From 8452d0c01636bd6a07f0ecc58a84fd446d33d5fe Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Tue, 18 Oct 2016 06:35:45 +0900 Subject: [PATCH] Split out daemon handler into separate file. --- server/controller.py | 85 +++++++++++--------------------------------- server/daemon.py | 70 ++++++++++++++++++++++++++++++++++++ server/protocol.py | 14 ++++---- 3 files changed, 99 insertions(+), 70 deletions(-) create mode 100644 server/daemon.py diff --git a/server/controller.py b/server/controller.py index ca013da..23245ab 100644 --- a/server/controller.py +++ b/server/controller.py @@ -2,13 +2,11 @@ # and warranty status of this software. import asyncio -import json import signal import traceback from functools import partial -import aiohttp - +from server.daemon import Daemon, DaemonError from server.db import DB from server.protocol import ElectrumX, LocalRPC from lib.hash import (sha256, double_sha256, hash_to_str, @@ -19,10 +17,15 @@ from lib.util import LoggedClass class Controller(LoggedClass): def __init__(self, env): + '''Create up the controller. + + Creates DB, Daemon and BlockCache instances. + ''' super().__init__() self.env = env self.db = DB(env) - self.block_cache = BlockCache(env, self.db) + self.daemon = Daemon(env.daemon_url) + self.block_cache = BlockCache(self.db, self.daemon) self.servers = [] self.sessions = set() self.addresses = {} @@ -30,6 +33,7 @@ class Controller(LoggedClass): self.peers = {} def start(self, loop): + '''Prime the event loop with asynchronous servers and jobs.''' env = self.env if False: @@ -41,7 +45,7 @@ class Controller(LoggedClass): self.logger.info('RPC server listening on {}:{:d}' .format(host, env.rpc_port)) - protocol = partial(ElectrumX, self, env) + protocol = partial(ElectrumX, self, self.db, self.daemon, env) if env.tcp_port is not None: tcp_server = loop.create_server(protocol, env.host, env.tcp_port) self.servers.append(loop.run_until_complete(tcp_server)) @@ -68,10 +72,12 @@ class Controller(LoggedClass): partial(self.on_signal, loop, signame)) def stop(self): + '''Close the listening servers.''' for server in self.servers: server.close() def on_signal(self, loop, signame): + '''Call on receipt of a signal to cleanly shutdown.''' self.logger.warning('received {} signal, preparing to shut down' .format(signame)) for task in asyncio.Task.all_tasks(loop): @@ -119,9 +125,8 @@ class Controller(LoggedClass): async def get_merkle(self, tx_hash, height): '''tx_hash is a hex string.''' - daemon_send = self.block_cache.send_single - block_hash = await daemon_send('getblockhash', (height,)) - block = await daemon_send('getblock', (block_hash, True)) + block_hash = await self.daemon.send_single('getblockhash', (height,)) + block = await self.daemon.send_single('getblock', (block_hash, True)) tx_hashes = block['tx'] # This will throw if the tx_hash is bad pos = tx_hashes.index(tx_hash) @@ -151,13 +156,10 @@ class BlockCache(LoggedClass): block chain reorganisations. ''' - class DaemonError: - pass - - def __init__(self, env, db): + def __init__(self, db, daemon): super().__init__() self.db = db - self.daemon_url = env.daemon_url + self.daemon = daemon # Target cache size. Has little effect on sync time. self.target_cache_size = 10 * 1024 * 1024 self.daemon_height = 0 @@ -166,8 +168,6 @@ class BlockCache(LoggedClass): self.queue_size = 0 self.recent_sizes = [0] - self.logger.info('using daemon URL {}'.format(self.daemon_url)) - def flush_db(self): self.db.flush(self.daemon_height, True) @@ -194,8 +194,8 @@ class BlockCache(LoggedClass): while True: try: await self.maybe_prefetch() - except self.DaemonError: - pass + except DaemonError as e: + self.logger.info('ignoring daemon errors: {}'.format(e)) await asyncio.sleep(2) def cache_used(self): @@ -208,9 +208,10 @@ class BlockCache(LoggedClass): async def maybe_prefetch(self): '''Prefetch blocks if there are any to prefetch.''' + daemon = self.daemon while self.queue_size < self.target_cache_size: # Keep going by getting a whole new cache_limit of blocks - self.daemon_height = await self.send_single('getblockcount') + self.daemon_height = await daemon.send_single('getblockcount') max_count = min(self.daemon_height - self.fetched_height, 4000) count = min(max_count, self.prefill_count(self.target_cache_size)) if not count: @@ -218,11 +219,11 @@ class BlockCache(LoggedClass): first = self.fetched_height + 1 param_lists = [[height] for height in range(first, first + count)] - hashes = await self.send_vector('getblockhash', param_lists) + hashes = await daemon.send_vector('getblockhash', param_lists) # Hashes is an array of hex strings param_lists = [(h, False) for h in hashes] - blocks = await self.send_vector('getblock', param_lists) + blocks = await daemon.send_vector('getblock', param_lists) self.fetched_height += count # Convert hex string to bytes @@ -237,47 +238,3 @@ class BlockCache(LoggedClass): excess = len(self.recent_sizes) - 50 if excess > 0: self.recent_sizes = self.recent_sizes[excess:] - - async def send_single(self, method, params=None): - payload = {'method': method} - if params: - payload['params'] = params - result, = await self.send((payload, )) - return result - - async def send_many(self, mp_pairs): - payload = [{'method': method, 'params': params} - for method, params in mp_pairs] - return await self.send(payload) - - async def send_vector(self, method, params_list): - payload = [{'method': method, 'params': params} - for params in params_list] - return await self.send(payload) - - async def send(self, payload): - assert isinstance(payload, (tuple, list)) - data = json.dumps(payload) - while True: - try: - async with aiohttp.post(self.daemon_url, data=data) as resp: - result = await resp.json() - except asyncio.CancelledError: - raise - except Exception as e: - msg = 'aiohttp error: {}'.format(e) - secs = 3 - else: - errs = tuple(item['error'] for item in result) - if not any(errs): - return tuple(item['result'] for item in result) - if any(err.get('code') == -28 for err in errs): - msg = 'daemon still warming up.' - secs = 30 - else: - msg = 'daemon errors: {}'.format(errs) - raise self.DaemonError(msg) - - self.logger.error('{}. Sleeping {:d}s and trying again...' - .format(msg, secs)) - await asyncio.sleep(secs) diff --git a/server/daemon.py b/server/daemon.py new file mode 100644 index 0000000..c60a4c8 --- /dev/null +++ b/server/daemon.py @@ -0,0 +1,70 @@ +# See the file "LICENSE" for information about the copyright +# and warranty status of this software. + +'''Classes for handling asynchronous connections to a blockchain +daemon.''' + +import asyncio +import json + +import aiohttp + +from lib.util import LoggedClass + + +class DaemonError(Exception): + '''Raised when the daemon returns an error in its results that + cannot be remedied by retrying.''' + + +class Daemon(LoggedClass): + '''Handles connections to a daemon at the given URL.''' + + def __init__(self, url): + super().__init__() + self.url = url + self.logger.info('connecting to daemon at URL {}'.format(url)) + + async def send_single(self, method, params=None): + payload = {'method': method} + if params: + payload['params'] = params + result, = await self.send((payload, )) + return result + + async def send_many(self, mp_pairs): + payload = [{'method': method, 'params': params} + for method, params in mp_pairs] + return await self.send(payload) + + async def send_vector(self, method, params_list): + payload = [{'method': method, 'params': params} + for params in params_list] + return await self.send(payload) + + async def send(self, payload): + assert isinstance(payload, (tuple, list)) + data = json.dumps(payload) + while True: + try: + async with aiohttp.post(self.url, data=data) as resp: + result = await resp.json() + except asyncio.CancelledError: + raise + except Exception as e: + msg = 'aiohttp error: {}'.format(e) + secs = 3 + else: + errs = tuple(item['error'] for item in result) + if not any(errs): + return tuple(item['result'] for item in result) + if any(err.get('code') == -28 for err in errs): + msg = 'daemon still warming up.' + secs = 30 + else: + msg = '{}'.format(errs) + raise DaemonError(msg) + + self.logger.error('{}. Sleeping {:d}s and trying again...' + .format(msg, secs)) + await asyncio.sleep(secs) diff --git a/server/protocol.py b/server/protocol.py index cf72aa8..a20f08c 100644 --- a/server/protocol.py +++ b/server/protocol.py @@ -100,11 +100,12 @@ class JSONRPC(asyncio.Protocol, LoggedClass): class ElectrumX(JSONRPC): + '''A TCP server that handles incoming Electrum connections.''' - def __init__(self, controller, env): + def __init__(self, controller, db, daemon, env): super().__init__(controller) - self.BC = controller.block_cache - self.db = controller.db + self.db = db + self.daemon = daemon self.env = env self.addresses = set() self.subscribe_headers = False @@ -134,7 +135,7 @@ class ElectrumX(JSONRPC): return status.hex() if status else None async def handle_blockchain_estimatefee(self, params): - result = await self.BC.send_single('estimatefee', params) + result = await self.daemon.send_single('estimatefee', params) return result async def handle_blockchain_headers_subscribe(self, params): @@ -145,7 +146,7 @@ class ElectrumX(JSONRPC): '''The minimum fee a low-priority tx must pay in order to be accepted to this daemon's memory pool. ''' - net_info = await self.BC.send_single('getnetworkinfo') + net_info = await self.daemon.send_single('getnetworkinfo') return net_info['relayfee'] async def handle_blockchain_transaction_get(self, params): @@ -153,7 +154,7 @@ class ElectrumX(JSONRPC): raise Error(Error.BAD_REQUEST, 'params should contain a transaction hash') tx_hash = params[0] - return await self.BC.send_single('getrawtransaction', (tx_hash, 0)) + return await self.daemon.send_single('getrawtransaction', (tx_hash, 0)) async def handle_blockchain_transaction_get_merkle(self, params): if len(params) != 2: @@ -196,6 +197,7 @@ class ElectrumX(JSONRPC): class LocalRPC(JSONRPC): + '''A local TCP RPC server for querying status.''' async def handle_getinfo(self, params): return {