From cb0160901f5f2223fe5452b67c41903cd5bfe167 Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Tue, 24 Jan 2017 21:14:41 +0900 Subject: [PATCH] Unify executor and futures logic --- server/block_processor.py | 18 +++---- server/controller.py | 105 ++++++++++++++++++++++---------------- server/mempool.py | 9 ++-- server/peers.py | 37 ++------------ server/session.py | 1 - 5 files changed, 78 insertions(+), 92 deletions(-) diff --git a/server/block_processor.py b/server/block_processor.py index ea8fb8c..74ae878 100644 --- a/server/block_processor.py +++ b/server/block_processor.py @@ -138,9 +138,10 @@ class BlockProcessor(server.db.DB): Coordinate backing up in case of chain reorganisations. ''' - def __init__(self, env, daemon): + def __init__(self, env, controller, daemon): super().__init__(env) self.daemon = daemon + self.controller = controller # These are our state as we move ahead of DB state self.fs_height = self.db_height @@ -190,6 +191,7 @@ class BlockProcessor(server.db.DB): async def main_loop(self): '''Main loop for block processing.''' + self.controller.ensure_future(self.prefetcher.main_loop()) await self.prefetcher.reset_height() while True: @@ -205,16 +207,11 @@ class BlockProcessor(server.db.DB): self.logger.info('flushing state to DB for a clean shutdown...') self.flush(True) - async def executor(self, func, *args, **kwargs): - '''Run func taking args in the executor.''' - loop = asyncio.get_event_loop() - await loop.run_in_executor(None, partial(func, *args, **kwargs)) - async def first_caught_up(self): '''Called when first caught up to daemon after starting.''' # Flush everything with updated first_sync->False state. self.first_sync = False - await self.executor(self.flush, True) + await self.controller.run_in_executor(self.flush, True) if self.utxo_db.for_sync: self.logger.info('{} synced to height {:,d}' .format(VERSION, self.height)) @@ -240,7 +237,8 @@ class BlockProcessor(server.db.DB): if hprevs == chain: start = time.time() - await self.executor(self.advance_blocks, blocks, headers) + await self.controller.run_in_executor(self.advance_blocks, + blocks, headers) if not self.first_sync: s = '' if len(blocks) == 1 else 's' self.logger.info('processed {:,d} block{} in {:.1f}s' @@ -277,14 +275,14 @@ class BlockProcessor(server.db.DB): self.logger.info('chain reorg detected') else: self.logger.info('faking a reorg of {:,d} blocks'.format(count)) - await self.executor(self.flush, True) + await self.controller.run_in_executor(self.flush, True) hashes = await self.reorg_hashes(count) # Reverse and convert to hex strings. hashes = [hash_to_str(hash) for hash in reversed(hashes)] for hex_hashes in chunks(hashes, 50): blocks = await self.daemon.raw_blocks(hex_hashes) - await self.executor(self.backup_blocks, blocks) + await self.controller.run_in_executor(self.backup_blocks, blocks) await self.prefetcher.reset_height() async def reorg_hashes(self, count): diff --git a/server/controller.py b/server/controller.py index 803d557..2ef9b44 100644 --- a/server/controller.py +++ b/server/controller.py @@ -11,6 +11,7 @@ import json import os import ssl import time +import traceback from bisect import bisect_left from collections import defaultdict from concurrent.futures import ThreadPoolExecutor @@ -49,9 +50,9 @@ class Controller(util.LoggedClass): self.start_time = time.time() self.coin = env.coin self.daemon = Daemon(env.coin.daemon_urls(env.daemon_url)) - self.bp = BlockProcessor(env, self.daemon) - self.mempool = MemPool(self.bp) - self.peers = PeerManager(env) + self.bp = BlockProcessor(env, self, self.daemon) + self.mempool = MemPool(self.bp, self) + self.peers = PeerManager(env, self) self.env = env self.servers = {} # Map of session to the key of its list in self.groups @@ -63,6 +64,7 @@ class Controller(util.LoggedClass): self.max_sessions = env.max_sessions self.low_watermark = self.max_sessions * 19 // 20 self.max_subs = env.max_subs + self.futures = set() # Cache some idea of room to avoid recounting on each subscription self.subs_room = 0 self.next_stale_check = 0 @@ -199,43 +201,59 @@ class Controller(util.LoggedClass): if session.items: self.enqueue_session(session) + async def run_in_executor(self, func, *args): + '''Wait whilst running func in the executor.''' + return await self.loop.run_in_executor(None, func, *args) + + def schedule_executor(self, func, *args): + '''Schedule running func in the executor, return a task.''' + return self.ensure_future(self.run_in_executor(func, *args)) + + def ensure_future(self, coro): + '''Schedule the coro to be run.''' + future = asyncio.ensure_future(coro) + future.add_done_callback(self.on_future_done) + self.futures.add(future) + return future + + def on_future_done(self, future): + '''Collect the result of a future after removing it from our set.''' + self.futures.remove(future) + try: + future.result() + except asyncio.CancelledError: + pass + except Exception: + self.log_error(traceback.format_exc()) + + async def wait_for_bp_catchup(self): + '''Called when the block processor catches up.''' + await self.bp.caught_up_event.wait() + self.logger.info('block processor has caught up') + self.ensure_future(self.peers.main_loop()) + self.ensure_future(self.start_servers()) + self.ensure_future(self.mempool.main_loop()) + self.ensure_future(self.enqueue_delayed_sessions()) + self.ensure_future(self.notify()) + for n in range(4): + self.ensure_future(self.serve_requests()) + + async def main_loop(self): + '''Controller main loop.''' + self.ensure_future(self.bp.main_loop()) + self.ensure_future(self.wait_for_bp_catchup()) + + # Shut down cleanly after waiting for shutdown to be signalled + await self.shutdown_event.wait() + self.logger.info('shutting down') + await self.shutdown() + self.logger.info('shutdown complete') + def initiate_shutdown(self): '''Call this function to start the shutdown process.''' self.shutdown_event.set() - async def main_loop(self): - '''Controller main loop.''' - def add_future(coro): - futures.append(asyncio.ensure_future(coro)) - - async def await_bp_catchup(): - '''Wait for the block processor to catch up. - - Then start the servers and the peer manager. - ''' - await self.bp.caught_up_event.wait() - self.logger.info('block processor has caught up') - add_future(self.peers.main_loop()) - add_future(self.start_servers()) - add_future(self.mempool.main_loop()) - add_future(self.enqueue_delayed_sessions()) - add_future(self.notify()) - for n in range(4): - add_future(self.serve_requests()) - - futures = [] - add_future(self.bp.main_loop()) - add_future(self.bp.prefetcher.main_loop()) - add_future(await_bp_catchup()) - - # Perform a clean shutdown when this event is signalled. - await self.shutdown_event.wait() - - self.logger.info('shutting down') - await self.shutdown(futures) - self.logger.info('shutdown complete') - - async def shutdown(self, futures): + async def shutdown(self): '''Perform the shutdown sequence.''' self.state = self.SHUTTING_DOWN @@ -244,13 +262,13 @@ class Controller(util.LoggedClass): for session in self.sessions: self.close_session(session) - # Cancel the futures - for future in futures: + # Cancel pending futures + for future in self.futures: future.cancel() # Wait for all futures to finish - while any(not future.done() for future in futures): - await asyncio.sleep(1) + while not all (future.done() for future in self.futures): + await asyncio.sleep(0.1) # Finally shut down the block processor and executor self.bp.shutdown(self.executor) @@ -694,8 +712,7 @@ class Controller(util.LoggedClass): limit = self.env.max_send // 97 return list(self.bp.get_history(hashX, limit=limit)) - loop = asyncio.get_event_loop() - history = await loop.run_in_executor(None, job) + history = await self.run_in_executor(job) self.history_cache[hashX] = history return history @@ -725,8 +742,8 @@ class Controller(util.LoggedClass): '''Get UTXOs asynchronously to reduce latency.''' def job(): return list(self.bp.get_utxos(hashX, limit=None)) - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, job) + + return await self.run_in_executor(job) def get_chunk(self, index): '''Return header chunk as hex. Index is a non-negative integer.''' diff --git a/server/mempool.py b/server/mempool.py index 0a0a952..387b12c 100644 --- a/server/mempool.py +++ b/server/mempool.py @@ -31,9 +31,10 @@ class MemPool(util.LoggedClass): A pair is a (hashX, value) tuple. tx hashes are hex strings. ''' - def __init__(self, bp): + def __init__(self, bp, controller): super().__init__() self.daemon = bp.daemon + self.controller = controller self.coin = bp.coin self.db = bp self.touched = bp.touched @@ -139,7 +140,6 @@ class MemPool(util.LoggedClass): break def async_process_some(self, unfetched, limit): - loop = asyncio.get_event_loop() pending = [] txs = self.txs @@ -162,9 +162,8 @@ class MemPool(util.LoggedClass): deferred = pending pending = [] - def job(): - return self.process_raw_txs(raw_txs, deferred) - result, deferred = await loop.run_in_executor(None, job) + result, deferred = await self.controller.run_in_executor \ + (self.process_raw_txs, raw_txs, deferred) pending.extend(deferred) hashXs = self.hashXs diff --git a/server/peers.py b/server/peers.py index 31519a3..2f1e832 100644 --- a/server/peers.py +++ b/server/peers.py @@ -7,11 +7,8 @@ '''Peer management.''' -import asyncio import socket -import traceback from collections import namedtuple -from functools import partial import lib.util as util from server.irc import IRC @@ -30,12 +27,11 @@ class PeerManager(util.LoggedClass): VERSION = '1.0' DEFAULT_PORTS = {'t': 50001, 's': 50002} - def __init__(self, env): + def __init__(self, env, controller): super().__init__() self.env = env - self.loop = asyncio.get_event_loop() + self.controller = controller self.irc = IRC(env, self) - self.futures = set() self.identities = [] # Keyed by nick self.irc_peers = {} @@ -51,10 +47,6 @@ class PeerManager(util.LoggedClass): env.report_ssl_port_tor, '_tor')) - async def executor(self, func, *args, **kwargs): - '''Run func taking args in the executor.''' - await self.loop.run_in_executor(None, partial(func, *args, **kwargs)) - @classmethod def real_name(cls, identity): '''Real name as used on IRC.''' @@ -70,38 +62,19 @@ class PeerManager(util.LoggedClass): ssl = port_text('s', identity.ssl_port) return '{} v{}{}{}'.format(identity.host, cls.VERSION, tcp, ssl) - def ensure_future(self, coro): - '''Convert a coro into a future and add it to our pending list - to be waited for.''' - self.futures.add(asyncio.ensure_future(coro)) - def start_irc(self): '''Start up the IRC connections if enabled.''' if self.env.irc: name_pairs = [(self.real_name(identity), identity.nick_suffix) for identity in self.identities] - self.ensure_future(self.irc.start(name_pairs)) + self.controller.ensure_future(self.irc.start(name_pairs)) else: self.logger.info('IRC is disabled') async def main_loop(self): - '''Start and then enter the main loop.''' + '''Main loop. No loop for now.''' self.start_irc() - try: - while True: - await asyncio.sleep(10) - done = [future for future in self.futures if future.done()] - self.futures.difference_update(done) - for future in done: - try: - future.result() - except: - self.log_error(traceback.format_exc()) - finally: - for future in self.futures: - future.cancel() - def dns_lookup_peer(self, nick, hostname, details): try: ip_addr = None @@ -119,7 +92,7 @@ class PeerManager(util.LoggedClass): def add_irc_peer(self, *args): '''Schedule DNS lookup of peer.''' - self.ensure_future(self.executor(self.dns_lookup_peer, *args)) + self.controller.schedule_executor(self.dns_lookup_peer, *args) def remove_irc_peer(self, nick): '''Remove a peer from our IRC peers map.''' diff --git a/server/session.py b/server/session.py index 8916f49..0b341e8 100644 --- a/server/session.py +++ b/server/session.py @@ -8,7 +8,6 @@ '''Classes for local RPC server and remote client TCP/SSL servers.''' -import asyncio import time import traceback from functools import partial