Unify executor and futures logic

This commit is contained in:
Neil Booth 2017-01-24 21:14:41 +09:00
parent 9b5cb105d5
commit cb0160901f
5 changed files with 78 additions and 92 deletions

View File

@ -138,9 +138,10 @@ class BlockProcessor(server.db.DB):
Coordinate backing up in case of chain reorganisations.
'''
def __init__(self, env, daemon):
def __init__(self, env, controller, daemon):
super().__init__(env)
self.daemon = daemon
self.controller = controller
# These are our state as we move ahead of DB state
self.fs_height = self.db_height
@ -190,6 +191,7 @@ class BlockProcessor(server.db.DB):
async def main_loop(self):
'''Main loop for block processing.'''
self.controller.ensure_future(self.prefetcher.main_loop())
await self.prefetcher.reset_height()
while True:
@ -205,16 +207,11 @@ class BlockProcessor(server.db.DB):
self.logger.info('flushing state to DB for a clean shutdown...')
self.flush(True)
async def executor(self, func, *args, **kwargs):
'''Run func taking args in the executor.'''
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, partial(func, *args, **kwargs))
async def first_caught_up(self):
'''Called when first caught up to daemon after starting.'''
# Flush everything with updated first_sync->False state.
self.first_sync = False
await self.executor(self.flush, True)
await self.controller.run_in_executor(self.flush, True)
if self.utxo_db.for_sync:
self.logger.info('{} synced to height {:,d}'
.format(VERSION, self.height))
@ -240,7 +237,8 @@ class BlockProcessor(server.db.DB):
if hprevs == chain:
start = time.time()
await self.executor(self.advance_blocks, blocks, headers)
await self.controller.run_in_executor(self.advance_blocks,
blocks, headers)
if not self.first_sync:
s = '' if len(blocks) == 1 else 's'
self.logger.info('processed {:,d} block{} in {:.1f}s'
@ -277,14 +275,14 @@ class BlockProcessor(server.db.DB):
self.logger.info('chain reorg detected')
else:
self.logger.info('faking a reorg of {:,d} blocks'.format(count))
await self.executor(self.flush, True)
await self.controller.run_in_executor(self.flush, True)
hashes = await self.reorg_hashes(count)
# Reverse and convert to hex strings.
hashes = [hash_to_str(hash) for hash in reversed(hashes)]
for hex_hashes in chunks(hashes, 50):
blocks = await self.daemon.raw_blocks(hex_hashes)
await self.executor(self.backup_blocks, blocks)
await self.controller.run_in_executor(self.backup_blocks, blocks)
await self.prefetcher.reset_height()
async def reorg_hashes(self, count):

View File

@ -11,6 +11,7 @@ import json
import os
import ssl
import time
import traceback
from bisect import bisect_left
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
@ -49,9 +50,9 @@ class Controller(util.LoggedClass):
self.start_time = time.time()
self.coin = env.coin
self.daemon = Daemon(env.coin.daemon_urls(env.daemon_url))
self.bp = BlockProcessor(env, self.daemon)
self.mempool = MemPool(self.bp)
self.peers = PeerManager(env)
self.bp = BlockProcessor(env, self, self.daemon)
self.mempool = MemPool(self.bp, self)
self.peers = PeerManager(env, self)
self.env = env
self.servers = {}
# Map of session to the key of its list in self.groups
@ -63,6 +64,7 @@ class Controller(util.LoggedClass):
self.max_sessions = env.max_sessions
self.low_watermark = self.max_sessions * 19 // 20
self.max_subs = env.max_subs
self.futures = set()
# Cache some idea of room to avoid recounting on each subscription
self.subs_room = 0
self.next_stale_check = 0
@ -199,43 +201,59 @@ class Controller(util.LoggedClass):
if session.items:
self.enqueue_session(session)
async def run_in_executor(self, func, *args):
'''Wait whilst running func in the executor.'''
return await self.loop.run_in_executor(None, func, *args)
def schedule_executor(self, func, *args):
'''Schedule running func in the executor, return a task.'''
return self.ensure_future(self.run_in_executor(func, *args))
def ensure_future(self, coro):
'''Schedule the coro to be run.'''
future = asyncio.ensure_future(coro)
future.add_done_callback(self.on_future_done)
self.futures.add(future)
return future
def on_future_done(self, future):
'''Collect the result of a future after removing it from our set.'''
self.futures.remove(future)
try:
future.result()
except asyncio.CancelledError:
pass
except Exception:
self.log_error(traceback.format_exc())
async def wait_for_bp_catchup(self):
'''Called when the block processor catches up.'''
await self.bp.caught_up_event.wait()
self.logger.info('block processor has caught up')
self.ensure_future(self.peers.main_loop())
self.ensure_future(self.start_servers())
self.ensure_future(self.mempool.main_loop())
self.ensure_future(self.enqueue_delayed_sessions())
self.ensure_future(self.notify())
for n in range(4):
self.ensure_future(self.serve_requests())
async def main_loop(self):
'''Controller main loop.'''
self.ensure_future(self.bp.main_loop())
self.ensure_future(self.wait_for_bp_catchup())
# Shut down cleanly after waiting for shutdown to be signalled
await self.shutdown_event.wait()
self.logger.info('shutting down')
await self.shutdown()
self.logger.info('shutdown complete')
def initiate_shutdown(self):
'''Call this function to start the shutdown process.'''
self.shutdown_event.set()
async def main_loop(self):
'''Controller main loop.'''
def add_future(coro):
futures.append(asyncio.ensure_future(coro))
async def await_bp_catchup():
'''Wait for the block processor to catch up.
Then start the servers and the peer manager.
'''
await self.bp.caught_up_event.wait()
self.logger.info('block processor has caught up')
add_future(self.peers.main_loop())
add_future(self.start_servers())
add_future(self.mempool.main_loop())
add_future(self.enqueue_delayed_sessions())
add_future(self.notify())
for n in range(4):
add_future(self.serve_requests())
futures = []
add_future(self.bp.main_loop())
add_future(self.bp.prefetcher.main_loop())
add_future(await_bp_catchup())
# Perform a clean shutdown when this event is signalled.
await self.shutdown_event.wait()
self.logger.info('shutting down')
await self.shutdown(futures)
self.logger.info('shutdown complete')
async def shutdown(self, futures):
async def shutdown(self):
'''Perform the shutdown sequence.'''
self.state = self.SHUTTING_DOWN
@ -244,13 +262,13 @@ class Controller(util.LoggedClass):
for session in self.sessions:
self.close_session(session)
# Cancel the futures
for future in futures:
# Cancel pending futures
for future in self.futures:
future.cancel()
# Wait for all futures to finish
while any(not future.done() for future in futures):
await asyncio.sleep(1)
while not all (future.done() for future in self.futures):
await asyncio.sleep(0.1)
# Finally shut down the block processor and executor
self.bp.shutdown(self.executor)
@ -694,8 +712,7 @@ class Controller(util.LoggedClass):
limit = self.env.max_send // 97
return list(self.bp.get_history(hashX, limit=limit))
loop = asyncio.get_event_loop()
history = await loop.run_in_executor(None, job)
history = await self.run_in_executor(job)
self.history_cache[hashX] = history
return history
@ -725,8 +742,8 @@ class Controller(util.LoggedClass):
'''Get UTXOs asynchronously to reduce latency.'''
def job():
return list(self.bp.get_utxos(hashX, limit=None))
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, job)
return await self.run_in_executor(job)
def get_chunk(self, index):
'''Return header chunk as hex. Index is a non-negative integer.'''

View File

@ -31,9 +31,10 @@ class MemPool(util.LoggedClass):
A pair is a (hashX, value) tuple. tx hashes are hex strings.
'''
def __init__(self, bp):
def __init__(self, bp, controller):
super().__init__()
self.daemon = bp.daemon
self.controller = controller
self.coin = bp.coin
self.db = bp
self.touched = bp.touched
@ -139,7 +140,6 @@ class MemPool(util.LoggedClass):
break
def async_process_some(self, unfetched, limit):
loop = asyncio.get_event_loop()
pending = []
txs = self.txs
@ -162,9 +162,8 @@ class MemPool(util.LoggedClass):
deferred = pending
pending = []
def job():
return self.process_raw_txs(raw_txs, deferred)
result, deferred = await loop.run_in_executor(None, job)
result, deferred = await self.controller.run_in_executor \
(self.process_raw_txs, raw_txs, deferred)
pending.extend(deferred)
hashXs = self.hashXs

View File

@ -7,11 +7,8 @@
'''Peer management.'''
import asyncio
import socket
import traceback
from collections import namedtuple
from functools import partial
import lib.util as util
from server.irc import IRC
@ -30,12 +27,11 @@ class PeerManager(util.LoggedClass):
VERSION = '1.0'
DEFAULT_PORTS = {'t': 50001, 's': 50002}
def __init__(self, env):
def __init__(self, env, controller):
super().__init__()
self.env = env
self.loop = asyncio.get_event_loop()
self.controller = controller
self.irc = IRC(env, self)
self.futures = set()
self.identities = []
# Keyed by nick
self.irc_peers = {}
@ -51,10 +47,6 @@ class PeerManager(util.LoggedClass):
env.report_ssl_port_tor,
'_tor'))
async def executor(self, func, *args, **kwargs):
'''Run func taking args in the executor.'''
await self.loop.run_in_executor(None, partial(func, *args, **kwargs))
@classmethod
def real_name(cls, identity):
'''Real name as used on IRC.'''
@ -70,38 +62,19 @@ class PeerManager(util.LoggedClass):
ssl = port_text('s', identity.ssl_port)
return '{} v{}{}{}'.format(identity.host, cls.VERSION, tcp, ssl)
def ensure_future(self, coro):
'''Convert a coro into a future and add it to our pending list
to be waited for.'''
self.futures.add(asyncio.ensure_future(coro))
def start_irc(self):
'''Start up the IRC connections if enabled.'''
if self.env.irc:
name_pairs = [(self.real_name(identity), identity.nick_suffix)
for identity in self.identities]
self.ensure_future(self.irc.start(name_pairs))
self.controller.ensure_future(self.irc.start(name_pairs))
else:
self.logger.info('IRC is disabled')
async def main_loop(self):
'''Start and then enter the main loop.'''
'''Main loop. No loop for now.'''
self.start_irc()
try:
while True:
await asyncio.sleep(10)
done = [future for future in self.futures if future.done()]
self.futures.difference_update(done)
for future in done:
try:
future.result()
except:
self.log_error(traceback.format_exc())
finally:
for future in self.futures:
future.cancel()
def dns_lookup_peer(self, nick, hostname, details):
try:
ip_addr = None
@ -119,7 +92,7 @@ class PeerManager(util.LoggedClass):
def add_irc_peer(self, *args):
'''Schedule DNS lookup of peer.'''
self.ensure_future(self.executor(self.dns_lookup_peer, *args))
self.controller.schedule_executor(self.dns_lookup_peer, *args)
def remove_irc_peer(self, nick):
'''Remove a peer from our IRC peers map.'''

View File

@ -8,7 +8,6 @@
'''Classes for local RPC server and remote client TCP/SSL servers.'''
import asyncio
import time
import traceback
from functools import partial