Unify executor and futures logic

This commit is contained in:
Neil Booth 2017-01-24 21:14:41 +09:00
parent 9b5cb105d5
commit cb0160901f
5 changed files with 78 additions and 92 deletions

View File

@ -138,9 +138,10 @@ class BlockProcessor(server.db.DB):
Coordinate backing up in case of chain reorganisations. Coordinate backing up in case of chain reorganisations.
''' '''
def __init__(self, env, daemon): def __init__(self, env, controller, daemon):
super().__init__(env) super().__init__(env)
self.daemon = daemon self.daemon = daemon
self.controller = controller
# These are our state as we move ahead of DB state # These are our state as we move ahead of DB state
self.fs_height = self.db_height self.fs_height = self.db_height
@ -190,6 +191,7 @@ class BlockProcessor(server.db.DB):
async def main_loop(self): async def main_loop(self):
'''Main loop for block processing.''' '''Main loop for block processing.'''
self.controller.ensure_future(self.prefetcher.main_loop())
await self.prefetcher.reset_height() await self.prefetcher.reset_height()
while True: while True:
@ -205,16 +207,11 @@ class BlockProcessor(server.db.DB):
self.logger.info('flushing state to DB for a clean shutdown...') self.logger.info('flushing state to DB for a clean shutdown...')
self.flush(True) self.flush(True)
async def executor(self, func, *args, **kwargs):
'''Run func taking args in the executor.'''
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, partial(func, *args, **kwargs))
async def first_caught_up(self): async def first_caught_up(self):
'''Called when first caught up to daemon after starting.''' '''Called when first caught up to daemon after starting.'''
# Flush everything with updated first_sync->False state. # Flush everything with updated first_sync->False state.
self.first_sync = False self.first_sync = False
await self.executor(self.flush, True) await self.controller.run_in_executor(self.flush, True)
if self.utxo_db.for_sync: if self.utxo_db.for_sync:
self.logger.info('{} synced to height {:,d}' self.logger.info('{} synced to height {:,d}'
.format(VERSION, self.height)) .format(VERSION, self.height))
@ -240,7 +237,8 @@ class BlockProcessor(server.db.DB):
if hprevs == chain: if hprevs == chain:
start = time.time() start = time.time()
await self.executor(self.advance_blocks, blocks, headers) await self.controller.run_in_executor(self.advance_blocks,
blocks, headers)
if not self.first_sync: if not self.first_sync:
s = '' if len(blocks) == 1 else 's' s = '' if len(blocks) == 1 else 's'
self.logger.info('processed {:,d} block{} in {:.1f}s' self.logger.info('processed {:,d} block{} in {:.1f}s'
@ -277,14 +275,14 @@ class BlockProcessor(server.db.DB):
self.logger.info('chain reorg detected') self.logger.info('chain reorg detected')
else: else:
self.logger.info('faking a reorg of {:,d} blocks'.format(count)) self.logger.info('faking a reorg of {:,d} blocks'.format(count))
await self.executor(self.flush, True) await self.controller.run_in_executor(self.flush, True)
hashes = await self.reorg_hashes(count) hashes = await self.reorg_hashes(count)
# Reverse and convert to hex strings. # Reverse and convert to hex strings.
hashes = [hash_to_str(hash) for hash in reversed(hashes)] hashes = [hash_to_str(hash) for hash in reversed(hashes)]
for hex_hashes in chunks(hashes, 50): for hex_hashes in chunks(hashes, 50):
blocks = await self.daemon.raw_blocks(hex_hashes) blocks = await self.daemon.raw_blocks(hex_hashes)
await self.executor(self.backup_blocks, blocks) await self.controller.run_in_executor(self.backup_blocks, blocks)
await self.prefetcher.reset_height() await self.prefetcher.reset_height()
async def reorg_hashes(self, count): async def reorg_hashes(self, count):

View File

@ -11,6 +11,7 @@ import json
import os import os
import ssl import ssl
import time import time
import traceback
from bisect import bisect_left from bisect import bisect_left
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -49,9 +50,9 @@ class Controller(util.LoggedClass):
self.start_time = time.time() self.start_time = time.time()
self.coin = env.coin self.coin = env.coin
self.daemon = Daemon(env.coin.daemon_urls(env.daemon_url)) self.daemon = Daemon(env.coin.daemon_urls(env.daemon_url))
self.bp = BlockProcessor(env, self.daemon) self.bp = BlockProcessor(env, self, self.daemon)
self.mempool = MemPool(self.bp) self.mempool = MemPool(self.bp, self)
self.peers = PeerManager(env) self.peers = PeerManager(env, self)
self.env = env self.env = env
self.servers = {} self.servers = {}
# Map of session to the key of its list in self.groups # Map of session to the key of its list in self.groups
@ -63,6 +64,7 @@ class Controller(util.LoggedClass):
self.max_sessions = env.max_sessions self.max_sessions = env.max_sessions
self.low_watermark = self.max_sessions * 19 // 20 self.low_watermark = self.max_sessions * 19 // 20
self.max_subs = env.max_subs self.max_subs = env.max_subs
self.futures = set()
# Cache some idea of room to avoid recounting on each subscription # Cache some idea of room to avoid recounting on each subscription
self.subs_room = 0 self.subs_room = 0
self.next_stale_check = 0 self.next_stale_check = 0
@ -199,43 +201,59 @@ class Controller(util.LoggedClass):
if session.items: if session.items:
self.enqueue_session(session) self.enqueue_session(session)
async def run_in_executor(self, func, *args):
'''Wait whilst running func in the executor.'''
return await self.loop.run_in_executor(None, func, *args)
def schedule_executor(self, func, *args):
'''Schedule running func in the executor, return a task.'''
return self.ensure_future(self.run_in_executor(func, *args))
def ensure_future(self, coro):
'''Schedule the coro to be run.'''
future = asyncio.ensure_future(coro)
future.add_done_callback(self.on_future_done)
self.futures.add(future)
return future
def on_future_done(self, future):
'''Collect the result of a future after removing it from our set.'''
self.futures.remove(future)
try:
future.result()
except asyncio.CancelledError:
pass
except Exception:
self.log_error(traceback.format_exc())
async def wait_for_bp_catchup(self):
'''Called when the block processor catches up.'''
await self.bp.caught_up_event.wait()
self.logger.info('block processor has caught up')
self.ensure_future(self.peers.main_loop())
self.ensure_future(self.start_servers())
self.ensure_future(self.mempool.main_loop())
self.ensure_future(self.enqueue_delayed_sessions())
self.ensure_future(self.notify())
for n in range(4):
self.ensure_future(self.serve_requests())
async def main_loop(self):
'''Controller main loop.'''
self.ensure_future(self.bp.main_loop())
self.ensure_future(self.wait_for_bp_catchup())
# Shut down cleanly after waiting for shutdown to be signalled
await self.shutdown_event.wait()
self.logger.info('shutting down')
await self.shutdown()
self.logger.info('shutdown complete')
def initiate_shutdown(self): def initiate_shutdown(self):
'''Call this function to start the shutdown process.''' '''Call this function to start the shutdown process.'''
self.shutdown_event.set() self.shutdown_event.set()
async def main_loop(self): async def shutdown(self):
'''Controller main loop.'''
def add_future(coro):
futures.append(asyncio.ensure_future(coro))
async def await_bp_catchup():
'''Wait for the block processor to catch up.
Then start the servers and the peer manager.
'''
await self.bp.caught_up_event.wait()
self.logger.info('block processor has caught up')
add_future(self.peers.main_loop())
add_future(self.start_servers())
add_future(self.mempool.main_loop())
add_future(self.enqueue_delayed_sessions())
add_future(self.notify())
for n in range(4):
add_future(self.serve_requests())
futures = []
add_future(self.bp.main_loop())
add_future(self.bp.prefetcher.main_loop())
add_future(await_bp_catchup())
# Perform a clean shutdown when this event is signalled.
await self.shutdown_event.wait()
self.logger.info('shutting down')
await self.shutdown(futures)
self.logger.info('shutdown complete')
async def shutdown(self, futures):
'''Perform the shutdown sequence.''' '''Perform the shutdown sequence.'''
self.state = self.SHUTTING_DOWN self.state = self.SHUTTING_DOWN
@ -244,13 +262,13 @@ class Controller(util.LoggedClass):
for session in self.sessions: for session in self.sessions:
self.close_session(session) self.close_session(session)
# Cancel the futures # Cancel pending futures
for future in futures: for future in self.futures:
future.cancel() future.cancel()
# Wait for all futures to finish # Wait for all futures to finish
while any(not future.done() for future in futures): while not all (future.done() for future in self.futures):
await asyncio.sleep(1) await asyncio.sleep(0.1)
# Finally shut down the block processor and executor # Finally shut down the block processor and executor
self.bp.shutdown(self.executor) self.bp.shutdown(self.executor)
@ -694,8 +712,7 @@ class Controller(util.LoggedClass):
limit = self.env.max_send // 97 limit = self.env.max_send // 97
return list(self.bp.get_history(hashX, limit=limit)) return list(self.bp.get_history(hashX, limit=limit))
loop = asyncio.get_event_loop() history = await self.run_in_executor(job)
history = await loop.run_in_executor(None, job)
self.history_cache[hashX] = history self.history_cache[hashX] = history
return history return history
@ -725,8 +742,8 @@ class Controller(util.LoggedClass):
'''Get UTXOs asynchronously to reduce latency.''' '''Get UTXOs asynchronously to reduce latency.'''
def job(): def job():
return list(self.bp.get_utxos(hashX, limit=None)) return list(self.bp.get_utxos(hashX, limit=None))
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, job) return await self.run_in_executor(job)
def get_chunk(self, index): def get_chunk(self, index):
'''Return header chunk as hex. Index is a non-negative integer.''' '''Return header chunk as hex. Index is a non-negative integer.'''

View File

@ -31,9 +31,10 @@ class MemPool(util.LoggedClass):
A pair is a (hashX, value) tuple. tx hashes are hex strings. A pair is a (hashX, value) tuple. tx hashes are hex strings.
''' '''
def __init__(self, bp): def __init__(self, bp, controller):
super().__init__() super().__init__()
self.daemon = bp.daemon self.daemon = bp.daemon
self.controller = controller
self.coin = bp.coin self.coin = bp.coin
self.db = bp self.db = bp
self.touched = bp.touched self.touched = bp.touched
@ -139,7 +140,6 @@ class MemPool(util.LoggedClass):
break break
def async_process_some(self, unfetched, limit): def async_process_some(self, unfetched, limit):
loop = asyncio.get_event_loop()
pending = [] pending = []
txs = self.txs txs = self.txs
@ -162,9 +162,8 @@ class MemPool(util.LoggedClass):
deferred = pending deferred = pending
pending = [] pending = []
def job(): result, deferred = await self.controller.run_in_executor \
return self.process_raw_txs(raw_txs, deferred) (self.process_raw_txs, raw_txs, deferred)
result, deferred = await loop.run_in_executor(None, job)
pending.extend(deferred) pending.extend(deferred)
hashXs = self.hashXs hashXs = self.hashXs

View File

@ -7,11 +7,8 @@
'''Peer management.''' '''Peer management.'''
import asyncio
import socket import socket
import traceback
from collections import namedtuple from collections import namedtuple
from functools import partial
import lib.util as util import lib.util as util
from server.irc import IRC from server.irc import IRC
@ -30,12 +27,11 @@ class PeerManager(util.LoggedClass):
VERSION = '1.0' VERSION = '1.0'
DEFAULT_PORTS = {'t': 50001, 's': 50002} DEFAULT_PORTS = {'t': 50001, 's': 50002}
def __init__(self, env): def __init__(self, env, controller):
super().__init__() super().__init__()
self.env = env self.env = env
self.loop = asyncio.get_event_loop() self.controller = controller
self.irc = IRC(env, self) self.irc = IRC(env, self)
self.futures = set()
self.identities = [] self.identities = []
# Keyed by nick # Keyed by nick
self.irc_peers = {} self.irc_peers = {}
@ -51,10 +47,6 @@ class PeerManager(util.LoggedClass):
env.report_ssl_port_tor, env.report_ssl_port_tor,
'_tor')) '_tor'))
async def executor(self, func, *args, **kwargs):
'''Run func taking args in the executor.'''
await self.loop.run_in_executor(None, partial(func, *args, **kwargs))
@classmethod @classmethod
def real_name(cls, identity): def real_name(cls, identity):
'''Real name as used on IRC.''' '''Real name as used on IRC.'''
@ -70,38 +62,19 @@ class PeerManager(util.LoggedClass):
ssl = port_text('s', identity.ssl_port) ssl = port_text('s', identity.ssl_port)
return '{} v{}{}{}'.format(identity.host, cls.VERSION, tcp, ssl) return '{} v{}{}{}'.format(identity.host, cls.VERSION, tcp, ssl)
def ensure_future(self, coro):
'''Convert a coro into a future and add it to our pending list
to be waited for.'''
self.futures.add(asyncio.ensure_future(coro))
def start_irc(self): def start_irc(self):
'''Start up the IRC connections if enabled.''' '''Start up the IRC connections if enabled.'''
if self.env.irc: if self.env.irc:
name_pairs = [(self.real_name(identity), identity.nick_suffix) name_pairs = [(self.real_name(identity), identity.nick_suffix)
for identity in self.identities] for identity in self.identities]
self.ensure_future(self.irc.start(name_pairs)) self.controller.ensure_future(self.irc.start(name_pairs))
else: else:
self.logger.info('IRC is disabled') self.logger.info('IRC is disabled')
async def main_loop(self): async def main_loop(self):
'''Start and then enter the main loop.''' '''Main loop. No loop for now.'''
self.start_irc() self.start_irc()
try:
while True:
await asyncio.sleep(10)
done = [future for future in self.futures if future.done()]
self.futures.difference_update(done)
for future in done:
try:
future.result()
except:
self.log_error(traceback.format_exc())
finally:
for future in self.futures:
future.cancel()
def dns_lookup_peer(self, nick, hostname, details): def dns_lookup_peer(self, nick, hostname, details):
try: try:
ip_addr = None ip_addr = None
@ -119,7 +92,7 @@ class PeerManager(util.LoggedClass):
def add_irc_peer(self, *args): def add_irc_peer(self, *args):
'''Schedule DNS lookup of peer.''' '''Schedule DNS lookup of peer.'''
self.ensure_future(self.executor(self.dns_lookup_peer, *args)) self.controller.schedule_executor(self.dns_lookup_peer, *args)
def remove_irc_peer(self, nick): def remove_irc_peer(self, nick):
'''Remove a peer from our IRC peers map.''' '''Remove a peer from our IRC peers map.'''

View File

@ -8,7 +8,6 @@
'''Classes for local RPC server and remote client TCP/SSL servers.''' '''Classes for local RPC server and remote client TCP/SSL servers.'''
import asyncio
import time import time
import traceback import traceback
from functools import partial from functools import partial