Merge branch 'devel'

This commit is contained in:
Neil Booth 2018-08-14 12:58:52 +09:00
commit f0a2f128dc
13 changed files with 775 additions and 350 deletions

View File

@ -15,7 +15,7 @@
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
VERSION="ElectrumX 1.8.3"
VERSION="ElectrumX 1.8.4"
# -- Project information -----------------------------------------------------

View File

@ -1,4 +1,4 @@
version = 'ElectrumX 1.8.3'
version = 'ElectrumX 1.8.4-dev'
version_short = version.split()[-1]
from electrumx.server.controller import Controller

View File

@ -114,10 +114,6 @@ class Coin(object):
url = 'http://' + url
return url + '/'
@classmethod
def daemon_urls(cls, urls):
return [cls.sanitize_url(url) for url in urls.split(',')]
@classmethod
def genesis_block(cls, block):
'''Check the Genesis block is the right one for this coin.

View File

@ -13,7 +13,7 @@ import sys
import time
from functools import partial
from aiorpcx import TaskGroup
from aiorpcx import spawn
from electrumx.lib.util import class_logger
@ -93,12 +93,11 @@ class ServerBase(object):
loop.set_exception_handler(self.on_exception)
shutdown_event = asyncio.Event()
async with TaskGroup() as group:
server_task = await group.spawn(self.serve(shutdown_event))
# Wait for shutdown, log on receipt of the event
await shutdown_event.wait()
self.logger.info('shutting down')
server_task.cancel()
server_task = await spawn(self.serve(shutdown_event))
# Wait for shutdown, log on receipt of the event
await shutdown_event.wait()
self.logger.info('shutting down')
server_task.cancel()
# Prevent some silly logs
await asyncio.sleep(0.01)

View File

@ -650,10 +650,7 @@ class BlockProcessor(object):
could be lost.
'''
self._caught_up_event = caught_up_event
async with TaskGroup() as group:
await group.spawn(self._first_open_dbs())
# Ensure cached_height is set
await group.spawn(self.daemon.height())
await self._first_open_dbs()
try:
async with TaskGroup() as group:
await group.spawn(self.prefetcher.main_loop(self.height))

View File

@ -1,102 +0,0 @@
# Copyright (c) 2016-2018, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
from electrumx.lib.hash import hash_to_hex_str
class ChainState(object):
'''Used as an interface by servers to request information about
blocks, transaction history, UTXOs and the mempool.
'''
def __init__(self, env, db, daemon, bp):
self._env = env
self._db = db
self._daemon = daemon
# External interface pass-throughs for session.py
self.force_chain_reorg = bp.force_chain_reorg
self.tx_branch_and_root = db.merkle.branch_and_root
self.read_headers = db.read_headers
self.all_utxos = db.all_utxos
self.limited_history = db.limited_history
self.header_branch_and_root = db.header_branch_and_root
async def broadcast_transaction(self, raw_tx):
return await self._daemon.sendrawtransaction([raw_tx])
async def daemon_request(self, method, args=()):
return await getattr(self._daemon, method)(*args)
def db_height(self):
return self._db.db_height
def get_info(self):
'''Chain state info for LocalRPC and logs.'''
return {
'daemon': self._daemon.logged_url(),
'daemon_height': self._daemon.cached_height(),
'db_height': self.db_height(),
}
async def raw_header(self, height):
'''Return the binary header at the given height.'''
header, n = await self.read_headers(height, 1)
if n != 1:
raise IndexError(f'height {height:,d} out of range')
return header
def set_daemon_url(self, daemon_url):
self._daemon.set_urls(self._env.coin.daemon_urls(daemon_url))
return self._daemon.logged_url()
async def query(self, args, limit):
coin = self._env.coin
db = self._db
lines = []
def arg_to_hashX(arg):
try:
script = bytes.fromhex(arg)
lines.append(f'Script: {arg}')
return coin.hashX_from_script(script)
except ValueError:
pass
hashX = coin.address_to_hashX(arg)
lines.append(f'Address: {arg}')
return hashX
for arg in args:
hashX = arg_to_hashX(arg)
if not hashX:
continue
n = None
history = await db.limited_history(hashX, limit=limit)
for n, (tx_hash, height) in enumerate(history):
lines.append(f'History #{n:,d}: height {height:,d} '
f'tx_hash {hash_to_hex_str(tx_hash)}')
if n is None:
lines.append('No history found')
n = None
utxos = await db.all_utxos(hashX)
for n, utxo in enumerate(utxos, start=1):
lines.append(f'UTXO #{n:,d}: tx_hash '
f'{hash_to_hex_str(utxo.tx_hash)} '
f'tx_pos {utxo.tx_pos:,d} height '
f'{utxo.height:,d} value {utxo.value:,d}')
if n == limit:
break
if n is None:
lines.append('No UTXOs found')
balance = sum(utxo.value for utxo in utxos)
lines.append(f'Balance: {coin.decimal_value(balance):,f} '
f'{coin.SHORTNAME}')
return lines

View File

@ -12,7 +12,6 @@ from aiorpcx import _version as aiorpcx_version, TaskGroup
import electrumx
from electrumx.lib.server_base import ServerBase
from electrumx.lib.util import version_string
from electrumx.server.chain_state import ChainState
from electrumx.server.db import DB
from electrumx.server.mempool import MemPool, MemPoolAPI
from electrumx.server.session import SessionManager
@ -93,11 +92,12 @@ class Controller(ServerBase):
self.logger.info(f'reorg limit is {env.reorg_limit:,d} blocks')
notifications = Notifications()
daemon = env.coin.DAEMON(env)
db = DB(env)
Daemon = env.coin.DAEMON
BlockProcessor = env.coin.BLOCK_PROCESSOR
daemon = Daemon(env.coin, env.daemon_url)
db = DB(env)
bp = BlockProcessor(env, db, daemon, notifications)
chain_state = ChainState(env, db, daemon, bp)
# Set ourselves up to implement the MemPoolAPI
self.height = daemon.height
@ -109,13 +109,16 @@ class Controller(ServerBase):
MemPoolAPI.register(Controller)
mempool = MemPool(env.coin, self)
session_mgr = SessionManager(env, chain_state, mempool,
session_mgr = SessionManager(env, db, bp, daemon, mempool,
notifications, shutdown_event)
# Test daemon authentication, and also ensure it has a cached
# height. Do this before entering the task group.
await daemon.height()
caught_up_event = Event()
serve_externally_event = Event()
synchronized_event = Event()
async with TaskGroup() as group:
await group.spawn(session_mgr.serve(serve_externally_event))
await group.spawn(bp.fetch_and_process_blocks(caught_up_event))

View File

@ -9,6 +9,7 @@
daemon.'''
import asyncio
import itertools
import json
import time
from calendar import timegm
@ -28,48 +29,53 @@ class DaemonError(Exception):
'''Raised when the daemon returns an error in its results.'''
class WarmingUpError(Exception):
'''Internal - when the daemon is warming up.'''
class WorkQueueFullError(Exception):
'''Internal - when the daemon's work queue is full.'''
class Daemon(object):
'''Handles connections to a daemon at the given URL.'''
WARMING_UP = -28
RPC_MISC_ERROR = -1
id_counter = itertools.count()
class DaemonWarmingUpError(Exception):
'''Raised when the daemon returns an error in its results.'''
def __init__(self, env):
def __init__(self, coin, url, max_workqueue=10, init_retry=0.25,
max_retry=4.0):
self.coin = coin
self.logger = class_logger(__name__, self.__class__.__name__)
self.coin = env.coin
self.set_urls(env.coin.daemon_urls(env.daemon_url))
self._height = None
self.set_url(url)
# Limit concurrent RPC calls to this number.
# See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16
self.workqueue_semaphore = asyncio.Semaphore(value=10)
self.down = False
self.last_error_time = 0
self.req_id = 0
self._available_rpcs = {} # caches results for _is_rpc_available()
self.workqueue_semaphore = asyncio.Semaphore(value=max_workqueue)
self.init_retry = init_retry
self.max_retry = max_retry
self._height = None
self.available_rpcs = {}
def next_req_id(self):
'''Retrns the next request ID.'''
self.req_id += 1
return self.req_id
def set_urls(self, urls):
def set_url(self, url):
'''Set the URLS to the given list, and switch to the first one.'''
if not urls:
raise DaemonError('no daemon URLs provided')
self.urls = urls
self.url_index = 0
urls = url.split(',')
urls = [self.coin.sanitize_url(url) for url in urls]
for n, url in enumerate(urls):
self.logger.info('daemon #{:d} at {}{}'
.format(n + 1, self.logged_url(url),
'' if n else ' (current)'))
status = '' if n else ' (current)'
logged_url = self.logged_url(url)
self.logger.info(f'daemon #{n + 1} at {logged_url}{status}')
self.url_index = 0
self.urls = urls
def url(self):
def current_url(self):
'''Returns the current daemon URL.'''
return self.urls[self.url_index]
def logged_url(self, url=None):
'''The host and port part, for logging.'''
url = url or self.current_url()
return url[url.rindex('@') + 1:]
def failover(self):
'''Call to fail-over to the next daemon URL.
@ -77,7 +83,7 @@ class Daemon(object):
'''
if len(self.urls) > 1:
self.url_index = (self.url_index + 1) % len(self.urls)
self.logger.info('failing over to {}'.format(self.logged_url()))
self.logger.info(f'failing over to {self.logged_url()}')
return True
return False
@ -88,13 +94,17 @@ class Daemon(object):
async def _send_data(self, data):
async with self.workqueue_semaphore:
async with self.client_session() as session:
async with session.post(self.url(), data=data) as resp:
# If bitcoind can't find a tx, for some reason
# it returns 500 but fills out the JSON.
# Should still return 200 IMO.
if resp.status in (200, 404, 500):
async with session.post(self.current_url(), data=data) as resp:
kind = resp.headers.get('Content-Type', None)
if kind == 'application/json':
return await resp.json()
return (resp.status, resp.reason)
# bitcoind's HTTP protocol "handling" is a bad joke
text = await resp.text()
if 'Work queue depth exceeded' in text:
raise WorkQueueFullError
text = text.strip() or resp.reason
self.logger.error(text)
raise DaemonError(text)
async def _send(self, payload, processor):
'''Send a payload to be converted to JSON.
@ -103,54 +113,42 @@ class Daemon(object):
are raise through DaemonError.
'''
def log_error(error):
self.down = True
nonlocal last_error_log, retry
now = time.time()
prior_time = self.last_error_time
if now - prior_time > 60:
self.last_error_time = now
if prior_time and self.failover():
secs = 0
else:
self.logger.error('{} Retrying occasionally...'
.format(error))
if now - last_error_log > 60:
last_error_time = now
self.logger.error(f'{error} Retrying occasionally...')
if retry == self.max_retry and self.failover():
retry = 0
on_good_message = None
last_error_log = 0
data = json.dumps(payload)
secs = 1
max_secs = 4
retry = self.init_retry
while True:
try:
result = await self._send_data(data)
if not isinstance(result, tuple):
result = processor(result)
if self.down:
self.down = False
self.last_error_time = 0
self.logger.info('connection restored')
return result
log_error('HTTP error code {:d}: {}'
.format(result[0], result[1]))
result = processor(result)
if on_good_message:
self.logger.info(on_good_message)
return result
except asyncio.TimeoutError:
log_error('timeout error.')
except aiohttp.ServerDisconnectedError:
log_error('disconnected.')
except aiohttp.ClientPayloadError:
log_error('payload encoding error.')
on_good_message = 'connection restored'
except aiohttp.ClientConnectionError:
log_error('connection problem - is your daemon running?')
except self.DaemonWarmingUpError:
on_good_message = 'connection restored'
except WarmingUpError:
log_error('starting up checking blocks.')
except (asyncio.CancelledError, DaemonError):
raise
except Exception as e:
self.logger.exception(f'uncaught exception: {e}')
on_good_message = 'running normally'
except WorkQueueFullError:
log_error('work queue full.')
on_good_message = 'running normally'
await asyncio.sleep(secs)
secs = min(max_secs, secs * 2, 1)
def logged_url(self, url=None):
'''The host and port part, for logging.'''
url = url or self.url()
return url[url.rindex('@') + 1:]
await asyncio.sleep(retry)
retry = max(min(self.max_retry, retry * 2), self.init_retry)
async def _send_single(self, method, params=None):
'''Send a single request to the daemon.'''
@ -159,10 +157,10 @@ class Daemon(object):
if not err:
return result['result']
if err.get('code') == self.WARMING_UP:
raise self.DaemonWarmingUpError
raise WarmingUpError
raise DaemonError(err)
payload = {'method': method, 'id': self.next_req_id()}
payload = {'method': method, 'id': next(self.id_counter)}
if params:
payload['params'] = params
return await self._send(payload, processor)
@ -176,12 +174,12 @@ class Daemon(object):
def processor(result):
errs = [item['error'] for item in result if item['error']]
if any(err.get('code') == self.WARMING_UP for err in errs):
raise self.DaemonWarmingUpError
raise WarmingUpError
if not errs or replace_errs:
return [item['result'] for item in result]
raise DaemonError(errs)
payload = [{'method': method, 'params': p, 'id': self.next_req_id()}
payload = [{'method': method, 'params': p, 'id': next(self.id_counter)}
for p in params_iterable]
if payload:
return await self._send(payload, processor)
@ -192,27 +190,16 @@ class Daemon(object):
Results are cached and the daemon will generally not be queried with
the same method more than once.'''
available = self._available_rpcs.get(method, None)
available = self.available_rpcs.get(method)
if available is None:
available = True
try:
await self._send_single(method)
available = True
except DaemonError as e:
err = e.args[0]
error_code = err.get("code")
if error_code == JSONRPC.METHOD_NOT_FOUND:
available = False
elif error_code == self.RPC_MISC_ERROR:
# method found but exception was thrown in command handling
# probably because we did not provide arguments
available = True
else:
self.logger.warning('error (code {:d}: {}) when testing '
'RPC availability of method {}'
.format(error_code, err.get("message"),
method))
available = False
self._available_rpcs[method] = available
available = error_code != JSONRPC.METHOD_NOT_FOUND
self.available_rpcs[method] = available
return available
async def block_hex_hashes(self, first, count):
@ -235,12 +222,16 @@ class Daemon(object):
'''Update our record of the daemon's mempool hashes.'''
return await self._send_single('getrawmempool')
async def estimatefee(self, params):
'''Return the fee estimate for the given parameters.'''
async def estimatefee(self, block_count):
'''Return the fee estimate for the block count. Units are whole
currency units per KB, e.g. 0.00000995, or -1 if no estimate
is available.
'''
args = (block_count, )
if await self._is_rpc_available('estimatesmartfee'):
estimate = await self._send_single('estimatesmartfee', params)
estimate = await self._send_single('estimatesmartfee', args)
return estimate.get('feerate', -1)
return await self._send_single('estimatefee', params)
return await self._send_single('estimatefee', args)
async def getnetworkinfo(self):
'''Return the result of the 'getnetworkinfo' RPC call.'''
@ -268,9 +259,9 @@ class Daemon(object):
# Convert hex strings to bytes
return [hex_to_bytes(tx) if tx else None for tx in txs]
async def sendrawtransaction(self, params):
async def broadcast_transaction(self, raw_tx):
'''Broadcast a transaction to the network.'''
return await self._send_single('sendrawtransaction', params)
return await self._send_single('sendrawtransaction', (raw_tx, ))
async def height(self):
'''Query the daemon for its current height.'''
@ -299,7 +290,7 @@ class FakeEstimateFeeDaemon(Daemon):
'''Daemon that simulates estimatefee and relayfee RPC calls. Coin that
wants to use this daemon must define ESTIMATE_FEE & RELAY_FEE'''
async def estimatefee(self, params):
async def estimatefee(self, block_count):
'''Return the fee estimate for the given parameters.'''
return self.coin.ESTIMATE_FEE

View File

@ -370,6 +370,13 @@ class DB(object):
# Truncate header_mc: header count is 1 more than the height.
self.header_mc.truncate(height + 1)
async def raw_header(self, height):
'''Return the binary header at the given height.'''
header, n = await self.read_headers(height, 1)
if n != 1:
raise IndexError(f'height {height:,d} out of range')
return header
async def read_headers(self, start_height, count):
'''Requires start_height >= 0, count >= 0. Reads as many headers as
are available starting at start_height up to count. This

View File

@ -55,12 +55,12 @@ class PeerManager(object):
Attempts to maintain a connection with up to 8 peers.
Issues a 'peers.subscribe' RPC to them and tells them our data.
'''
def __init__(self, env, chain_state):
def __init__(self, env, db):
self.logger = class_logger(__name__, self.__class__.__name__)
# Initialise the Peer class
Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS
self.env = env
self.chain_state = chain_state
self.db = db
# Our clearnet and Tor Peers, if any
sclass = env.coin.SESSIONCLS
@ -300,7 +300,7 @@ class PeerManager(object):
result = await session.send_request(message)
assert_good(message, result, dict)
our_height = self.chain_state.db_height()
our_height = self.db.db_height
if ptuple < (1, 3):
their_height = result.get('block_height')
else:
@ -313,7 +313,7 @@ class PeerManager(object):
# Check prior header too in case of hard fork.
check_height = min(our_height, their_height)
raw_header = await self.chain_state.raw_header(check_height)
raw_header = await self.db.raw_header(check_height)
if ptuple >= (1, 4):
ours = raw_header.hex()
message = 'blockchain.block.header'

View File

@ -109,13 +109,15 @@ class SessionManager(object):
CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4)
def __init__(self, env, chain_state, mempool, notifications,
def __init__(self, env, db, bp, daemon, mempool, notifications,
shutdown_event):
env.max_send = max(350000, env.max_send)
self.env = env
self.chain_state = chain_state
self.db = db
self.bp = bp
self.daemon = daemon
self.mempool = mempool
self.peer_mgr = PeerManager(env, chain_state)
self.peer_mgr = PeerManager(env, db)
self.shutdown_event = shutdown_event
self.logger = util.class_logger(__name__, self.__class__.__name__)
self.servers = {}
@ -127,8 +129,8 @@ class SessionManager(object):
self.state = self.CATCHING_UP
self.txs_sent = 0
self.start_time = time.time()
self._history_cache = pylru.lrucache(256)
self._hc_height = 0
self.history_cache = pylru.lrucache(256)
self.notified_height = None
# Cache some idea of room to avoid recounting on each subscription
self.subs_room = 0
# Masternode stuff only for such coins
@ -152,7 +154,7 @@ class SessionManager(object):
protocol_class = LocalRPC
else:
protocol_class = self.env.coin.SESSIONCLS
protocol_factory = partial(protocol_class, self, self.chain_state,
protocol_factory = partial(protocol_class, self, self.db,
self.mempool, self.peer_mgr, kind)
server = loop.create_server(protocol_factory, *args, **kw_args)
@ -276,10 +278,11 @@ class SessionManager(object):
def _get_info(self):
'''A summary of server state.'''
group_map = self._group_map()
result = self.chain_state.get_info()
result.update({
'version': electrumx.version,
return {
'closing': len([s for s in self.sessions if s.is_closing()]),
'daemon': self.daemon.logged_url(),
'daemon_height': self.daemon.cached_height(),
'db_height': self.db.db_height,
'errors': sum(s.errors for s in self.sessions),
'groups': len(group_map),
'logged': len([s for s in self.sessions if s.log_me]),
@ -291,8 +294,8 @@ class SessionManager(object):
'subs': self._sub_count(),
'txs_sent': self.txs_sent,
'uptime': util.formatted_time(time.time() - self.start_time),
})
return result
'version': electrumx.version,
}
def _session_data(self, for_log):
'''Returned to the RPC 'sessions' call.'''
@ -329,6 +332,19 @@ class SessionManager(object):
])
return result
async def _electrum_and_raw_headers(self, height):
raw_header = await self.raw_header(height)
electrum_header = self.env.coin.electrum_header(raw_header, height)
return electrum_header, raw_header
async def _refresh_hsub_results(self, height):
'''Refresh the cached header subscription responses to be for height,
and record that as notified_height.
'''
electrum, raw = await self._electrum_and_raw_headers(height)
self.hsub_results = (electrum, {'hex': raw.hex(), 'height': height})
self.notified_height = height
# --- LocalRPC command handlers
async def rpc_add_peer(self, real_name):
@ -367,10 +383,10 @@ class SessionManager(object):
'''Replace the daemon URL.'''
daemon_url = daemon_url or self.env.daemon_url
try:
daemon_url = self.chain_state.set_daemon_url(daemon_url)
self.daemon.set_url(daemon_url)
except Exception as e:
raise RPCError(BAD_REQUEST, f'an error occured: {e!r}')
return f'now using daemon at {daemon_url}'
return f'now using daemon at {self.daemon.logged_url()}'
async def rpc_stop(self):
'''Shut down the server cleanly.'''
@ -391,10 +407,54 @@ class SessionManager(object):
async def rpc_query(self, items, limit):
'''Return a list of data about server peers.'''
try:
return await self.chain_state.query(items, limit)
except Base58Error as e:
raise RPCError(BAD_REQUEST, e.args[0]) from None
coin = self.env.coin
db = self.db
lines = []
def arg_to_hashX(arg):
try:
script = bytes.fromhex(arg)
lines.append(f'Script: {arg}')
return coin.hashX_from_script(script)
except ValueError:
pass
try:
hashX = coin.address_to_hashX(arg)
except Base58Error as e:
lines.append(e.args[0])
return None
lines.append(f'Address: {arg}')
return hashX
for arg in args:
hashX = arg_to_hashX(arg)
if not hashX:
continue
n = None
history = await db.limited_history(hashX, limit=limit)
for n, (tx_hash, height) in enumerate(history):
lines.append(f'History #{n:,d}: height {height:,d} '
f'tx_hash {hash_to_hex_str(tx_hash)}')
if n is None:
lines.append('No history found')
n = None
utxos = await db.all_utxos(hashX)
for n, utxo in enumerate(utxos, start=1):
lines.append(f'UTXO #{n:,d}: tx_hash '
f'{hash_to_hex_str(utxo.tx_hash)} '
f'tx_pos {utxo.tx_pos:,d} height '
f'{utxo.height:,d} value {utxo.value:,d}')
if n == limit:
break
if n is None:
lines.append('No UTXOs found')
balance = sum(utxo.value for utxo in utxos)
lines.append(f'Balance: {coin.decimal_value(balance):,f} '
f'{coin.SHORTNAME}')
return lines
async def rpc_sessions(self):
'''Return statistics about connected sessions.'''
@ -406,7 +466,7 @@ class SessionManager(object):
count: number of blocks to reorg
'''
count = non_negative_integer(count)
if not self.chain_state.force_chain_reorg(count):
if not self.bp.force_chain_reorg(count):
raise RPCError(BAD_REQUEST, 'still catching up with daemon')
return f'scheduled a reorg of {count:,d} blocks'
@ -454,31 +514,57 @@ class SessionManager(object):
'''The number of connections that we've sent something to.'''
return len(self.sessions)
async def daemon_request(self, method, *args):
'''Catch a DaemonError and convert it to an RPCError.'''
try:
return await getattr(self.daemon, method)(*args)
except DaemonError as e:
raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None
async def raw_header(self, height):
'''Return the binary header at the given height.'''
try:
return await self.db.raw_header(height)
except IndexError:
raise RPCError(BAD_REQUEST, f'height {height:,d} '
'out of range') from None
async def electrum_header(self, height):
'''Return the deserialized header at the given height.'''
electrum_header, _ = await self._electrum_and_raw_headers(height)
return electrum_header
async def broadcast_transaction(self, raw_tx):
hex_hash = await self.daemon.broadcast_transaction(raw_tx)
self.txs_sent += 1
return hex_hash
async def limited_history(self, hashX):
'''A caching layer.'''
hc = self._history_cache
hc = self.history_cache
if hashX not in hc:
# History DoS limit. Each element of history is about 99
# bytes when encoded as JSON. This limits resource usage
# on bloated history requests, and uses a smaller divisor
# so large requests are logged before refusing them.
limit = self.env.max_send // 97
hc[hashX] = await self.chain_state.limited_history(hashX,
limit=limit)
hc[hashX] = await self.db.limited_history(hashX, limit=limit)
return hc[hashX]
async def _notify_sessions(self, height, touched):
'''Notify sessions about height changes and touched addresses.'''
# Invalidate our history cache for touched hashXs
if height != self._hc_height:
self._hc_height = height
hc = self._history_cache
height_changed = height != self.notified_height
if height_changed:
# Paranoia: a reorg could race and leave db_height lower
await self._refresh_hsub_results(min(height, self.db.db_height))
# Invalidate our history cache for touched hashXs
hc = self.history_cache
for hashX in set(hc).intersection(touched):
del hc[hashX]
async with TaskGroup() as group:
for session in self.sessions:
await group.spawn(session.notify(height, touched))
await group.spawn(session.notify(touched, height_changed))
def add_session(self, session):
self.sessions.add(session)
@ -518,12 +604,12 @@ class SessionBase(ServerSession):
MAX_CHUNK_SIZE = 2016
session_counter = itertools.count()
def __init__(self, session_mgr, chain_state, mempool, peer_mgr, kind):
def __init__(self, session_mgr, db, mempool, peer_mgr, kind):
connection = JSONRPCConnection(JSONRPCAutoDetect)
super().__init__(connection=connection)
self.logger = util.class_logger(__name__, self.__class__.__name__)
self.session_mgr = session_mgr
self.chain_state = chain_state
self.db = db
self.mempool = mempool
self.peer_mgr = peer_mgr
self.kind = kind # 'RPC', 'TCP' etc.
@ -534,11 +620,12 @@ class SessionBase(ServerSession):
self.txs_sent = 0
self.log_me = False
self.bw_limit = self.env.bandwidth_limit
self.daemon_request = self.session_mgr.daemon_request
# Hijack the connection so we can log messages
self._receive_message_orig = self.connection.receive_message
self.connection.receive_message = self.receive_message
async def notify(self, height, touched):
async def notify(self, touched, height_changed):
pass
def peer_address_str(self, *, for_log=True):
@ -623,14 +710,12 @@ class ElectrumX(SessionBase):
super().__init__(*args, **kwargs)
self.subscribe_headers = False
self.subscribe_headers_raw = False
self.notified_height = None
self.connection.max_response_size = self.env.max_send
self.max_subs = self.env.max_session_subs
self.hashX_subs = {}
self.sv_seen = False
self.mempool_statuses = {}
self.set_request_handlers(self.PROTOCOL_MIN)
self.db_height = self.chain_state.db_height
@classmethod
def protocol_min_max_strings(cls):
@ -662,96 +747,58 @@ class ElectrumX(SessionBase):
def protocol_version_string(self):
return util.version_string(self.protocol_tuple)
async def daemon_request(self, method, *args):
'''Catch a DaemonError and convert it to an RPCError.'''
try:
return await self.chain_state.daemon_request(method, args)
except DaemonError as e:
raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None
def sub_count(self):
return len(self.hashX_subs)
async def notify_touched(self, our_touched):
changed = {}
for hashX in our_touched:
alias = self.hashX_subs[hashX]
status = await self.address_status(hashX)
changed[alias] = status
# Check mempool hashXs - the status is a function of the
# confirmed state of other transactions. Note: we cannot
# iterate over mempool_statuses as it changes size.
for hashX in tuple(self.mempool_statuses):
# Items can be evicted whilst await-ing below; False
# ensures such hashXs are notified
old_status = self.mempool_statuses.get(hashX, False)
status = await self.address_status(hashX)
if status != old_status:
alias = self.hashX_subs[hashX]
changed[alias] = status
for alias, status in changed.items():
if len(alias) == 64:
method = 'blockchain.scripthash.subscribe'
else:
method = 'blockchain.address.subscribe'
await self.send_notification(method, (alias, status))
if changed:
es = '' if len(changed) == 1 else 'es'
self.logger.info('notified of {:,d} address{}'
.format(len(changed), es))
async def notify(self, height, touched):
async def notify(self, touched, height_changed):
'''Notify the client about changes to touched addresses (from mempool
updates or new blocks) and height.
Return the set of addresses the session needs to be
asyncronously notified about. This can be empty if there are
possible mempool status updates.
Returns None if nothing needs to be notified asynchronously.
'''
height_changed = height != self.notified_height
if height_changed:
self.notified_height = height
if self.subscribe_headers:
args = (await self.subscribe_headers_result(height), )
await self.send_notification('blockchain.headers.subscribe',
args)
if height_changed and self.subscribe_headers:
args = (await self.subscribe_headers_result(), )
await self.send_notification('blockchain.headers.subscribe', args)
touched = touched.intersection(self.hashX_subs)
if touched or (height_changed and self.mempool_statuses):
await self.notify_touched(touched)
changed = {}
async def raw_header(self, height):
'''Return the binary header at the given height.'''
try:
return await self.chain_state.raw_header(height)
except IndexError:
raise RPCError(BAD_REQUEST, f'height {height:,d} '
'out of range') from None
for hashX in touched:
alias = self.hashX_subs[hashX]
status = await self.address_status(hashX)
changed[alias] = status
async def electrum_header(self, height):
'''Return the deserialized header at the given height.'''
raw_header = await self.raw_header(height)
return self.coin.electrum_header(raw_header, height)
# Check mempool hashXs - the status is a function of the
# confirmed state of other transactions. Note: we cannot
# iterate over mempool_statuses as it changes size.
for hashX in tuple(self.mempool_statuses):
# Items can be evicted whilst await-ing status; False
# ensures such hashXs are notified
old_status = self.mempool_statuses.get(hashX, False)
status = await self.address_status(hashX)
if status != old_status:
alias = self.hashX_subs[hashX]
changed[alias] = status
async def subscribe_headers_result(self, height):
'''The result of a header subscription for the given height.'''
if self.subscribe_headers_raw:
raw_header = await self.raw_header(height)
return {'hex': raw_header.hex(), 'height': height}
return await self.electrum_header(height)
for alias, status in changed.items():
if len(alias) == 64:
method = 'blockchain.scripthash.subscribe'
else:
method = 'blockchain.address.subscribe'
await self.send_notification(method, (alias, status))
if changed:
es = '' if len(changed) == 1 else 'es'
self.logger.info(f'notified of {len(changed):,d} address{es}')
async def subscribe_headers_result(self):
'''The result of a header subscription or notification.'''
return self.session_mgr.hsub_results[self.subscribe_headers_raw]
async def _headers_subscribe(self, raw):
'''Subscribe to get headers of new blocks.'''
self.subscribe_headers = True
self.subscribe_headers_raw = assert_boolean(raw)
self.notified_height = self.db_height()
return await self.subscribe_headers_result(self.notified_height)
self.subscribe_headers = True
return await self.subscribe_headers_result()
async def headers_subscribe(self):
'''Subscribe to get raw headers of new blocks.'''
@ -804,7 +851,7 @@ class ElectrumX(SessionBase):
async def hashX_listunspent(self, hashX):
'''Return the list of UTXOs of a script hash, including mempool
effects.'''
utxos = await self.chain_state.all_utxos(hashX)
utxos = await self.db.all_utxos(hashX)
utxos = sorted(utxos)
utxos.extend(await self.mempool.unordered_UTXOs(hashX))
spends = await self.mempool.potential_spends(hashX)
@ -861,7 +908,7 @@ class ElectrumX(SessionBase):
return await self.hashX_subscribe(hashX, address)
async def get_balance(self, hashX):
utxos = await self.chain_state.all_utxos(hashX)
utxos = await self.db.all_utxos(hashX)
confirmed = sum(utxo.value for utxo in utxos)
unconfirmed = await self.mempool.balance_delta(hashX)
return {'confirmed': confirmed, 'unconfirmed': unconfirmed}
@ -909,14 +956,14 @@ class ElectrumX(SessionBase):
return await self.hashX_subscribe(hashX, scripthash)
async def _merkle_proof(self, cp_height, height):
max_height = self.db_height()
max_height = self.db.db_height
if not height <= cp_height <= max_height:
raise RPCError(BAD_REQUEST,
f'require header height {height:,d} <= '
f'cp_height {cp_height:,d} <= '
f'chain height {max_height:,d}')
branch, root = await self.chain_state.header_branch_and_root(
cp_height + 1, height)
branch, root = await self.db.header_branch_and_root(cp_height + 1,
height)
return {
'branch': [hash_to_hex_str(elt) for elt in branch],
'root': hash_to_hex_str(root),
@ -927,7 +974,7 @@ class ElectrumX(SessionBase):
dictionary with a merkle proof.'''
height = non_negative_integer(height)
cp_height = non_negative_integer(cp_height)
raw_header_hex = (await self.raw_header(height)).hex()
raw_header_hex = (await self.session_mgr.raw_header(height)).hex()
if cp_height == 0:
return raw_header_hex
result = {'header': raw_header_hex}
@ -953,8 +1000,7 @@ class ElectrumX(SessionBase):
max_size = self.MAX_CHUNK_SIZE
count = min(count, max_size)
headers, count = await self.chain_state.read_headers(start_height,
count)
headers, count = await self.db.read_headers(start_height, count)
result = {'hex': headers.hex(), 'count': count, 'max': max_size}
if count and cp_height:
last_height = start_height + count - 1
@ -971,7 +1017,7 @@ class ElectrumX(SessionBase):
index = non_negative_integer(index)
size = self.coin.CHUNK_SIZE
start_height = index * size
headers, _ = await self.chain_state.read_headers(start_height, size)
headers, _ = await self.db.read_headers(start_height, size)
return headers.hex()
async def block_get_header(self, height):
@ -979,7 +1025,7 @@ class ElectrumX(SessionBase):
height: the header's height'''
height = non_negative_integer(height)
return await self.electrum_header(height)
return await self.session_mgr.electrum_header(height)
def is_tor(self):
'''Try to detect if the connection is to a tor hidden service we are
@ -1042,7 +1088,7 @@ class ElectrumX(SessionBase):
number: the number of blocks
'''
number = non_negative_integer(number)
return await self.daemon_request('estimatefee', [number])
return await self.daemon_request('estimatefee', number)
async def ping(self):
'''Serves as a connection keep-alive mechanism and for the client to
@ -1091,15 +1137,14 @@ class ElectrumX(SessionBase):
raw_tx: the raw transaction as a hexadecimal string'''
# This returns errors as JSON RPC errors, as is natural
try:
tx_hash = await self.chain_state.broadcast_transaction(raw_tx)
hex_hash = await self.session_mgr.broadcast_transaction(raw_tx)
self.txs_sent += 1
self.session_mgr.txs_sent += 1
self.logger.info('sent tx: {}'.format(tx_hash))
return tx_hash
self.logger.info(f'sent tx: {hex_hash}')
return hex_hash
except DaemonError as e:
error, = e.args
message = error['message']
self.logger.info(f'sendrawtransaction: {message}')
self.logger.info(f'error sending transaction: {message}')
raise RPCError(BAD_REQUEST, 'the transaction was rejected by '
f'network rules.\n\n{message}\n[{raw_tx}]')
@ -1135,7 +1180,7 @@ class ElectrumX(SessionBase):
tx_pos: index of transaction in tx_hashes to create branch for
'''
hashes = [hex_str_to_hash(hash) for hash in tx_hashes]
branch, root = self.chain_state.tx_branch_and_root(hashes, tx_pos)
branch, root = self.db.merkle.branch_and_root(hashes, tx_pos)
branch = [hash_to_hex_str(hash) for hash in branch]
return branch
@ -1264,9 +1309,9 @@ class DashElectrumX(ElectrumX):
'masternode.list': self.masternode_list
})
async def notify(self, height, touched):
async def notify(self, touched, height_changed):
'''Notify the client about changes in masternode list.'''
await super().notify(height, touched)
await super().notify(touched, height_changed)
for mn in self.mns:
status = await self.daemon_request('masternode_list',
['status', mn])
@ -1358,7 +1403,7 @@ class DashElectrumX(ElectrumX):
# with the masternode information including the payment
# position is returned.
cache = self.session_mgr.mn_cache
if not cache or self.session_mgr.mn_cache_height != self.db_height():
if not cache or self.session_mgr.mn_cache_height != self.db.db_height:
full_mn_list = await self.daemon_request('masternode_list',
['full'])
mn_payment_queue = get_masternode_payment_queue(full_mn_list)
@ -1386,7 +1431,7 @@ class DashElectrumX(ElectrumX):
mn_list.append(mn_info)
cache.clear()
cache.extend(mn_list)
self.session_mgr.mn_cache_height = self.db_height()
self.session_mgr.mn_cache_height = self.db.db_height
# If payees is an empty list the whole masternode list is returned
if payees:

View File

@ -1,5 +1,5 @@
import setuptools
version = '1.8.3'
version = '1.8.4'
setuptools.setup(
name='electrumX',

489
tests/server/test_daemon.py Normal file
View File

@ -0,0 +1,489 @@
import aiohttp
import asyncio
import json
import logging
import pytest
from aiorpcx import (
JSONRPCv1, JSONRPCLoose, RPCError, ignore_after,
Request, Batch,
)
from electrumx.lib.coins import BitcoinCash, CoinError, Bitzeny
from electrumx.server.daemon import (
Daemon, FakeEstimateFeeDaemon, DaemonError
)
coin = BitcoinCash
# These should be full, canonical URLs
urls = ['http://rpc_user:rpc_pass@127.0.0.1:8332/',
'http://rpc_user:rpc_pass@192.168.0.1:8332/']
@pytest.fixture(params=[BitcoinCash, Bitzeny])
def daemon(request):
coin = request.param
return coin.DAEMON(coin, ','.join(urls))
class ResponseBase(object):
def __init__(self, headers, status):
self.headers = headers
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass
class JSONResponse(ResponseBase):
def __init__(self, result, msg_id, status=200):
super().__init__({'Content-Type': 'application/json'}, status)
self.result = result
self.msg_id = msg_id
async def json(self):
if isinstance(self.msg_id, int):
message = JSONRPCv1.response_message(self.result, self.msg_id)
else:
parts = [JSONRPCv1.response_message(item, msg_id)
for item, msg_id in zip(self.result, self.msg_id)]
message = JSONRPCv1.batch_message_from_parts(parts)
return json.loads(message.decode())
class HTMLResponse(ResponseBase):
def __init__(self, text, reason, status):
super().__init__({'Content-Type': 'text/html; charset=ISO-8859-1'},
status)
self._text = text
self.reason = reason
async def text(self):
return self._text
class ClientSessionBase(object):
def __enter__(self):
self.prior_class = aiohttp.ClientSession
aiohttp.ClientSession = lambda: self
def __exit__(self, exc_type, exc_value, traceback):
aiohttp.ClientSession = self.prior_class
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass
class ClientSessionGood(ClientSessionBase):
'''Imitate aiohttp for testing purposes.'''
def __init__(self, *triples):
self.triples = triples # each a (method, args, result)
self.count = 0
self.expected_url = urls[0]
def post(self, url, data=""):
assert url == self.expected_url
request, request_id = JSONRPCLoose.message_to_item(data.encode())
method, args, result = self.triples[self.count]
self.count += 1
if isinstance(request, Request):
assert request.method == method
assert request.args == args
return JSONResponse(result, request_id)
else:
assert isinstance(request, Batch)
for request, args in zip(request, args):
assert request.method == method
assert request.args == args
return JSONResponse(result, request_id)
class ClientSessionBadAuth(ClientSessionBase):
def post(self, url, data=""):
return HTMLResponse('', 'Unauthorized', 401)
class ClientSessionWorkQueueFull(ClientSessionGood):
def post(self, url, data=""):
self.post = super().post
return HTMLResponse('Work queue depth exceeded',
'Internal server error', 500)
class ClientSessionNoConnection(ClientSessionGood):
def __init__(self, *args):
self.args = args
async def __aenter__(self):
aiohttp.ClientSession = lambda: ClientSessionGood(*self.args)
raise aiohttp.ClientConnectionError
class ClientSessionPostError(ClientSessionGood):
def __init__(self, exception, *args):
self.exception = exception
self.args = args
def post(self, url, data=""):
aiohttp.ClientSession = lambda: ClientSessionGood(*self.args)
raise self.exception
class ClientSessionFailover(ClientSessionGood):
def post(self, url, data=""):
# If not failed over; simulate disconnecting
if url == self.expected_url:
raise aiohttp.ServerDisconnectedError
else:
self.expected_url = urls[1]
return super().post(url, data)
def in_caplog(caplog, message, count=1):
return sum(message in record.message
for record in caplog.records) == count
#
# Tests
#
def test_set_urls_bad():
with pytest.raises(CoinError):
Daemon(coin, '')
with pytest.raises(CoinError):
Daemon(coin, 'a')
def test_set_urls_one(caplog):
with caplog.at_level(logging.INFO):
daemon = Daemon(coin, urls[0])
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 1
logged_url = daemon.logged_url()
assert logged_url == '127.0.0.1:8332/'
assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)')
def test_set_urls_two(caplog):
with caplog.at_level(logging.INFO):
daemon = Daemon(coin, ','.join(urls))
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 2
logged_url = daemon.logged_url()
assert logged_url == '127.0.0.1:8332/'
assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)')
assert in_caplog(caplog, 'daemon #2 at 192.168.0.1:8332')
def test_set_urls_short():
no_prefix_urls = ['/'.join(part for part in url.split('/')[2:])
for url in urls]
daemon = Daemon(coin, ','.join(no_prefix_urls))
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 2
no_slash_urls = [url[:-1] for url in urls]
daemon = Daemon(coin, ','.join(no_slash_urls))
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 2
no_port_urls = [url[:url.rfind(':')] for url in urls]
daemon = Daemon(coin, ','.join(no_port_urls))
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 2
def test_failover_good(caplog):
daemon = Daemon(coin, ','.join(urls))
with caplog.at_level(logging.INFO):
result = daemon.failover()
assert result is True
assert daemon.current_url() == urls[1]
logged_url = daemon.logged_url()
assert in_caplog(caplog, f'failing over to {logged_url}')
# And again
result = daemon.failover()
assert result is True
assert daemon.current_url() == urls[0]
def test_failover_fail(caplog):
daemon = Daemon(coin, urls[0])
with caplog.at_level(logging.INFO):
result = daemon.failover()
assert result is False
assert daemon.current_url() == urls[0]
assert not in_caplog(caplog, f'failing over')
@pytest.mark.asyncio
async def test_height(daemon):
assert daemon.cached_height() is None
height = 300
with ClientSessionGood(('getblockcount', [], height)):
assert await daemon.height() == height
assert daemon.cached_height() == height
@pytest.mark.asyncio
async def test_broadcast_transaction(daemon):
raw_tx = 'deadbeef'
tx_hash = 'hash'
with ClientSessionGood(('sendrawtransaction', [raw_tx], tx_hash)):
assert await daemon.broadcast_transaction(raw_tx) == tx_hash
@pytest.mark.asyncio
async def test_relayfee(daemon):
response = {"relayfee": sats, "other:": "cruft"}
with ClientSessionGood(('getnetworkinfo', [], response)):
assert await daemon.getnetworkinfo() == response
@pytest.mark.asyncio
async def test_relayfee(daemon):
if isinstance(daemon, FakeEstimateFeeDaemon):
sats = daemon.coin.ESTIMATE_FEE
else:
sats = 2
response = {"relayfee": sats, "other:": "cruft"}
with ClientSessionGood(('getnetworkinfo', [], response)):
assert await daemon.relayfee() == sats
@pytest.mark.asyncio
async def test_mempool_hashes(daemon):
hashes = ['hex_hash1', 'hex_hash2']
with ClientSessionGood(('getrawmempool', [], hashes)):
assert await daemon.mempool_hashes() == hashes
@pytest.mark.asyncio
async def test_deserialised_block(daemon):
block_hash = 'block_hash'
result = {'some': 'mess'}
with ClientSessionGood(('getblock', [block_hash, True], result)):
assert await daemon.deserialised_block(block_hash) == result
@pytest.mark.asyncio
async def test_estimatefee(daemon):
method_not_found = RPCError(JSONRPCv1.METHOD_NOT_FOUND, 'nope')
if isinstance(daemon, FakeEstimateFeeDaemon):
result = daemon.coin.ESTIMATE_FEE
else:
result = -1
with ClientSessionGood(
('estimatesmartfee', [], method_not_found),
('estimatefee', [2], result)
):
assert await daemon.estimatefee(2) == result
@pytest.mark.asyncio
async def test_estimatefee_smart(daemon):
bad_args = RPCError(JSONRPCv1.INVALID_ARGS, 'bad args')
if isinstance(daemon, FakeEstimateFeeDaemon):
return
rate = 0.0002
result = {'feerate': rate}
with ClientSessionGood(
('estimatesmartfee', [], bad_args),
('estimatesmartfee', [2], result)
):
assert await daemon.estimatefee(2) == rate
# Test the rpc_available_cache is used
with ClientSessionGood(('estimatesmartfee', [2], result)):
assert await daemon.estimatefee(2) == rate
@pytest.mark.asyncio
async def test_getrawtransaction(daemon):
hex_hash = 'deadbeef'
simple = 'tx_in_hex'
verbose = {'hex': hex_hash, 'other': 'cruft'}
# Test False is converted to 0 - old daemon's reject False
with ClientSessionGood(('getrawtransaction', [hex_hash, 0], simple)):
assert await daemon.getrawtransaction(hex_hash) == simple
# Test True is converted to 1
with ClientSessionGood(('getrawtransaction', [hex_hash, 1], verbose)):
assert await daemon.getrawtransaction(
hex_hash, True) == verbose
# Batch tests
@pytest.mark.asyncio
async def test_empty_send(daemon):
first = 5
count = 0
with ClientSessionGood(('getblockhash', [], [])):
assert await daemon.block_hex_hashes(first, count) == []
@pytest.mark.asyncio
async def test_block_hex_hashes(daemon):
first = 5
count = 3
hashes = [f'hex_hash{n}' for n in range(count)]
with ClientSessionGood(('getblockhash',
[[n] for n in range(first, first + count)],
hashes)):
assert await daemon.block_hex_hashes(first, count) == hashes
@pytest.mark.asyncio
async def test_raw_blocks(daemon):
count = 3
hex_hashes = [f'hex_hash{n}' for n in range(count)]
args_list = [[hex_hash, False] for hex_hash in hex_hashes]
iterable = (hex_hash for hex_hash in hex_hashes)
blocks = ["00", "019a", "02fe"]
blocks_raw = [bytes.fromhex(block) for block in blocks]
with ClientSessionGood(('getblock', args_list, blocks)):
assert await daemon.raw_blocks(iterable) == blocks_raw
@pytest.mark.asyncio
async def test_get_raw_transactions(daemon):
hex_hashes = ['deadbeef0', 'deadbeef1']
args_list = [[hex_hash, 0] for hex_hash in hex_hashes]
raw_txs_hex = ['fffefdfc', '0a0b0c0d']
raw_txs = [bytes.fromhex(raw_tx) for raw_tx in raw_txs_hex]
# Test 0 - old daemon's reject False
with ClientSessionGood(('getrawtransaction', args_list, raw_txs_hex)):
assert await daemon.getrawtransactions(hex_hashes) == raw_txs
# Test one error
tx_not_found = RPCError(-1, 'some error message')
results = ['ff0b7d', tx_not_found]
raw_txs = [bytes.fromhex(results[0]), None]
with ClientSessionGood(('getrawtransaction', args_list, results)):
assert await daemon.getrawtransactions(hex_hashes) == raw_txs
# Other tests
@pytest.mark.asyncio
async def test_bad_auth(daemon, caplog):
with pytest.raises(DaemonError) as e:
with ClientSessionBadAuth():
await daemon.height()
assert "Unauthorized" in e.value.args[0]
assert in_caplog(caplog, "Unauthorized")
@pytest.mark.asyncio
async def test_workqueue_depth(daemon, caplog):
daemon.init_retry = 0.01
height = 125
with caplog.at_level(logging.INFO):
with ClientSessionWorkQueueFull(('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "work queue full")
assert in_caplog(caplog, "running normally")
@pytest.mark.asyncio
async def test_connection_error(daemon, caplog):
height = 100
daemon.init_retry = 0.01
with caplog.at_level(logging.INFO):
with ClientSessionNoConnection(('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "connection problem - is your daemon running?")
assert in_caplog(caplog, "connection restored")
@pytest.mark.asyncio
async def test_timeout_error(daemon, caplog):
height = 100
daemon.init_retry = 0.01
with caplog.at_level(logging.INFO):
with ClientSessionPostError(asyncio.TimeoutError,
('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "timeout error")
@pytest.mark.asyncio
async def test_disconnected(daemon, caplog):
height = 100
daemon.init_retry = 0.01
with caplog.at_level(logging.INFO):
with ClientSessionPostError(aiohttp.ServerDisconnectedError,
('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "disconnected")
assert in_caplog(caplog, "connection restored")
@pytest.mark.asyncio
async def test_warming_up(daemon, caplog):
warming_up = RPCError(-28, 'reading block index')
height = 100
daemon.init_retry = 0.01
with caplog.at_level(logging.INFO):
with ClientSessionGood(
('getblockcount', [], warming_up),
('getblockcount', [], height)
):
assert await daemon.height() == height
assert in_caplog(caplog, "starting up checking blocks")
assert in_caplog(caplog, "running normally")
@pytest.mark.asyncio
async def test_warming_up_batch(daemon, caplog):
warming_up = RPCError(-28, 'reading block index')
first = 5
count = 1
daemon.init_retry = 0.01
hashes = ['hex_hash5']
with caplog.at_level(logging.INFO):
with ClientSessionGood(('getblockhash', [[first]], [warming_up]),
('getblockhash', [[first]], hashes)):
assert await daemon.block_hex_hashes(first, count) == hashes
assert in_caplog(caplog, "starting up checking blocks")
assert in_caplog(caplog, "running normally")
@pytest.mark.asyncio
async def test_failover(daemon, caplog):
height = 100
daemon.init_retry = 0.01
daemon.max_retry = 0.04
with caplog.at_level(logging.INFO):
with ClientSessionFailover(('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "disconnected", 3)
assert in_caplog(caplog, "failing over")
assert in_caplog(caplog, "connection restored")