diff --git a/docs/conf.py b/docs/conf.py index a84ef96..f6c5a69 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,7 +15,7 @@ import os import sys sys.path.insert(0, os.path.abspath('..')) -VERSION="ElectrumX 1.8.3" +VERSION="ElectrumX 1.8.4" # -- Project information ----------------------------------------------------- diff --git a/electrumx/__init__.py b/electrumx/__init__.py index a352144..fd5584f 100644 --- a/electrumx/__init__.py +++ b/electrumx/__init__.py @@ -1,4 +1,4 @@ -version = 'ElectrumX 1.8.3' +version = 'ElectrumX 1.8.4-dev' version_short = version.split()[-1] from electrumx.server.controller import Controller diff --git a/electrumx/lib/coins.py b/electrumx/lib/coins.py index fcc3141..8a5c1b0 100644 --- a/electrumx/lib/coins.py +++ b/electrumx/lib/coins.py @@ -114,10 +114,6 @@ class Coin(object): url = 'http://' + url return url + '/' - @classmethod - def daemon_urls(cls, urls): - return [cls.sanitize_url(url) for url in urls.split(',')] - @classmethod def genesis_block(cls, block): '''Check the Genesis block is the right one for this coin. diff --git a/electrumx/lib/server_base.py b/electrumx/lib/server_base.py index db7a213..f3688d6 100644 --- a/electrumx/lib/server_base.py +++ b/electrumx/lib/server_base.py @@ -13,7 +13,7 @@ import sys import time from functools import partial -from aiorpcx import TaskGroup +from aiorpcx import spawn from electrumx.lib.util import class_logger @@ -93,12 +93,11 @@ class ServerBase(object): loop.set_exception_handler(self.on_exception) shutdown_event = asyncio.Event() - async with TaskGroup() as group: - server_task = await group.spawn(self.serve(shutdown_event)) - # Wait for shutdown, log on receipt of the event - await shutdown_event.wait() - self.logger.info('shutting down') - server_task.cancel() + server_task = await spawn(self.serve(shutdown_event)) + # Wait for shutdown, log on receipt of the event + await shutdown_event.wait() + self.logger.info('shutting down') + server_task.cancel() # Prevent some silly logs await asyncio.sleep(0.01) diff --git a/electrumx/server/block_processor.py b/electrumx/server/block_processor.py index bea879e..e658fe5 100644 --- a/electrumx/server/block_processor.py +++ b/electrumx/server/block_processor.py @@ -650,10 +650,7 @@ class BlockProcessor(object): could be lost. ''' self._caught_up_event = caught_up_event - async with TaskGroup() as group: - await group.spawn(self._first_open_dbs()) - # Ensure cached_height is set - await group.spawn(self.daemon.height()) + await self._first_open_dbs() try: async with TaskGroup() as group: await group.spawn(self.prefetcher.main_loop(self.height)) diff --git a/electrumx/server/chain_state.py b/electrumx/server/chain_state.py deleted file mode 100644 index 135d42a..0000000 --- a/electrumx/server/chain_state.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2016-2018, Neil Booth -# -# All rights reserved. -# -# See the file "LICENCE" for information about the copyright -# and warranty status of this software. - - -from electrumx.lib.hash import hash_to_hex_str - - -class ChainState(object): - '''Used as an interface by servers to request information about - blocks, transaction history, UTXOs and the mempool. - ''' - - def __init__(self, env, db, daemon, bp): - self._env = env - self._db = db - self._daemon = daemon - - # External interface pass-throughs for session.py - self.force_chain_reorg = bp.force_chain_reorg - self.tx_branch_and_root = db.merkle.branch_and_root - self.read_headers = db.read_headers - self.all_utxos = db.all_utxos - self.limited_history = db.limited_history - self.header_branch_and_root = db.header_branch_and_root - - async def broadcast_transaction(self, raw_tx): - return await self._daemon.sendrawtransaction([raw_tx]) - - async def daemon_request(self, method, args=()): - return await getattr(self._daemon, method)(*args) - - def db_height(self): - return self._db.db_height - - def get_info(self): - '''Chain state info for LocalRPC and logs.''' - return { - 'daemon': self._daemon.logged_url(), - 'daemon_height': self._daemon.cached_height(), - 'db_height': self.db_height(), - } - - async def raw_header(self, height): - '''Return the binary header at the given height.''' - header, n = await self.read_headers(height, 1) - if n != 1: - raise IndexError(f'height {height:,d} out of range') - return header - - def set_daemon_url(self, daemon_url): - self._daemon.set_urls(self._env.coin.daemon_urls(daemon_url)) - return self._daemon.logged_url() - - async def query(self, args, limit): - coin = self._env.coin - db = self._db - lines = [] - - def arg_to_hashX(arg): - try: - script = bytes.fromhex(arg) - lines.append(f'Script: {arg}') - return coin.hashX_from_script(script) - except ValueError: - pass - - hashX = coin.address_to_hashX(arg) - lines.append(f'Address: {arg}') - return hashX - - for arg in args: - hashX = arg_to_hashX(arg) - if not hashX: - continue - n = None - history = await db.limited_history(hashX, limit=limit) - for n, (tx_hash, height) in enumerate(history): - lines.append(f'History #{n:,d}: height {height:,d} ' - f'tx_hash {hash_to_hex_str(tx_hash)}') - if n is None: - lines.append('No history found') - n = None - utxos = await db.all_utxos(hashX) - for n, utxo in enumerate(utxos, start=1): - lines.append(f'UTXO #{n:,d}: tx_hash ' - f'{hash_to_hex_str(utxo.tx_hash)} ' - f'tx_pos {utxo.tx_pos:,d} height ' - f'{utxo.height:,d} value {utxo.value:,d}') - if n == limit: - break - if n is None: - lines.append('No UTXOs found') - - balance = sum(utxo.value for utxo in utxos) - lines.append(f'Balance: {coin.decimal_value(balance):,f} ' - f'{coin.SHORTNAME}') - - return lines diff --git a/electrumx/server/controller.py b/electrumx/server/controller.py index 9061f93..09d022b 100644 --- a/electrumx/server/controller.py +++ b/electrumx/server/controller.py @@ -12,7 +12,6 @@ from aiorpcx import _version as aiorpcx_version, TaskGroup import electrumx from electrumx.lib.server_base import ServerBase from electrumx.lib.util import version_string -from electrumx.server.chain_state import ChainState from electrumx.server.db import DB from electrumx.server.mempool import MemPool, MemPoolAPI from electrumx.server.session import SessionManager @@ -93,11 +92,12 @@ class Controller(ServerBase): self.logger.info(f'reorg limit is {env.reorg_limit:,d} blocks') notifications = Notifications() - daemon = env.coin.DAEMON(env) - db = DB(env) + Daemon = env.coin.DAEMON BlockProcessor = env.coin.BLOCK_PROCESSOR + + daemon = Daemon(env.coin, env.daemon_url) + db = DB(env) bp = BlockProcessor(env, db, daemon, notifications) - chain_state = ChainState(env, db, daemon, bp) # Set ourselves up to implement the MemPoolAPI self.height = daemon.height @@ -109,13 +109,16 @@ class Controller(ServerBase): MemPoolAPI.register(Controller) mempool = MemPool(env.coin, self) - session_mgr = SessionManager(env, chain_state, mempool, + session_mgr = SessionManager(env, db, bp, daemon, mempool, notifications, shutdown_event) + # Test daemon authentication, and also ensure it has a cached + # height. Do this before entering the task group. + await daemon.height() + caught_up_event = Event() serve_externally_event = Event() synchronized_event = Event() - async with TaskGroup() as group: await group.spawn(session_mgr.serve(serve_externally_event)) await group.spawn(bp.fetch_and_process_blocks(caught_up_event)) diff --git a/electrumx/server/daemon.py b/electrumx/server/daemon.py index 4c47834..6302145 100644 --- a/electrumx/server/daemon.py +++ b/electrumx/server/daemon.py @@ -9,6 +9,7 @@ daemon.''' import asyncio +import itertools import json import time from calendar import timegm @@ -28,48 +29,53 @@ class DaemonError(Exception): '''Raised when the daemon returns an error in its results.''' +class WarmingUpError(Exception): + '''Internal - when the daemon is warming up.''' + + +class WorkQueueFullError(Exception): + '''Internal - when the daemon's work queue is full.''' + + class Daemon(object): '''Handles connections to a daemon at the given URL.''' WARMING_UP = -28 - RPC_MISC_ERROR = -1 + id_counter = itertools.count() - class DaemonWarmingUpError(Exception): - '''Raised when the daemon returns an error in its results.''' - - def __init__(self, env): + def __init__(self, coin, url, max_workqueue=10, init_retry=0.25, + max_retry=4.0): + self.coin = coin self.logger = class_logger(__name__, self.__class__.__name__) - self.coin = env.coin - self.set_urls(env.coin.daemon_urls(env.daemon_url)) - self._height = None + self.set_url(url) # Limit concurrent RPC calls to this number. # See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16 - self.workqueue_semaphore = asyncio.Semaphore(value=10) - self.down = False - self.last_error_time = 0 - self.req_id = 0 - self._available_rpcs = {} # caches results for _is_rpc_available() + self.workqueue_semaphore = asyncio.Semaphore(value=max_workqueue) + self.init_retry = init_retry + self.max_retry = max_retry + self._height = None + self.available_rpcs = {} - def next_req_id(self): - '''Retrns the next request ID.''' - self.req_id += 1 - return self.req_id - - def set_urls(self, urls): + def set_url(self, url): '''Set the URLS to the given list, and switch to the first one.''' - if not urls: - raise DaemonError('no daemon URLs provided') - self.urls = urls - self.url_index = 0 + urls = url.split(',') + urls = [self.coin.sanitize_url(url) for url in urls] for n, url in enumerate(urls): - self.logger.info('daemon #{:d} at {}{}' - .format(n + 1, self.logged_url(url), - '' if n else ' (current)')) + status = '' if n else ' (current)' + logged_url = self.logged_url(url) + self.logger.info(f'daemon #{n + 1} at {logged_url}{status}') + self.url_index = 0 + self.urls = urls - def url(self): + def current_url(self): '''Returns the current daemon URL.''' return self.urls[self.url_index] + def logged_url(self, url=None): + '''The host and port part, for logging.''' + url = url or self.current_url() + return url[url.rindex('@') + 1:] + def failover(self): '''Call to fail-over to the next daemon URL. @@ -77,7 +83,7 @@ class Daemon(object): ''' if len(self.urls) > 1: self.url_index = (self.url_index + 1) % len(self.urls) - self.logger.info('failing over to {}'.format(self.logged_url())) + self.logger.info(f'failing over to {self.logged_url()}') return True return False @@ -88,13 +94,17 @@ class Daemon(object): async def _send_data(self, data): async with self.workqueue_semaphore: async with self.client_session() as session: - async with session.post(self.url(), data=data) as resp: - # If bitcoind can't find a tx, for some reason - # it returns 500 but fills out the JSON. - # Should still return 200 IMO. - if resp.status in (200, 404, 500): + async with session.post(self.current_url(), data=data) as resp: + kind = resp.headers.get('Content-Type', None) + if kind == 'application/json': return await resp.json() - return (resp.status, resp.reason) + # bitcoind's HTTP protocol "handling" is a bad joke + text = await resp.text() + if 'Work queue depth exceeded' in text: + raise WorkQueueFullError + text = text.strip() or resp.reason + self.logger.error(text) + raise DaemonError(text) async def _send(self, payload, processor): '''Send a payload to be converted to JSON. @@ -103,54 +113,42 @@ class Daemon(object): are raise through DaemonError. ''' def log_error(error): - self.down = True + nonlocal last_error_log, retry now = time.time() - prior_time = self.last_error_time - if now - prior_time > 60: - self.last_error_time = now - if prior_time and self.failover(): - secs = 0 - else: - self.logger.error('{} Retrying occasionally...' - .format(error)) + if now - last_error_log > 60: + last_error_time = now + self.logger.error(f'{error} Retrying occasionally...') + if retry == self.max_retry and self.failover(): + retry = 0 + on_good_message = None + last_error_log = 0 data = json.dumps(payload) - secs = 1 - max_secs = 4 + retry = self.init_retry while True: try: result = await self._send_data(data) - if not isinstance(result, tuple): - result = processor(result) - if self.down: - self.down = False - self.last_error_time = 0 - self.logger.info('connection restored') - return result - log_error('HTTP error code {:d}: {}' - .format(result[0], result[1])) + result = processor(result) + if on_good_message: + self.logger.info(on_good_message) + return result except asyncio.TimeoutError: log_error('timeout error.') except aiohttp.ServerDisconnectedError: log_error('disconnected.') - except aiohttp.ClientPayloadError: - log_error('payload encoding error.') + on_good_message = 'connection restored' except aiohttp.ClientConnectionError: log_error('connection problem - is your daemon running?') - except self.DaemonWarmingUpError: + on_good_message = 'connection restored' + except WarmingUpError: log_error('starting up checking blocks.') - except (asyncio.CancelledError, DaemonError): - raise - except Exception as e: - self.logger.exception(f'uncaught exception: {e}') + on_good_message = 'running normally' + except WorkQueueFullError: + log_error('work queue full.') + on_good_message = 'running normally' - await asyncio.sleep(secs) - secs = min(max_secs, secs * 2, 1) - - def logged_url(self, url=None): - '''The host and port part, for logging.''' - url = url or self.url() - return url[url.rindex('@') + 1:] + await asyncio.sleep(retry) + retry = max(min(self.max_retry, retry * 2), self.init_retry) async def _send_single(self, method, params=None): '''Send a single request to the daemon.''' @@ -159,10 +157,10 @@ class Daemon(object): if not err: return result['result'] if err.get('code') == self.WARMING_UP: - raise self.DaemonWarmingUpError + raise WarmingUpError raise DaemonError(err) - payload = {'method': method, 'id': self.next_req_id()} + payload = {'method': method, 'id': next(self.id_counter)} if params: payload['params'] = params return await self._send(payload, processor) @@ -176,12 +174,12 @@ class Daemon(object): def processor(result): errs = [item['error'] for item in result if item['error']] if any(err.get('code') == self.WARMING_UP for err in errs): - raise self.DaemonWarmingUpError + raise WarmingUpError if not errs or replace_errs: return [item['result'] for item in result] raise DaemonError(errs) - payload = [{'method': method, 'params': p, 'id': self.next_req_id()} + payload = [{'method': method, 'params': p, 'id': next(self.id_counter)} for p in params_iterable] if payload: return await self._send(payload, processor) @@ -192,27 +190,16 @@ class Daemon(object): Results are cached and the daemon will generally not be queried with the same method more than once.''' - available = self._available_rpcs.get(method, None) + available = self.available_rpcs.get(method) if available is None: + available = True try: await self._send_single(method) - available = True except DaemonError as e: err = e.args[0] error_code = err.get("code") - if error_code == JSONRPC.METHOD_NOT_FOUND: - available = False - elif error_code == self.RPC_MISC_ERROR: - # method found but exception was thrown in command handling - # probably because we did not provide arguments - available = True - else: - self.logger.warning('error (code {:d}: {}) when testing ' - 'RPC availability of method {}' - .format(error_code, err.get("message"), - method)) - available = False - self._available_rpcs[method] = available + available = error_code != JSONRPC.METHOD_NOT_FOUND + self.available_rpcs[method] = available return available async def block_hex_hashes(self, first, count): @@ -235,12 +222,16 @@ class Daemon(object): '''Update our record of the daemon's mempool hashes.''' return await self._send_single('getrawmempool') - async def estimatefee(self, params): - '''Return the fee estimate for the given parameters.''' + async def estimatefee(self, block_count): + '''Return the fee estimate for the block count. Units are whole + currency units per KB, e.g. 0.00000995, or -1 if no estimate + is available. + ''' + args = (block_count, ) if await self._is_rpc_available('estimatesmartfee'): - estimate = await self._send_single('estimatesmartfee', params) + estimate = await self._send_single('estimatesmartfee', args) return estimate.get('feerate', -1) - return await self._send_single('estimatefee', params) + return await self._send_single('estimatefee', args) async def getnetworkinfo(self): '''Return the result of the 'getnetworkinfo' RPC call.''' @@ -268,9 +259,9 @@ class Daemon(object): # Convert hex strings to bytes return [hex_to_bytes(tx) if tx else None for tx in txs] - async def sendrawtransaction(self, params): + async def broadcast_transaction(self, raw_tx): '''Broadcast a transaction to the network.''' - return await self._send_single('sendrawtransaction', params) + return await self._send_single('sendrawtransaction', (raw_tx, )) async def height(self): '''Query the daemon for its current height.''' @@ -299,7 +290,7 @@ class FakeEstimateFeeDaemon(Daemon): '''Daemon that simulates estimatefee and relayfee RPC calls. Coin that wants to use this daemon must define ESTIMATE_FEE & RELAY_FEE''' - async def estimatefee(self, params): + async def estimatefee(self, block_count): '''Return the fee estimate for the given parameters.''' return self.coin.ESTIMATE_FEE diff --git a/electrumx/server/db.py b/electrumx/server/db.py index 35646be..4c08497 100644 --- a/electrumx/server/db.py +++ b/electrumx/server/db.py @@ -370,6 +370,13 @@ class DB(object): # Truncate header_mc: header count is 1 more than the height. self.header_mc.truncate(height + 1) + async def raw_header(self, height): + '''Return the binary header at the given height.''' + header, n = await self.read_headers(height, 1) + if n != 1: + raise IndexError(f'height {height:,d} out of range') + return header + async def read_headers(self, start_height, count): '''Requires start_height >= 0, count >= 0. Reads as many headers as are available starting at start_height up to count. This diff --git a/electrumx/server/peers.py b/electrumx/server/peers.py index 31e9f74..ad36dbf 100644 --- a/electrumx/server/peers.py +++ b/electrumx/server/peers.py @@ -55,12 +55,12 @@ class PeerManager(object): Attempts to maintain a connection with up to 8 peers. Issues a 'peers.subscribe' RPC to them and tells them our data. ''' - def __init__(self, env, chain_state): + def __init__(self, env, db): self.logger = class_logger(__name__, self.__class__.__name__) # Initialise the Peer class Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS self.env = env - self.chain_state = chain_state + self.db = db # Our clearnet and Tor Peers, if any sclass = env.coin.SESSIONCLS @@ -300,7 +300,7 @@ class PeerManager(object): result = await session.send_request(message) assert_good(message, result, dict) - our_height = self.chain_state.db_height() + our_height = self.db.db_height if ptuple < (1, 3): their_height = result.get('block_height') else: @@ -313,7 +313,7 @@ class PeerManager(object): # Check prior header too in case of hard fork. check_height = min(our_height, their_height) - raw_header = await self.chain_state.raw_header(check_height) + raw_header = await self.db.raw_header(check_height) if ptuple >= (1, 4): ours = raw_header.hex() message = 'blockchain.block.header' diff --git a/electrumx/server/session.py b/electrumx/server/session.py index 2d2e8f1..c379f45 100644 --- a/electrumx/server/session.py +++ b/electrumx/server/session.py @@ -109,13 +109,15 @@ class SessionManager(object): CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4) - def __init__(self, env, chain_state, mempool, notifications, + def __init__(self, env, db, bp, daemon, mempool, notifications, shutdown_event): env.max_send = max(350000, env.max_send) self.env = env - self.chain_state = chain_state + self.db = db + self.bp = bp + self.daemon = daemon self.mempool = mempool - self.peer_mgr = PeerManager(env, chain_state) + self.peer_mgr = PeerManager(env, db) self.shutdown_event = shutdown_event self.logger = util.class_logger(__name__, self.__class__.__name__) self.servers = {} @@ -127,8 +129,8 @@ class SessionManager(object): self.state = self.CATCHING_UP self.txs_sent = 0 self.start_time = time.time() - self._history_cache = pylru.lrucache(256) - self._hc_height = 0 + self.history_cache = pylru.lrucache(256) + self.notified_height = None # Cache some idea of room to avoid recounting on each subscription self.subs_room = 0 # Masternode stuff only for such coins @@ -152,7 +154,7 @@ class SessionManager(object): protocol_class = LocalRPC else: protocol_class = self.env.coin.SESSIONCLS - protocol_factory = partial(protocol_class, self, self.chain_state, + protocol_factory = partial(protocol_class, self, self.db, self.mempool, self.peer_mgr, kind) server = loop.create_server(protocol_factory, *args, **kw_args) @@ -276,10 +278,11 @@ class SessionManager(object): def _get_info(self): '''A summary of server state.''' group_map = self._group_map() - result = self.chain_state.get_info() - result.update({ - 'version': electrumx.version, + return { 'closing': len([s for s in self.sessions if s.is_closing()]), + 'daemon': self.daemon.logged_url(), + 'daemon_height': self.daemon.cached_height(), + 'db_height': self.db.db_height, 'errors': sum(s.errors for s in self.sessions), 'groups': len(group_map), 'logged': len([s for s in self.sessions if s.log_me]), @@ -291,8 +294,8 @@ class SessionManager(object): 'subs': self._sub_count(), 'txs_sent': self.txs_sent, 'uptime': util.formatted_time(time.time() - self.start_time), - }) - return result + 'version': electrumx.version, + } def _session_data(self, for_log): '''Returned to the RPC 'sessions' call.''' @@ -329,6 +332,19 @@ class SessionManager(object): ]) return result + async def _electrum_and_raw_headers(self, height): + raw_header = await self.raw_header(height) + electrum_header = self.env.coin.electrum_header(raw_header, height) + return electrum_header, raw_header + + async def _refresh_hsub_results(self, height): + '''Refresh the cached header subscription responses to be for height, + and record that as notified_height. + ''' + electrum, raw = await self._electrum_and_raw_headers(height) + self.hsub_results = (electrum, {'hex': raw.hex(), 'height': height}) + self.notified_height = height + # --- LocalRPC command handlers async def rpc_add_peer(self, real_name): @@ -367,10 +383,10 @@ class SessionManager(object): '''Replace the daemon URL.''' daemon_url = daemon_url or self.env.daemon_url try: - daemon_url = self.chain_state.set_daemon_url(daemon_url) + self.daemon.set_url(daemon_url) except Exception as e: raise RPCError(BAD_REQUEST, f'an error occured: {e!r}') - return f'now using daemon at {daemon_url}' + return f'now using daemon at {self.daemon.logged_url()}' async def rpc_stop(self): '''Shut down the server cleanly.''' @@ -391,10 +407,54 @@ class SessionManager(object): async def rpc_query(self, items, limit): '''Return a list of data about server peers.''' - try: - return await self.chain_state.query(items, limit) - except Base58Error as e: - raise RPCError(BAD_REQUEST, e.args[0]) from None + coin = self.env.coin + db = self.db + lines = [] + + def arg_to_hashX(arg): + try: + script = bytes.fromhex(arg) + lines.append(f'Script: {arg}') + return coin.hashX_from_script(script) + except ValueError: + pass + + try: + hashX = coin.address_to_hashX(arg) + except Base58Error as e: + lines.append(e.args[0]) + return None + lines.append(f'Address: {arg}') + return hashX + + for arg in args: + hashX = arg_to_hashX(arg) + if not hashX: + continue + n = None + history = await db.limited_history(hashX, limit=limit) + for n, (tx_hash, height) in enumerate(history): + lines.append(f'History #{n:,d}: height {height:,d} ' + f'tx_hash {hash_to_hex_str(tx_hash)}') + if n is None: + lines.append('No history found') + n = None + utxos = await db.all_utxos(hashX) + for n, utxo in enumerate(utxos, start=1): + lines.append(f'UTXO #{n:,d}: tx_hash ' + f'{hash_to_hex_str(utxo.tx_hash)} ' + f'tx_pos {utxo.tx_pos:,d} height ' + f'{utxo.height:,d} value {utxo.value:,d}') + if n == limit: + break + if n is None: + lines.append('No UTXOs found') + + balance = sum(utxo.value for utxo in utxos) + lines.append(f'Balance: {coin.decimal_value(balance):,f} ' + f'{coin.SHORTNAME}') + + return lines async def rpc_sessions(self): '''Return statistics about connected sessions.''' @@ -406,7 +466,7 @@ class SessionManager(object): count: number of blocks to reorg ''' count = non_negative_integer(count) - if not self.chain_state.force_chain_reorg(count): + if not self.bp.force_chain_reorg(count): raise RPCError(BAD_REQUEST, 'still catching up with daemon') return f'scheduled a reorg of {count:,d} blocks' @@ -454,31 +514,57 @@ class SessionManager(object): '''The number of connections that we've sent something to.''' return len(self.sessions) + async def daemon_request(self, method, *args): + '''Catch a DaemonError and convert it to an RPCError.''' + try: + return await getattr(self.daemon, method)(*args) + except DaemonError as e: + raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None + + async def raw_header(self, height): + '''Return the binary header at the given height.''' + try: + return await self.db.raw_header(height) + except IndexError: + raise RPCError(BAD_REQUEST, f'height {height:,d} ' + 'out of range') from None + + async def electrum_header(self, height): + '''Return the deserialized header at the given height.''' + electrum_header, _ = await self._electrum_and_raw_headers(height) + return electrum_header + + async def broadcast_transaction(self, raw_tx): + hex_hash = await self.daemon.broadcast_transaction(raw_tx) + self.txs_sent += 1 + return hex_hash + async def limited_history(self, hashX): '''A caching layer.''' - hc = self._history_cache + hc = self.history_cache if hashX not in hc: # History DoS limit. Each element of history is about 99 # bytes when encoded as JSON. This limits resource usage # on bloated history requests, and uses a smaller divisor # so large requests are logged before refusing them. limit = self.env.max_send // 97 - hc[hashX] = await self.chain_state.limited_history(hashX, - limit=limit) + hc[hashX] = await self.db.limited_history(hashX, limit=limit) return hc[hashX] async def _notify_sessions(self, height, touched): '''Notify sessions about height changes and touched addresses.''' - # Invalidate our history cache for touched hashXs - if height != self._hc_height: - self._hc_height = height - hc = self._history_cache + height_changed = height != self.notified_height + if height_changed: + # Paranoia: a reorg could race and leave db_height lower + await self._refresh_hsub_results(min(height, self.db.db_height)) + # Invalidate our history cache for touched hashXs + hc = self.history_cache for hashX in set(hc).intersection(touched): del hc[hashX] async with TaskGroup() as group: for session in self.sessions: - await group.spawn(session.notify(height, touched)) + await group.spawn(session.notify(touched, height_changed)) def add_session(self, session): self.sessions.add(session) @@ -518,12 +604,12 @@ class SessionBase(ServerSession): MAX_CHUNK_SIZE = 2016 session_counter = itertools.count() - def __init__(self, session_mgr, chain_state, mempool, peer_mgr, kind): + def __init__(self, session_mgr, db, mempool, peer_mgr, kind): connection = JSONRPCConnection(JSONRPCAutoDetect) super().__init__(connection=connection) self.logger = util.class_logger(__name__, self.__class__.__name__) self.session_mgr = session_mgr - self.chain_state = chain_state + self.db = db self.mempool = mempool self.peer_mgr = peer_mgr self.kind = kind # 'RPC', 'TCP' etc. @@ -534,11 +620,12 @@ class SessionBase(ServerSession): self.txs_sent = 0 self.log_me = False self.bw_limit = self.env.bandwidth_limit + self.daemon_request = self.session_mgr.daemon_request # Hijack the connection so we can log messages self._receive_message_orig = self.connection.receive_message self.connection.receive_message = self.receive_message - async def notify(self, height, touched): + async def notify(self, touched, height_changed): pass def peer_address_str(self, *, for_log=True): @@ -623,14 +710,12 @@ class ElectrumX(SessionBase): super().__init__(*args, **kwargs) self.subscribe_headers = False self.subscribe_headers_raw = False - self.notified_height = None self.connection.max_response_size = self.env.max_send self.max_subs = self.env.max_session_subs self.hashX_subs = {} self.sv_seen = False self.mempool_statuses = {} self.set_request_handlers(self.PROTOCOL_MIN) - self.db_height = self.chain_state.db_height @classmethod def protocol_min_max_strings(cls): @@ -662,96 +747,58 @@ class ElectrumX(SessionBase): def protocol_version_string(self): return util.version_string(self.protocol_tuple) - async def daemon_request(self, method, *args): - '''Catch a DaemonError and convert it to an RPCError.''' - try: - return await self.chain_state.daemon_request(method, args) - except DaemonError as e: - raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None - def sub_count(self): return len(self.hashX_subs) - async def notify_touched(self, our_touched): - changed = {} - - for hashX in our_touched: - alias = self.hashX_subs[hashX] - status = await self.address_status(hashX) - changed[alias] = status - - # Check mempool hashXs - the status is a function of the - # confirmed state of other transactions. Note: we cannot - # iterate over mempool_statuses as it changes size. - for hashX in tuple(self.mempool_statuses): - # Items can be evicted whilst await-ing below; False - # ensures such hashXs are notified - old_status = self.mempool_statuses.get(hashX, False) - status = await self.address_status(hashX) - if status != old_status: - alias = self.hashX_subs[hashX] - changed[alias] = status - - for alias, status in changed.items(): - if len(alias) == 64: - method = 'blockchain.scripthash.subscribe' - else: - method = 'blockchain.address.subscribe' - await self.send_notification(method, (alias, status)) - - if changed: - es = '' if len(changed) == 1 else 'es' - self.logger.info('notified of {:,d} address{}' - .format(len(changed), es)) - - async def notify(self, height, touched): + async def notify(self, touched, height_changed): '''Notify the client about changes to touched addresses (from mempool updates or new blocks) and height. - - Return the set of addresses the session needs to be - asyncronously notified about. This can be empty if there are - possible mempool status updates. - - Returns None if nothing needs to be notified asynchronously. ''' - height_changed = height != self.notified_height - if height_changed: - self.notified_height = height - if self.subscribe_headers: - args = (await self.subscribe_headers_result(height), ) - await self.send_notification('blockchain.headers.subscribe', - args) + if height_changed and self.subscribe_headers: + args = (await self.subscribe_headers_result(), ) + await self.send_notification('blockchain.headers.subscribe', args) touched = touched.intersection(self.hashX_subs) if touched or (height_changed and self.mempool_statuses): - await self.notify_touched(touched) + changed = {} - async def raw_header(self, height): - '''Return the binary header at the given height.''' - try: - return await self.chain_state.raw_header(height) - except IndexError: - raise RPCError(BAD_REQUEST, f'height {height:,d} ' - 'out of range') from None + for hashX in touched: + alias = self.hashX_subs[hashX] + status = await self.address_status(hashX) + changed[alias] = status - async def electrum_header(self, height): - '''Return the deserialized header at the given height.''' - raw_header = await self.raw_header(height) - return self.coin.electrum_header(raw_header, height) + # Check mempool hashXs - the status is a function of the + # confirmed state of other transactions. Note: we cannot + # iterate over mempool_statuses as it changes size. + for hashX in tuple(self.mempool_statuses): + # Items can be evicted whilst await-ing status; False + # ensures such hashXs are notified + old_status = self.mempool_statuses.get(hashX, False) + status = await self.address_status(hashX) + if status != old_status: + alias = self.hashX_subs[hashX] + changed[alias] = status - async def subscribe_headers_result(self, height): - '''The result of a header subscription for the given height.''' - if self.subscribe_headers_raw: - raw_header = await self.raw_header(height) - return {'hex': raw_header.hex(), 'height': height} - return await self.electrum_header(height) + for alias, status in changed.items(): + if len(alias) == 64: + method = 'blockchain.scripthash.subscribe' + else: + method = 'blockchain.address.subscribe' + await self.send_notification(method, (alias, status)) + + if changed: + es = '' if len(changed) == 1 else 'es' + self.logger.info(f'notified of {len(changed):,d} address{es}') + + async def subscribe_headers_result(self): + '''The result of a header subscription or notification.''' + return self.session_mgr.hsub_results[self.subscribe_headers_raw] async def _headers_subscribe(self, raw): '''Subscribe to get headers of new blocks.''' - self.subscribe_headers = True self.subscribe_headers_raw = assert_boolean(raw) - self.notified_height = self.db_height() - return await self.subscribe_headers_result(self.notified_height) + self.subscribe_headers = True + return await self.subscribe_headers_result() async def headers_subscribe(self): '''Subscribe to get raw headers of new blocks.''' @@ -804,7 +851,7 @@ class ElectrumX(SessionBase): async def hashX_listunspent(self, hashX): '''Return the list of UTXOs of a script hash, including mempool effects.''' - utxos = await self.chain_state.all_utxos(hashX) + utxos = await self.db.all_utxos(hashX) utxos = sorted(utxos) utxos.extend(await self.mempool.unordered_UTXOs(hashX)) spends = await self.mempool.potential_spends(hashX) @@ -861,7 +908,7 @@ class ElectrumX(SessionBase): return await self.hashX_subscribe(hashX, address) async def get_balance(self, hashX): - utxos = await self.chain_state.all_utxos(hashX) + utxos = await self.db.all_utxos(hashX) confirmed = sum(utxo.value for utxo in utxos) unconfirmed = await self.mempool.balance_delta(hashX) return {'confirmed': confirmed, 'unconfirmed': unconfirmed} @@ -909,14 +956,14 @@ class ElectrumX(SessionBase): return await self.hashX_subscribe(hashX, scripthash) async def _merkle_proof(self, cp_height, height): - max_height = self.db_height() + max_height = self.db.db_height if not height <= cp_height <= max_height: raise RPCError(BAD_REQUEST, f'require header height {height:,d} <= ' f'cp_height {cp_height:,d} <= ' f'chain height {max_height:,d}') - branch, root = await self.chain_state.header_branch_and_root( - cp_height + 1, height) + branch, root = await self.db.header_branch_and_root(cp_height + 1, + height) return { 'branch': [hash_to_hex_str(elt) for elt in branch], 'root': hash_to_hex_str(root), @@ -927,7 +974,7 @@ class ElectrumX(SessionBase): dictionary with a merkle proof.''' height = non_negative_integer(height) cp_height = non_negative_integer(cp_height) - raw_header_hex = (await self.raw_header(height)).hex() + raw_header_hex = (await self.session_mgr.raw_header(height)).hex() if cp_height == 0: return raw_header_hex result = {'header': raw_header_hex} @@ -953,8 +1000,7 @@ class ElectrumX(SessionBase): max_size = self.MAX_CHUNK_SIZE count = min(count, max_size) - headers, count = await self.chain_state.read_headers(start_height, - count) + headers, count = await self.db.read_headers(start_height, count) result = {'hex': headers.hex(), 'count': count, 'max': max_size} if count and cp_height: last_height = start_height + count - 1 @@ -971,7 +1017,7 @@ class ElectrumX(SessionBase): index = non_negative_integer(index) size = self.coin.CHUNK_SIZE start_height = index * size - headers, _ = await self.chain_state.read_headers(start_height, size) + headers, _ = await self.db.read_headers(start_height, size) return headers.hex() async def block_get_header(self, height): @@ -979,7 +1025,7 @@ class ElectrumX(SessionBase): height: the header's height''' height = non_negative_integer(height) - return await self.electrum_header(height) + return await self.session_mgr.electrum_header(height) def is_tor(self): '''Try to detect if the connection is to a tor hidden service we are @@ -1042,7 +1088,7 @@ class ElectrumX(SessionBase): number: the number of blocks ''' number = non_negative_integer(number) - return await self.daemon_request('estimatefee', [number]) + return await self.daemon_request('estimatefee', number) async def ping(self): '''Serves as a connection keep-alive mechanism and for the client to @@ -1091,15 +1137,14 @@ class ElectrumX(SessionBase): raw_tx: the raw transaction as a hexadecimal string''' # This returns errors as JSON RPC errors, as is natural try: - tx_hash = await self.chain_state.broadcast_transaction(raw_tx) + hex_hash = await self.session_mgr.broadcast_transaction(raw_tx) self.txs_sent += 1 - self.session_mgr.txs_sent += 1 - self.logger.info('sent tx: {}'.format(tx_hash)) - return tx_hash + self.logger.info(f'sent tx: {hex_hash}') + return hex_hash except DaemonError as e: error, = e.args message = error['message'] - self.logger.info(f'sendrawtransaction: {message}') + self.logger.info(f'error sending transaction: {message}') raise RPCError(BAD_REQUEST, 'the transaction was rejected by ' f'network rules.\n\n{message}\n[{raw_tx}]') @@ -1135,7 +1180,7 @@ class ElectrumX(SessionBase): tx_pos: index of transaction in tx_hashes to create branch for ''' hashes = [hex_str_to_hash(hash) for hash in tx_hashes] - branch, root = self.chain_state.tx_branch_and_root(hashes, tx_pos) + branch, root = self.db.merkle.branch_and_root(hashes, tx_pos) branch = [hash_to_hex_str(hash) for hash in branch] return branch @@ -1264,9 +1309,9 @@ class DashElectrumX(ElectrumX): 'masternode.list': self.masternode_list }) - async def notify(self, height, touched): + async def notify(self, touched, height_changed): '''Notify the client about changes in masternode list.''' - await super().notify(height, touched) + await super().notify(touched, height_changed) for mn in self.mns: status = await self.daemon_request('masternode_list', ['status', mn]) @@ -1358,7 +1403,7 @@ class DashElectrumX(ElectrumX): # with the masternode information including the payment # position is returned. cache = self.session_mgr.mn_cache - if not cache or self.session_mgr.mn_cache_height != self.db_height(): + if not cache or self.session_mgr.mn_cache_height != self.db.db_height: full_mn_list = await self.daemon_request('masternode_list', ['full']) mn_payment_queue = get_masternode_payment_queue(full_mn_list) @@ -1386,7 +1431,7 @@ class DashElectrumX(ElectrumX): mn_list.append(mn_info) cache.clear() cache.extend(mn_list) - self.session_mgr.mn_cache_height = self.db_height() + self.session_mgr.mn_cache_height = self.db.db_height # If payees is an empty list the whole masternode list is returned if payees: diff --git a/setup.py b/setup.py index bfaae71..4f975f7 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ import setuptools -version = '1.8.3' +version = '1.8.4' setuptools.setup( name='electrumX', diff --git a/tests/server/test_daemon.py b/tests/server/test_daemon.py new file mode 100644 index 0000000..2a905a9 --- /dev/null +++ b/tests/server/test_daemon.py @@ -0,0 +1,489 @@ +import aiohttp +import asyncio +import json +import logging + +import pytest + +from aiorpcx import ( + JSONRPCv1, JSONRPCLoose, RPCError, ignore_after, + Request, Batch, +) +from electrumx.lib.coins import BitcoinCash, CoinError, Bitzeny +from electrumx.server.daemon import ( + Daemon, FakeEstimateFeeDaemon, DaemonError +) + + +coin = BitcoinCash + +# These should be full, canonical URLs +urls = ['http://rpc_user:rpc_pass@127.0.0.1:8332/', + 'http://rpc_user:rpc_pass@192.168.0.1:8332/'] + + +@pytest.fixture(params=[BitcoinCash, Bitzeny]) +def daemon(request): + coin = request.param + return coin.DAEMON(coin, ','.join(urls)) + + +class ResponseBase(object): + + def __init__(self, headers, status): + self.headers = headers + self.status = status + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + pass + + +class JSONResponse(ResponseBase): + + def __init__(self, result, msg_id, status=200): + super().__init__({'Content-Type': 'application/json'}, status) + self.result = result + self.msg_id = msg_id + + async def json(self): + if isinstance(self.msg_id, int): + message = JSONRPCv1.response_message(self.result, self.msg_id) + else: + parts = [JSONRPCv1.response_message(item, msg_id) + for item, msg_id in zip(self.result, self.msg_id)] + message = JSONRPCv1.batch_message_from_parts(parts) + return json.loads(message.decode()) + + +class HTMLResponse(ResponseBase): + + def __init__(self, text, reason, status): + super().__init__({'Content-Type': 'text/html; charset=ISO-8859-1'}, + status) + self._text = text + self.reason = reason + + async def text(self): + return self._text + + +class ClientSessionBase(object): + + def __enter__(self): + self.prior_class = aiohttp.ClientSession + aiohttp.ClientSession = lambda: self + + def __exit__(self, exc_type, exc_value, traceback): + aiohttp.ClientSession = self.prior_class + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + pass + + +class ClientSessionGood(ClientSessionBase): + '''Imitate aiohttp for testing purposes.''' + + def __init__(self, *triples): + self.triples = triples # each a (method, args, result) + self.count = 0 + self.expected_url = urls[0] + + def post(self, url, data=""): + assert url == self.expected_url + request, request_id = JSONRPCLoose.message_to_item(data.encode()) + method, args, result = self.triples[self.count] + self.count += 1 + if isinstance(request, Request): + assert request.method == method + assert request.args == args + return JSONResponse(result, request_id) + else: + assert isinstance(request, Batch) + for request, args in zip(request, args): + assert request.method == method + assert request.args == args + return JSONResponse(result, request_id) + + +class ClientSessionBadAuth(ClientSessionBase): + + def post(self, url, data=""): + return HTMLResponse('', 'Unauthorized', 401) + + +class ClientSessionWorkQueueFull(ClientSessionGood): + + def post(self, url, data=""): + self.post = super().post + return HTMLResponse('Work queue depth exceeded', + 'Internal server error', 500) + + +class ClientSessionNoConnection(ClientSessionGood): + + def __init__(self, *args): + self.args = args + + async def __aenter__(self): + aiohttp.ClientSession = lambda: ClientSessionGood(*self.args) + raise aiohttp.ClientConnectionError + + +class ClientSessionPostError(ClientSessionGood): + + def __init__(self, exception, *args): + self.exception = exception + self.args = args + + def post(self, url, data=""): + aiohttp.ClientSession = lambda: ClientSessionGood(*self.args) + raise self.exception + + +class ClientSessionFailover(ClientSessionGood): + + def post(self, url, data=""): + # If not failed over; simulate disconnecting + if url == self.expected_url: + raise aiohttp.ServerDisconnectedError + else: + self.expected_url = urls[1] + return super().post(url, data) + + +def in_caplog(caplog, message, count=1): + return sum(message in record.message + for record in caplog.records) == count + +# +# Tests +# + +def test_set_urls_bad(): + with pytest.raises(CoinError): + Daemon(coin, '') + with pytest.raises(CoinError): + Daemon(coin, 'a') + + +def test_set_urls_one(caplog): + with caplog.at_level(logging.INFO): + daemon = Daemon(coin, urls[0]) + assert daemon.current_url() == urls[0] + assert len(daemon.urls) == 1 + logged_url = daemon.logged_url() + assert logged_url == '127.0.0.1:8332/' + assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)') + + +def test_set_urls_two(caplog): + with caplog.at_level(logging.INFO): + daemon = Daemon(coin, ','.join(urls)) + assert daemon.current_url() == urls[0] + assert len(daemon.urls) == 2 + logged_url = daemon.logged_url() + assert logged_url == '127.0.0.1:8332/' + assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)') + assert in_caplog(caplog, 'daemon #2 at 192.168.0.1:8332') + + +def test_set_urls_short(): + no_prefix_urls = ['/'.join(part for part in url.split('/')[2:]) + for url in urls] + daemon = Daemon(coin, ','.join(no_prefix_urls)) + assert daemon.current_url() == urls[0] + assert len(daemon.urls) == 2 + + no_slash_urls = [url[:-1] for url in urls] + daemon = Daemon(coin, ','.join(no_slash_urls)) + assert daemon.current_url() == urls[0] + assert len(daemon.urls) == 2 + + no_port_urls = [url[:url.rfind(':')] for url in urls] + daemon = Daemon(coin, ','.join(no_port_urls)) + assert daemon.current_url() == urls[0] + assert len(daemon.urls) == 2 + + +def test_failover_good(caplog): + daemon = Daemon(coin, ','.join(urls)) + with caplog.at_level(logging.INFO): + result = daemon.failover() + assert result is True + assert daemon.current_url() == urls[1] + logged_url = daemon.logged_url() + assert in_caplog(caplog, f'failing over to {logged_url}') + # And again + result = daemon.failover() + assert result is True + assert daemon.current_url() == urls[0] + + +def test_failover_fail(caplog): + daemon = Daemon(coin, urls[0]) + with caplog.at_level(logging.INFO): + result = daemon.failover() + assert result is False + assert daemon.current_url() == urls[0] + assert not in_caplog(caplog, f'failing over') + + +@pytest.mark.asyncio +async def test_height(daemon): + assert daemon.cached_height() is None + height = 300 + with ClientSessionGood(('getblockcount', [], height)): + assert await daemon.height() == height + assert daemon.cached_height() == height + + +@pytest.mark.asyncio +async def test_broadcast_transaction(daemon): + raw_tx = 'deadbeef' + tx_hash = 'hash' + with ClientSessionGood(('sendrawtransaction', [raw_tx], tx_hash)): + assert await daemon.broadcast_transaction(raw_tx) == tx_hash + + +@pytest.mark.asyncio +async def test_relayfee(daemon): + response = {"relayfee": sats, "other:": "cruft"} + with ClientSessionGood(('getnetworkinfo', [], response)): + assert await daemon.getnetworkinfo() == response + + +@pytest.mark.asyncio +async def test_relayfee(daemon): + if isinstance(daemon, FakeEstimateFeeDaemon): + sats = daemon.coin.ESTIMATE_FEE + else: + sats = 2 + response = {"relayfee": sats, "other:": "cruft"} + with ClientSessionGood(('getnetworkinfo', [], response)): + assert await daemon.relayfee() == sats + + +@pytest.mark.asyncio +async def test_mempool_hashes(daemon): + hashes = ['hex_hash1', 'hex_hash2'] + with ClientSessionGood(('getrawmempool', [], hashes)): + assert await daemon.mempool_hashes() == hashes + + +@pytest.mark.asyncio +async def test_deserialised_block(daemon): + block_hash = 'block_hash' + result = {'some': 'mess'} + with ClientSessionGood(('getblock', [block_hash, True], result)): + assert await daemon.deserialised_block(block_hash) == result + + +@pytest.mark.asyncio +async def test_estimatefee(daemon): + method_not_found = RPCError(JSONRPCv1.METHOD_NOT_FOUND, 'nope') + if isinstance(daemon, FakeEstimateFeeDaemon): + result = daemon.coin.ESTIMATE_FEE + else: + result = -1 + with ClientSessionGood( + ('estimatesmartfee', [], method_not_found), + ('estimatefee', [2], result) + ): + assert await daemon.estimatefee(2) == result + + +@pytest.mark.asyncio +async def test_estimatefee_smart(daemon): + bad_args = RPCError(JSONRPCv1.INVALID_ARGS, 'bad args') + if isinstance(daemon, FakeEstimateFeeDaemon): + return + rate = 0.0002 + result = {'feerate': rate} + with ClientSessionGood( + ('estimatesmartfee', [], bad_args), + ('estimatesmartfee', [2], result) + ): + assert await daemon.estimatefee(2) == rate + + # Test the rpc_available_cache is used + with ClientSessionGood(('estimatesmartfee', [2], result)): + assert await daemon.estimatefee(2) == rate + + +@pytest.mark.asyncio +async def test_getrawtransaction(daemon): + hex_hash = 'deadbeef' + simple = 'tx_in_hex' + verbose = {'hex': hex_hash, 'other': 'cruft'} + # Test False is converted to 0 - old daemon's reject False + with ClientSessionGood(('getrawtransaction', [hex_hash, 0], simple)): + assert await daemon.getrawtransaction(hex_hash) == simple + + # Test True is converted to 1 + with ClientSessionGood(('getrawtransaction', [hex_hash, 1], verbose)): + assert await daemon.getrawtransaction( + hex_hash, True) == verbose + + +# Batch tests + +@pytest.mark.asyncio +async def test_empty_send(daemon): + first = 5 + count = 0 + with ClientSessionGood(('getblockhash', [], [])): + assert await daemon.block_hex_hashes(first, count) == [] + + +@pytest.mark.asyncio +async def test_block_hex_hashes(daemon): + first = 5 + count = 3 + hashes = [f'hex_hash{n}' for n in range(count)] + with ClientSessionGood(('getblockhash', + [[n] for n in range(first, first + count)], + hashes)): + assert await daemon.block_hex_hashes(first, count) == hashes + + +@pytest.mark.asyncio +async def test_raw_blocks(daemon): + count = 3 + hex_hashes = [f'hex_hash{n}' for n in range(count)] + args_list = [[hex_hash, False] for hex_hash in hex_hashes] + iterable = (hex_hash for hex_hash in hex_hashes) + blocks = ["00", "019a", "02fe"] + blocks_raw = [bytes.fromhex(block) for block in blocks] + with ClientSessionGood(('getblock', args_list, blocks)): + assert await daemon.raw_blocks(iterable) == blocks_raw + + +@pytest.mark.asyncio +async def test_get_raw_transactions(daemon): + hex_hashes = ['deadbeef0', 'deadbeef1'] + args_list = [[hex_hash, 0] for hex_hash in hex_hashes] + raw_txs_hex = ['fffefdfc', '0a0b0c0d'] + raw_txs = [bytes.fromhex(raw_tx) for raw_tx in raw_txs_hex] + # Test 0 - old daemon's reject False + with ClientSessionGood(('getrawtransaction', args_list, raw_txs_hex)): + assert await daemon.getrawtransactions(hex_hashes) == raw_txs + + # Test one error + tx_not_found = RPCError(-1, 'some error message') + results = ['ff0b7d', tx_not_found] + raw_txs = [bytes.fromhex(results[0]), None] + with ClientSessionGood(('getrawtransaction', args_list, results)): + assert await daemon.getrawtransactions(hex_hashes) == raw_txs + + +# Other tests + +@pytest.mark.asyncio +async def test_bad_auth(daemon, caplog): + with pytest.raises(DaemonError) as e: + with ClientSessionBadAuth(): + await daemon.height() + + assert "Unauthorized" in e.value.args[0] + assert in_caplog(caplog, "Unauthorized") + + +@pytest.mark.asyncio +async def test_workqueue_depth(daemon, caplog): + daemon.init_retry = 0.01 + height = 125 + with caplog.at_level(logging.INFO): + with ClientSessionWorkQueueFull(('getblockcount', [], height)): + await daemon.height() == height + + assert in_caplog(caplog, "work queue full") + assert in_caplog(caplog, "running normally") + + +@pytest.mark.asyncio +async def test_connection_error(daemon, caplog): + height = 100 + daemon.init_retry = 0.01 + with caplog.at_level(logging.INFO): + with ClientSessionNoConnection(('getblockcount', [], height)): + await daemon.height() == height + + assert in_caplog(caplog, "connection problem - is your daemon running?") + assert in_caplog(caplog, "connection restored") + + +@pytest.mark.asyncio +async def test_timeout_error(daemon, caplog): + height = 100 + daemon.init_retry = 0.01 + with caplog.at_level(logging.INFO): + with ClientSessionPostError(asyncio.TimeoutError, + ('getblockcount', [], height)): + await daemon.height() == height + + assert in_caplog(caplog, "timeout error") + + +@pytest.mark.asyncio +async def test_disconnected(daemon, caplog): + height = 100 + daemon.init_retry = 0.01 + with caplog.at_level(logging.INFO): + with ClientSessionPostError(aiohttp.ServerDisconnectedError, + ('getblockcount', [], height)): + await daemon.height() == height + + assert in_caplog(caplog, "disconnected") + assert in_caplog(caplog, "connection restored") + + +@pytest.mark.asyncio +async def test_warming_up(daemon, caplog): + warming_up = RPCError(-28, 'reading block index') + height = 100 + daemon.init_retry = 0.01 + with caplog.at_level(logging.INFO): + with ClientSessionGood( + ('getblockcount', [], warming_up), + ('getblockcount', [], height) + ): + assert await daemon.height() == height + + assert in_caplog(caplog, "starting up checking blocks") + assert in_caplog(caplog, "running normally") + + +@pytest.mark.asyncio +async def test_warming_up_batch(daemon, caplog): + warming_up = RPCError(-28, 'reading block index') + first = 5 + count = 1 + daemon.init_retry = 0.01 + hashes = ['hex_hash5'] + with caplog.at_level(logging.INFO): + with ClientSessionGood(('getblockhash', [[first]], [warming_up]), + ('getblockhash', [[first]], hashes)): + assert await daemon.block_hex_hashes(first, count) == hashes + + assert in_caplog(caplog, "starting up checking blocks") + assert in_caplog(caplog, "running normally") + + +@pytest.mark.asyncio +async def test_failover(daemon, caplog): + height = 100 + daemon.init_retry = 0.01 + daemon.max_retry = 0.04 + with caplog.at_level(logging.INFO): + with ClientSessionFailover(('getblockcount', [], height)): + await daemon.height() == height + + assert in_caplog(caplog, "disconnected", 3) + assert in_caplog(caplog, "failing over") + assert in_caplog(caplog, "connection restored")