Switch to curio primitives

Gives much clearer code
This commit is contained in:
Neil Booth 2018-07-28 12:27:47 +08:00
parent 55ef1ab157
commit 751f9917a4
9 changed files with 211 additions and 302 deletions

View File

@ -12,6 +12,8 @@ import sys
import time import time
from functools import partial from functools import partial
from aiorpcx import TaskGroup
from electrumx.lib.util import class_logger from electrumx.lib.util import class_logger
@ -93,18 +95,18 @@ class ServerBase(object):
loop.set_exception_handler(self.on_exception) loop.set_exception_handler(self.on_exception)
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
task = loop.create_task(self.serve(shutdown_event))
try: try:
# Wait for shutdown to be signalled, and log it. async with TaskGroup() as group:
await shutdown_event.wait() server_task = await group.spawn(self.serve(shutdown_event))
self.logger.info('shutting down') # Wait for shutdown, log on receipt of the event
task.cancel() await shutdown_event.wait()
await task self.logger.info('shutting down')
server_task.cancel()
finally: finally:
await loop.shutdown_asyncgens() await loop.shutdown_asyncgens()
# Prevent some silly logs # Prevent some silly logs
await asyncio.sleep(0) await asyncio.sleep(0.001)
# Finally, work around an apparent asyncio bug that causes log # Finally, work around an apparent asyncio bug that causes log
# spew on shutdown for partially opened SSL sockets # spew on shutdown for partially opened SSL sockets
try: try:

View File

@ -1,68 +0,0 @@
# Copyright (c) 2018, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# and warranty status of this software.
'''Concurrency via tasks and threads.'''
from aiorpcx import TaskSet
import electrumx.lib.util as util
class Tasks(object):
# Functionality here will be incorporated into aiorpcX's TaskSet
# after experience is gained.
def __init__(self, *, loop=None):
self.tasks = TaskSet(loop=loop)
self.logger = util.class_logger(__name__, self.__class__.__name__)
# Pass through until integrated
self.loop = self.tasks.loop
self.wait = self.tasks.wait
async def run_in_thread(self, func, *args):
'''Run a function in a separate thread, and await its completion.'''
return await self.loop.run_in_executor(None, func, *args)
def create_task(self, coro, daemon=True):
'''Schedule the coro to be run.'''
task = self.tasks.create_task(coro)
if daemon:
task.add_done_callback(self._check_task_exception)
return task
def _check_task_exception(self, task):
'''Check a task for exceptions.'''
try:
if not task.cancelled():
task.result()
except Exception as e:
self.logger.exception(f'uncaught task exception: {e}')
async def cancel_all(self, wait=True):
'''Cancels all tasks and waits for them to complete.'''
self.tasks.cancel_all()
if wait:
await self.tasks.wait()

View File

@ -15,6 +15,8 @@ from struct import pack, unpack
import time import time
from functools import partial from functools import partial
from aiorpcx import TaskGroup, run_in_thread
import electrumx import electrumx
from electrumx.server.daemon import DaemonError from electrumx.server.daemon import DaemonError
from electrumx.lib.hash import hash_to_hex_str, HASHX_LEN from electrumx.lib.hash import hash_to_hex_str, HASHX_LEN
@ -44,8 +46,9 @@ class Prefetcher(object):
# This makes the first fetch be 10 blocks # This makes the first fetch be 10 blocks
self.ave_size = self.min_cache_size // 10 self.ave_size = self.min_cache_size // 10
async def main_loop(self): async def main_loop(self, bp_height):
'''Loop forever polling for more blocks.''' '''Loop forever polling for more blocks.'''
await self.reset_height(bp_height)
while True: while True:
try: try:
# Sleep a while if there is nothing to prefetch # Sleep a while if there is nothing to prefetch
@ -153,14 +156,12 @@ class BlockProcessor(electrumx.server.db.DB):
Coordinate backing up in case of chain reorganisations. Coordinate backing up in case of chain reorganisations.
''' '''
def __init__(self, env, tasks, daemon, notifications): def __init__(self, env, daemon, notifications):
super().__init__(env) super().__init__(env)
self.tasks = tasks
self.daemon = daemon self.daemon = daemon
self.notifications = notifications self.notifications = notifications
self._caught_up_event = asyncio.Event()
self.blocks_event = asyncio.Event() self.blocks_event = asyncio.Event()
self.prefetcher = Prefetcher(daemon, env.coin, self.blocks_event) self.prefetcher = Prefetcher(daemon, env.coin, self.blocks_event)
@ -187,16 +188,10 @@ class BlockProcessor(electrumx.server.db.DB):
# If the lock is successfully acquired, in-memory chain state # If the lock is successfully acquired, in-memory chain state
# is consistent with self.height # is consistent with self.height
self.state_lock = asyncio.Lock() self.state_lock = asyncio.Lock()
self.worker_task = None
def add_new_block_callback(self, callback): async def run_in_thread_shielded(self, func, *args):
'''Add a function called when a new block is found. async with self.state_lock:
return await asyncio.shield(run_in_thread(func, *args))
If several blocks are processed simultaneously, only called
once. The callback is passed a set of hashXs touched by the
block(s), which is cleared on return.
'''
self.callbacks.append(callback)
async def check_and_advance_blocks(self, raw_blocks): async def check_and_advance_blocks(self, raw_blocks):
'''Process the list of raw blocks passed. Detects and handles '''Process the list of raw blocks passed. Detects and handles
@ -212,14 +207,7 @@ class BlockProcessor(electrumx.server.db.DB):
chain = [self.tip] + [self.coin.header_hash(h) for h in headers[:-1]] chain = [self.tip] + [self.coin.header_hash(h) for h in headers[:-1]]
if hprevs == chain: if hprevs == chain:
start = time.time() await self.run_in_thread_shielded(self.advance_blocks, blocks)
async with self.state_lock:
await self.tasks.run_in_thread(self.advance_blocks, blocks)
if not self.first_sync:
s = '' if len(blocks) == 1 else 's'
self.logger.info('processed {:,d} block{} in {:.1f}s'
.format(len(blocks), s,
time.time() - start))
if self._caught_up_event.is_set(): if self._caught_up_event.is_set():
await self.notifications.on_block(self.touched, self.height) await self.notifications.on_block(self.touched, self.height)
self.touched = set() self.touched = set()
@ -244,7 +232,7 @@ class BlockProcessor(electrumx.server.db.DB):
self.logger.info('chain reorg detected') self.logger.info('chain reorg detected')
else: else:
self.logger.info(f'faking a reorg of {count:,d} blocks') self.logger.info(f'faking a reorg of {count:,d} blocks')
await self.tasks.run_in_thread(self.flush, True) await run_in_thread(self.flush, True)
async def get_raw_blocks(last_height, hex_hashes): async def get_raw_blocks(last_height, hex_hashes):
heights = range(last_height, last_height - len(hex_hashes), -1) heights = range(last_height, last_height - len(hex_hashes), -1)
@ -260,8 +248,7 @@ class BlockProcessor(electrumx.server.db.DB):
hashes = [hash_to_hex_str(hash) for hash in reversed(hashes)] hashes = [hash_to_hex_str(hash) for hash in reversed(hashes)]
for hex_hashes in chunks(hashes, 50): for hex_hashes in chunks(hashes, 50):
raw_blocks = await get_raw_blocks(last, hex_hashes) raw_blocks = await get_raw_blocks(last, hex_hashes)
async with self.state_lock: await self.run_in_thread_shielded(self.backup_blocks, raw_blocks)
await self.tasks.run_in_thread(self.backup_blocks, raw_blocks)
last -= len(raw_blocks) last -= len(raw_blocks)
# Truncate header_mc: header count is 1 more than the height. # Truncate header_mc: header count is 1 more than the height.
# Note header_mc is None if the reorg happens at startup. # Note header_mc is None if the reorg happens at startup.
@ -468,6 +455,7 @@ class BlockProcessor(electrumx.server.db.DB):
It is already verified they correctly connect onto our tip. It is already verified they correctly connect onto our tip.
''' '''
start = time.time()
min_height = self.min_undo_height(self.daemon.cached_height()) min_height = self.min_undo_height(self.daemon.cached_height())
height = self.height height = self.height
@ -492,6 +480,12 @@ class BlockProcessor(electrumx.server.db.DB):
self.check_cache_size() self.check_cache_size()
self.next_cache_check = time.time() + 30 self.next_cache_check = time.time() + 30
if not self.first_sync:
s = '' if len(blocks) == 1 else 's'
self.logger.info('processed {:,d} block{} in {:.1f}s'
.format(len(blocks), s,
time.time() - start))
def advance_txs(self, txs): def advance_txs(self, txs):
self.tx_hashes.append(b''.join(tx_hash for tx, tx_hash in txs)) self.tx_hashes.append(b''.join(tx_hash for tx, tx_hash in txs))
@ -744,20 +738,13 @@ class BlockProcessor(electrumx.server.db.DB):
self.db_height = self.height self.db_height = self.height
self.db_tip = self.tip self.db_tip = self.tip
async def _process_blocks(self): async def _process_prefetched_blocks(self):
'''Loop forever processing blocks as they arrive.''' '''Loop forever processing blocks as they arrive.'''
while True: while True:
if self.height == self.daemon.cached_height(): if self.height == self.daemon.cached_height():
if not self._caught_up_event.is_set(): if not self._caught_up_event.is_set():
self.logger.info(f'caught up to height {self.height}') await self._first_caught_up()
self._caught_up_event.set() self._caught_up_event.set()
# Flush everything but with first_sync->False state.
first_sync = self.first_sync
self.first_sync = False
self.flush(True)
if first_sync:
self.logger.info(f'{electrumx.version} synced to '
f'height {self.height:,d}')
await self.blocks_event.wait() await self.blocks_event.wait()
self.blocks_event.clear() self.blocks_event.clear()
if self.reorg_count: if self.reorg_count:
@ -767,7 +754,26 @@ class BlockProcessor(electrumx.server.db.DB):
blocks = self.prefetcher.get_prefetched_blocks() blocks = self.prefetcher.get_prefetched_blocks()
await self.check_and_advance_blocks(blocks) await self.check_and_advance_blocks(blocks)
def _on_dbs_opened(self): async def _first_caught_up(self):
self.logger.info(f'caught up to height {self.height}')
# Flush everything but with first_sync->False state.
first_sync = self.first_sync
self.first_sync = False
self.flush(True)
if first_sync:
self.logger.info(f'{electrumx.version} synced to '
f'height {self.height:,d}')
# Initialise the notification framework
await self.notifications.on_block(set(), self.height)
# Reopen for serving
await self.open_for_serving()
# Populate the header merkle cache
length = max(1, self.height - self.env.reorg_limit)
self.header_mc = MerkleCache(self.merkle, HeaderSource(self), length)
self.logger.info('populated header merkle cache')
async def _first_open_dbs(self):
await self.open_for_sync()
# An incomplete compaction needs to be cancelled otherwise # An incomplete compaction needs to be cancelled otherwise
# restarting it will corrupt the history # restarting it will corrupt the history
self.history.cancel_compaction() self.history.cancel_compaction()
@ -783,31 +789,32 @@ class BlockProcessor(electrumx.server.db.DB):
# --- External API # --- External API
async def catch_up_to_daemon(self): async def fetch_and_process_blocks(self, caught_up_event):
'''Process and index blocks until we catch up with the daemon. '''Fetch, process and index blocks from the daemon.
Returns once caught up. Future blocks continue to be Sets caught_up_event when first caught up. Flushes to disk
processed in a separate task. and shuts down cleanly if cancelled.
This is mainly because if, during initial sync ElectrumX is
asked to shut down when a large number of blocks have been
processed but not written to disk, it should write those to
disk before exiting, as otherwise a significant amount of work
could be lost.
''' '''
# Open the databases first. self._caught_up_event = caught_up_event
await self.open_for_sync() async with TaskGroup() as group:
self._on_dbs_opened() await group.spawn(self._first_open_dbs())
# Get the prefetcher running # Ensure cached_height is set
self.tasks.create_task(self.prefetcher.main_loop()) await group.spawn(self.daemon.height())
await self.prefetcher.reset_height(self.height) try:
# Start our loop that processes blocks as they are fetched async with TaskGroup() as group:
self.worker_task = self.tasks.create_task(self._process_blocks()) await group.spawn(self.prefetcher.main_loop(self.height))
# Wait until caught up await group.spawn(self._process_prefetched_blocks())
await self._caught_up_event.wait() finally:
# Initialise the notification framework async with self.state_lock:
await self.notifications.on_block(set(), self.height) # Shut down block processing
# Reopen for serving self.logger.info('flushing to DB for a clean shutdown...')
await self.open_for_serving() self.flush(True)
# Populate the header merkle cache
length = max(1, self.height - self.env.reorg_limit)
self.header_mc = MerkleCache(self.merkle, HeaderSource(self), length)
self.logger.info('populated header merkle cache')
def force_chain_reorg(self, count): def force_chain_reorg(self, count):
'''Force a reorg of the given number of blocks. '''Force a reorg of the given number of blocks.
@ -819,18 +826,3 @@ class BlockProcessor(electrumx.server.db.DB):
self.blocks_event.set() self.blocks_event.set()
return True return True
return False return False
async def shutdown(self):
'''Shutdown cleanly and flush to disk.
If during initial sync ElectrumX is asked to shut down when a
large number of blocks have been processed but not written to
disk, it should write those to disk before exiting, as
otherwise a significant amount of work could be lost.
'''
if self.worker_task:
async with self.state_lock:
# Shut down block processing
self.worker_task.cancel()
self.logger.info('flushing to DB for a clean shutdown...')
self.flush(True)

View File

@ -9,15 +9,16 @@
import asyncio import asyncio
import pylru import pylru
from aiorpcx import run_in_thread
class ChainState(object): class ChainState(object):
'''Used as an interface by servers to request information about '''Used as an interface by servers to request information about
blocks, transaction history, UTXOs and the mempool. blocks, transaction history, UTXOs and the mempool.
''' '''
def __init__(self, env, tasks, daemon, bp, notifications): def __init__(self, env, daemon, bp, notifications):
self._env = env self._env = env
self._tasks = tasks
self._daemon = daemon self._daemon = daemon
self._bp = bp self._bp = bp
self._history_cache = pylru.lrucache(256) self._history_cache = pylru.lrucache(256)
@ -64,7 +65,7 @@ class ChainState(object):
hc = self._history_cache hc = self._history_cache
if hashX not in hc: if hashX not in hc:
hc[hashX] = await self._tasks.run_in_thread(job) hc[hashX] = await run_in_thread(job)
return hc[hashX] return hc[hashX]
async def get_utxos(self, hashX): async def get_utxos(self, hashX):
@ -72,7 +73,7 @@ class ChainState(object):
def job(): def job():
return list(self._bp.get_utxos(hashX, limit=None)) return list(self._bp.get_utxos(hashX, limit=None))
return await self._tasks.run_in_thread(job) return await run_in_thread(job)
def header_branch_and_root(self, length, height): def header_branch_and_root(self, length, height):
return self._bp.header_mc.branch_and_root(length, height) return self._bp.header_mc.branch_and_root(length, height)
@ -91,7 +92,3 @@ class ChainState(object):
def set_daemon_url(self, daemon_url): def set_daemon_url(self, daemon_url):
self._daemon.set_urls(self._env.coin.daemon_urls(daemon_url)) self._daemon.set_urls(self._env.coin.daemon_urls(daemon_url))
return self._daemon.logged_url() return self._daemon.logged_url()
async def shutdown(self):
'''Shut down the block processor to flush chain state to disk.'''
await self._bp.shutdown()

View File

@ -5,14 +5,15 @@
# See the file "LICENCE" for information about the copyright # See the file "LICENCE" for information about the copyright
# and warranty status of this software. # and warranty status of this software.
from aiorpcx import _version as aiorpcx_version from asyncio import Event
from aiorpcx import _version as aiorpcx_version, TaskGroup
import electrumx import electrumx
from electrumx.lib.server_base import ServerBase from electrumx.lib.server_base import ServerBase
from electrumx.lib.util import version_string from electrumx.lib.util import version_string
from electrumx.server.chain_state import ChainState from electrumx.server.chain_state import ChainState
from electrumx.server.mempool import MemPool from electrumx.server.mempool import MemPool
from electrumx.server.peers import PeerManager
from electrumx.server.session import SessionManager from electrumx.server.session import SessionManager
@ -76,17 +77,16 @@ class Controller(ServerBase):
Servers are started once the mempool is synced after the block Servers are started once the mempool is synced after the block
processor first catches up with the daemon. processor first catches up with the daemon.
''' '''
async def serve(self, shutdown_event):
'''Start the RPC server and wait for the mempool to synchronize. Then
start serving external clients.
'''
reqd_version = (0, 5, 8)
if aiorpcx_version != reqd_version:
raise RuntimeError('ElectrumX requires aiorpcX version '
f'{version_string(reqd_version)}')
AIORPCX_MIN = (0, 5, 6) env = self.env
def __init__(self, env):
'''Initialize everything that doesn't require the event loop.'''
super().__init__(env)
if aiorpcx_version < self.AIORPCX_MIN:
raise RuntimeError('ElectrumX requires aiorpcX >= '
f'{version_string(self.AIORPCX_MIN)}')
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()
self.logger.info(f'software version: {electrumx.version}') self.logger.info(f'software version: {electrumx.version}')
self.logger.info(f'aiorpcX version: {version_string(aiorpcx_version)}') self.logger.info(f'aiorpcX version: {version_string(aiorpcx_version)}')
@ -97,29 +97,20 @@ class Controller(ServerBase):
notifications = Notifications() notifications = Notifications()
daemon = env.coin.DAEMON(env) daemon = env.coin.DAEMON(env)
BlockProcessor = env.coin.BLOCK_PROCESSOR BlockProcessor = env.coin.BLOCK_PROCESSOR
self.bp = BlockProcessor(env, self.tasks, daemon, notifications) bp = BlockProcessor(env, daemon, notifications)
self.mempool = MemPool(env.coin, self.tasks, daemon, notifications, mempool = MemPool(env.coin, daemon, notifications, bp.lookup_utxos)
self.bp.lookup_utxos) chain_state = ChainState(env, daemon, bp, notifications)
self.chain_state = ChainState(env, self.tasks, daemon, self.bp, session_mgr = SessionManager(env, chain_state, mempool,
notifications) notifications, shutdown_event)
self.session_mgr = SessionManager(env, self.tasks, self.chain_state,
self.mempool, notifications,
self.shutdown_event)
async def start_servers(self): caught_up_event = Event()
'''Start the RPC server and wait for the mempool to synchronize. Then serve_externally_event = Event()
start serving external clients. synchronized_event = Event()
'''
await self.session_mgr.start_rpc_server()
await self.bp.catch_up_to_daemon()
await self.mempool.start_and_wait_for_sync()
await self.session_mgr.start_serving()
async def shutdown(self): async with TaskGroup() as group:
'''Perform the shutdown sequence.''' await group.spawn(session_mgr.serve(serve_externally_event))
# Close servers and connections - main source of new task creation await group.spawn(bp.fetch_and_process_blocks(caught_up_event))
await self.session_mgr.shutdown() await caught_up_event.wait()
# Flush chain state to disk await group.spawn(mempool.keep_synchronized(synchronized_event))
await self.chain_state.shutdown() await synchronized_event.wait()
# Cancel all tasks; this shuts down the prefetcher serve_externally_event.set()
await self.tasks.cancel_all(wait=True)

View File

@ -17,6 +17,8 @@ from collections import namedtuple
from glob import glob from glob import glob
from struct import pack, unpack from struct import pack, unpack
from aiorpcx import run_in_thread
import electrumx.lib.util as util import electrumx.lib.util as util
from electrumx.lib.hash import hash_to_hex_str, HASHX_LEN from electrumx.lib.hash import hash_to_hex_str, HASHX_LEN
from electrumx.server.storage import db_class from electrumx.server.storage import db_class
@ -442,6 +444,5 @@ class DB(object):
return hashX, value return hashX, value
return [lookup_utxo(*hashX_pair) for hashX_pair in hashX_pairs] return [lookup_utxo(*hashX_pair) for hashX_pair in hashX_pairs]
run_in_thread = self.tasks.run_in_thread
hashX_pairs = await run_in_thread(lookup_hashXs) hashX_pairs = await run_in_thread(lookup_hashXs)
return await run_in_thread(lookup_utxos, hashX_pairs) return await run_in_thread(lookup_utxos, hashX_pairs)

View File

@ -13,6 +13,7 @@ import time
from collections import defaultdict from collections import defaultdict
import attr import attr
from aiorpcx import TaskGroup, run_in_thread
from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash
from electrumx.lib.util import class_logger, chunks from electrumx.lib.util import class_logger, chunks
@ -40,11 +41,10 @@ class MemPool(object):
hashXs: hashX -> set of all hashes of txs touching the hashX hashXs: hashX -> set of all hashes of txs touching the hashX
''' '''
def __init__(self, coin, tasks, daemon, notifications, lookup_utxos): def __init__(self, coin, daemon, notifications, lookup_utxos):
self.logger = class_logger(__name__, self.__class__.__name__) self.logger = class_logger(__name__, self.__class__.__name__)
self.coin = coin self.coin = coin
self.lookup_utxos = lookup_utxos self.lookup_utxos = lookup_utxos
self.tasks = tasks
self.daemon = daemon self.daemon = daemon
self.notifications = notifications self.notifications = notifications
self.txs = {} self.txs = {}
@ -127,7 +127,7 @@ class MemPool(object):
return deferred, {prevout: utxo_map[prevout] for prevout in unspent} return deferred, {prevout: utxo_map[prevout] for prevout in unspent}
async def _refresh_hashes(self, once): async def _refresh_hashes(self, synchronized_event):
'''Refresh our view of the daemon's mempool.''' '''Refresh our view of the daemon's mempool.'''
sleep = 5 sleep = 5
histogram_refresh = self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS // sleep histogram_refresh = self.coin.MEMPOOL_HISTOGRAM_REFRESH_SECS // sleep
@ -138,12 +138,11 @@ class MemPool(object):
continue continue
hashes = set(hex_str_to_hash(hh) for hh in hex_hashes) hashes = set(hex_str_to_hash(hh) for hh in hex_hashes)
touched = await self._process_mempool(hashes) touched = await self._process_mempool(hashes)
synchronized_event.set()
await self.notifications.on_mempool(touched, height) await self.notifications.on_mempool(touched, height)
# Thread mempool histogram refreshes - they can be expensive # Thread mempool histogram refreshes - they can be expensive
if loop_count % histogram_refresh == 0: if loop_count % histogram_refresh == 0:
await self.tasks.run_in_thread(self._update_histogram) await run_in_thread(self._update_histogram)
if once:
return
await asyncio.sleep(sleep) await asyncio.sleep(sleep)
async def _process_mempool(self, all_hashes): async def _process_mempool(self, all_hashes):
@ -165,16 +164,15 @@ class MemPool(object):
# Process new transactions # Process new transactions
new_hashes = list(all_hashes.difference(txs)) new_hashes = list(all_hashes.difference(txs))
jobs = [self.tasks.create_task(self._fetch_and_accept if new_hashes:
(hashes, all_hashes, touched), group = TaskGroup()
daemon=False) for hashes in chunks(new_hashes, 200):
for hashes in chunks(new_hashes, 2000)] coro = self._fetch_and_accept(hashes, all_hashes, touched)
if jobs: await group.spawn(coro)
await asyncio.gather(*jobs)
tx_map = {} tx_map = {}
utxo_map = {} utxo_map = {}
for job in jobs: async for task in group:
deferred, unspent = job.result() deferred, unspent = task.result()
tx_map.update(deferred) tx_map.update(deferred)
utxo_map.update(unspent) utxo_map.update(unspent)
@ -218,7 +216,7 @@ class MemPool(object):
return txs return txs
# Thread this potentially slow operation so as not to block # Thread this potentially slow operation so as not to block
tx_map = await self.tasks.run_in_thread(deserialize_txs) tx_map = await run_in_thread(deserialize_txs)
# Determine all prevouts not in the mempool, and fetch the # Determine all prevouts not in the mempool, and fetch the
# UTXO information from the database. Failed prevout lookups # UTXO information from the database. Failed prevout lookups
@ -236,19 +234,20 @@ class MemPool(object):
# External interface # External interface
# #
async def start_and_wait_for_sync(self): async def keep_synchronized(self, synchronized_event):
'''Starts the mempool synchronizer. '''Starts the mempool synchronizer.
Waits for an initial synchronization before returning. Waits for an initial synchronization before returning.
''' '''
self.logger.info('beginning processing of daemon mempool. ' self.logger.info('beginning processing of daemon mempool. '
'This can take some time...') 'This can take some time...')
start = time.time() async with TaskGroup() as group:
await self._refresh_hashes(once=True) await group.spawn(self._refresh_hashes(synchronized_event))
elapsed = time.time() - start start = time.time()
self.logger.info(f'synced in {elapsed:.2f}s') await synchronized_event.wait()
self.tasks.create_task(self._log_stats()) elapsed = time.time() - start
self.tasks.create_task(self._refresh_hashes(once=False)) self.logger.info(f'synced in {elapsed:.2f}s')
await group.spawn(self._log_stats())
async def balance_delta(self, hashX): async def balance_delta(self, hashX):
'''Return the unconfirmed amount in the mempool for hashX. '''Return the unconfirmed amount in the mempool for hashX.

View File

@ -14,8 +14,9 @@ import ssl
import time import time
from collections import defaultdict, Counter from collections import defaultdict, Counter
from aiorpcx import (ClientSession, RPCError, SOCKSProxy, from aiorpcx import (ClientSession, SOCKSProxy, SOCKSError,
SOCKSError, ConnectionError) RPCError, ConnectionError,
TaskGroup, run_in_thread, ignore_after)
from electrumx.lib.peer import Peer from electrumx.lib.peer import Peer
from electrumx.lib.util import class_logger, protocol_tuple from electrumx.lib.util import class_logger, protocol_tuple
@ -55,14 +56,12 @@ class PeerManager(object):
Attempts to maintain a connection with up to 8 peers. Attempts to maintain a connection with up to 8 peers.
Issues a 'peers.subscribe' RPC to them and tells them our data. Issues a 'peers.subscribe' RPC to them and tells them our data.
''' '''
def __init__(self, env, tasks, chain_state): def __init__(self, env, chain_state):
self.logger = class_logger(__name__, self.__class__.__name__) self.logger = class_logger(__name__, self.__class__.__name__)
# Initialise the Peer class # Initialise the Peer class
Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS
self.env = env self.env = env
self.tasks = tasks
self.chain_state = chain_state self.chain_state = chain_state
self.loop = tasks.loop
# Our clearnet and Tor Peers, if any # Our clearnet and Tor Peers, if any
sclass = env.coin.SESSIONCLS sclass = env.coin.SESSIONCLS
@ -155,30 +154,13 @@ class PeerManager(object):
self.logger.info(f'trying to detect proxy on "{host}" ports {ports}') self.logger.info(f'trying to detect proxy on "{host}" ports {ports}')
cls = SOCKSProxy cls = SOCKSProxy
result = await cls.auto_detect_host(host, ports, None, loop=self.loop) result = await cls.auto_detect_host(host, ports, None)
if isinstance(result, cls): if isinstance(result, cls):
self.proxy = result self.proxy = result
self.logger.info(f'detected {self.proxy}') self.logger.info(f'detected {self.proxy}')
else: else:
self.logger.info('no proxy detected') self.logger.info('no proxy detected')
async def _discover_peers(self):
'''Main loop performing peer maintenance. This includes
1) Forgetting unreachable peers.
2) Verifying connectivity of new peers.
3) Retrying old peers at regular intervals.
'''
self._import_peers()
while True:
await self._maybe_detect_proxy()
await self._retry_peers()
timeout = self.loop.call_later(WAKEUP_SECS, self.retry_event.set)
await self.retry_event.wait()
self.retry_event.clear()
timeout.cancel()
async def _retry_peers(self): async def _retry_peers(self):
'''Retry peers that are close to getting stale.''' '''Retry peers that are close to getting stale.'''
# Exponential backoff of retries # Exponential backoff of retries
@ -195,11 +177,10 @@ class PeerManager(object):
# Retry a failed connection if enough time has passed # Retry a failed connection if enough time has passed
return peer.last_try < now - WAKEUP_SECS * 2 ** peer.try_count return peer.last_try < now - WAKEUP_SECS * 2 ** peer.try_count
tasks = [] async with TaskGroup() as group:
for peer in self.peers: for peer in self.peers:
if should_retry(peer): if should_retry(peer):
tasks.append(self.tasks.create_task(self._retry_peer(peer))) await group.spawn(self._retry_peer(peer))
await asyncio.gather(*tasks)
async def _retry_peer(self, peer): async def _retry_peer(self, peer):
peer.try_count += 1 peer.try_count += 1
@ -278,12 +259,13 @@ class PeerManager(object):
peer.features['server_version'] = server_version peer.features['server_version'] = server_version
ptuple = protocol_tuple(protocol_version) ptuple = protocol_tuple(protocol_version)
jobs = [self.tasks.create_task(message, daemon=False) for message in ( async with TaskGroup() as group:
self._send_headers_subscribe(session, peer, timeout, ptuple), await group.spawn(self._send_headers_subscribe(session, peer,
self._send_server_features(session, peer, timeout), timeout, ptuple))
self._send_peers_subscribe(session, peer, timeout) await group.spawn(self._send_server_features(session, peer,
)] timeout))
await asyncio.gather(*jobs) await group.spawn(self._send_peers_subscribe(session, peer,
timeout))
async def _send_headers_subscribe(self, session, peer, timeout, ptuple): async def _send_headers_subscribe(self, session, peer, timeout, ptuple):
message = 'blockchain.headers.subscribe' message = 'blockchain.headers.subscribe'
@ -389,13 +371,27 @@ class PeerManager(object):
# #
# External interface # External interface
# #
def start_peer_discovery(self): async def discover_peers(self):
if self.env.peer_discovery == self.env.PD_ON: '''Perform peer maintenance. This includes
self.logger.info(f'beginning peer discovery. Force use of '
f'proxy: {self.env.force_proxy}') 1) Forgetting unreachable peers.
self.tasks.create_task(self._discover_peers()) 2) Verifying connectivity of new peers.
else: 3) Retrying old peers at regular intervals.
'''
if self.env.peer_discovery != self.env.PD_ON:
self.logger.info('peer discovery is disabled') self.logger.info('peer discovery is disabled')
return
self.logger.info(f'beginning peer discovery. Force use of '
f'proxy: {self.env.force_proxy}')
self._import_peers()
while True:
await self._maybe_detect_proxy()
await self._retry_peers()
async with ignore_after(WAKEUP_SECS):
await self.retry_event.wait()
self.retry_event.clear()
def add_peers(self, peers, limit=2, check_ports=False, source=None): def add_peers(self, peers, limit=2, check_ports=False, source=None):
'''Add a limited number of peers that are not already present.''' '''Add a limited number of peers that are not already present.'''
@ -422,9 +418,8 @@ class PeerManager(object):
use_peers = new_peers[:limit] use_peers = new_peers[:limit]
else: else:
use_peers = new_peers use_peers = new_peers
for n, peer in enumerate(use_peers): for peer in use_peers:
self.logger.info(f'accepted new peer {n+1}/{len(use_peers)} ' self.logger.info(f'accepted new peer {peer} from {source}')
f'{peer} from {source}')
self.peers.update(use_peers) self.peers.update(use_peers)
if retry: if retry:
@ -460,9 +455,9 @@ class PeerManager(object):
permit = self._permit_new_onion_peer() permit = self._permit_new_onion_peer()
reason = 'rate limiting' reason = 'rate limiting'
else: else:
getaddrinfo = asyncio.get_event_loop().getaddrinfo
try: try:
infos = await self.loop.getaddrinfo(host, 80, infos = await getaddrinfo(host, 80, type=socket.SOCK_STREAM)
type=socket.SOCK_STREAM)
except socket.gaierror: except socket.gaierror:
permit = False permit = False
reason = 'address resolution failure' reason = 'address resolution failure'

View File

@ -18,7 +18,7 @@ import time
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from aiorpcx import ServerSession, JSONRPCAutoDetect, RPCError from aiorpcx import ServerSession, JSONRPCAutoDetect, RPCError, TaskGroup
import electrumx import electrumx
import electrumx.lib.text as text import electrumx.lib.text as text
@ -27,6 +27,7 @@ from electrumx.lib.hash import (sha256, hash_to_hex_str, hex_str_to_hash,
HASHX_LEN) HASHX_LEN)
from electrumx.lib.peer import Peer from electrumx.lib.peer import Peer
from electrumx.server.daemon import DaemonError from electrumx.server.daemon import DaemonError
from electrumx.server.peers import PeerManager
BAD_REQUEST = 1 BAD_REQUEST = 1
@ -97,14 +98,13 @@ class SessionManager(object):
CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4) CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4)
def __init__(self, env, tasks, chain_state, mempool, notifications, def __init__(self, env, chain_state, mempool, notifications,
shutdown_event): shutdown_event):
env.max_send = max(350000, env.max_send) env.max_send = max(350000, env.max_send)
self.env = env self.env = env
self.tasks = tasks
self.chain_state = chain_state self.chain_state = chain_state
self.mempool = mempool self.mempool = mempool
self.peer_mgr = PeerManager(env, tasks, chain_state) self.peer_mgr = PeerManager(env, chain_state)
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
self.logger = util.class_logger(__name__, self.__class__.__name__) self.logger = util.class_logger(__name__, self.__class__.__name__)
self.servers = {} self.servers = {}
@ -396,42 +396,42 @@ class SessionManager(object):
# --- External Interface # --- External Interface
async def start_rpc_server(self): async def serve(self, event):
'''Start the RPC server if enabled.''' '''Start the RPC server if enabled. When the event is triggered,
if self.env.rpc_port is not None: start TCP and SSL servers.'''
await self._start_server('RPC', self.env.cs_host(for_rpc=True), try:
self.env.rpc_port) if self.env.rpc_port is not None:
await self._start_server('RPC', self.env.cs_host(for_rpc=True),
async def start_serving(self): self.env.rpc_port)
'''Start TCP and SSL servers.''' await event.wait()
self.logger.info('max session count: {:,d}'.format(self.max_sessions)) self.logger.info(f'max session count: {self.max_sessions:,d}')
self.logger.info('session timeout: {:,d} seconds' self.logger.info(f'session timeout: '
.format(self.env.session_timeout)) f'{self.env.session_timeout:,d} seconds')
self.logger.info('session bandwidth limit {:,d} bytes' self.logger.info('session bandwidth limit {:,d} bytes'
.format(self.env.bandwidth_limit)) .format(self.env.bandwidth_limit))
self.logger.info('max response size {:,d} bytes' self.logger.info('max response size {:,d} bytes'
.format(self.env.max_send)) .format(self.env.max_send))
self.logger.info('max subscriptions across all sessions: {:,d}' self.logger.info('max subscriptions across all sessions: {:,d}'
.format(self.max_subs)) .format(self.max_subs))
self.logger.info('max subscriptions per session: {:,d}' self.logger.info('max subscriptions per session: {:,d}'
.format(self.env.max_session_subs)) .format(self.env.max_session_subs))
if self.env.drop_client is not None: if self.env.drop_client is not None:
self.logger.info('drop clients matching: {}' self.logger.info('drop clients matching: {}'
.format(self.env.drop_client.pattern)) .format(self.env.drop_client.pattern))
await self._start_external_servers() await self._start_external_servers()
# Peer discovery should start after the external servers # Peer discovery should start after the external servers
# because we connect to ourself # because we connect to ourself
self.peer_mgr.start_peer_discovery() async with TaskGroup() as group:
self.tasks.create_task(self._housekeeping()) await group.spawn(self.peer_mgr.discover_peers())
await group.spawn(self._housekeeping())
async def shutdown(self): finally:
'''Close servers and sessions.''' # Close servers and sessions
self.state = self.SHUTTING_DOWN self.state = self.SHUTTING_DOWN
self._close_servers(list(self.servers.keys())) self._close_servers(list(self.servers.keys()))
for session in self.sessions: for session in self.sessions:
session.abort() session.abort()
for session in list(self.sessions): for session in list(self.sessions):
await session.wait_closed() await session.wait_closed()
def session_count(self): def session_count(self):
'''The number of connections that we've sent something to.''' '''The number of connections that we've sent something to.'''
@ -439,9 +439,9 @@ class SessionManager(object):
async def _notify_sessions(self, height, touched): async def _notify_sessions(self, height, touched):
'''Notify sessions about height changes and touched addresses.''' '''Notify sessions about height changes and touched addresses.'''
create_task = self.tasks.create_task async with TaskGroup() as group:
for session in self.sessions: for session in self.sessions:
create_task(session.notify(height, touched)) await group.spawn(session.notify(height, touched))
def add_session(self, session): def add_session(self, session):
self.sessions.add(session) self.sessions.add(session)