Merge branch 'devel'

This commit is contained in:
Neil Booth 2018-08-14 12:58:52 +09:00
commit f0a2f128dc
13 changed files with 775 additions and 350 deletions

View File

@ -15,7 +15,7 @@
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath('..')) sys.path.insert(0, os.path.abspath('..'))
VERSION="ElectrumX 1.8.3" VERSION="ElectrumX 1.8.4"
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------

View File

@ -1,4 +1,4 @@
version = 'ElectrumX 1.8.3' version = 'ElectrumX 1.8.4-dev'
version_short = version.split()[-1] version_short = version.split()[-1]
from electrumx.server.controller import Controller from electrumx.server.controller import Controller

View File

@ -114,10 +114,6 @@ class Coin(object):
url = 'http://' + url url = 'http://' + url
return url + '/' return url + '/'
@classmethod
def daemon_urls(cls, urls):
return [cls.sanitize_url(url) for url in urls.split(',')]
@classmethod @classmethod
def genesis_block(cls, block): def genesis_block(cls, block):
'''Check the Genesis block is the right one for this coin. '''Check the Genesis block is the right one for this coin.

View File

@ -13,7 +13,7 @@ import sys
import time import time
from functools import partial from functools import partial
from aiorpcx import TaskGroup from aiorpcx import spawn
from electrumx.lib.util import class_logger from electrumx.lib.util import class_logger
@ -93,12 +93,11 @@ class ServerBase(object):
loop.set_exception_handler(self.on_exception) loop.set_exception_handler(self.on_exception)
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
async with TaskGroup() as group: server_task = await spawn(self.serve(shutdown_event))
server_task = await group.spawn(self.serve(shutdown_event)) # Wait for shutdown, log on receipt of the event
# Wait for shutdown, log on receipt of the event await shutdown_event.wait()
await shutdown_event.wait() self.logger.info('shutting down')
self.logger.info('shutting down') server_task.cancel()
server_task.cancel()
# Prevent some silly logs # Prevent some silly logs
await asyncio.sleep(0.01) await asyncio.sleep(0.01)

View File

@ -650,10 +650,7 @@ class BlockProcessor(object):
could be lost. could be lost.
''' '''
self._caught_up_event = caught_up_event self._caught_up_event = caught_up_event
async with TaskGroup() as group: await self._first_open_dbs()
await group.spawn(self._first_open_dbs())
# Ensure cached_height is set
await group.spawn(self.daemon.height())
try: try:
async with TaskGroup() as group: async with TaskGroup() as group:
await group.spawn(self.prefetcher.main_loop(self.height)) await group.spawn(self.prefetcher.main_loop(self.height))

View File

@ -1,102 +0,0 @@
# Copyright (c) 2016-2018, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
from electrumx.lib.hash import hash_to_hex_str
class ChainState(object):
'''Used as an interface by servers to request information about
blocks, transaction history, UTXOs and the mempool.
'''
def __init__(self, env, db, daemon, bp):
self._env = env
self._db = db
self._daemon = daemon
# External interface pass-throughs for session.py
self.force_chain_reorg = bp.force_chain_reorg
self.tx_branch_and_root = db.merkle.branch_and_root
self.read_headers = db.read_headers
self.all_utxos = db.all_utxos
self.limited_history = db.limited_history
self.header_branch_and_root = db.header_branch_and_root
async def broadcast_transaction(self, raw_tx):
return await self._daemon.sendrawtransaction([raw_tx])
async def daemon_request(self, method, args=()):
return await getattr(self._daemon, method)(*args)
def db_height(self):
return self._db.db_height
def get_info(self):
'''Chain state info for LocalRPC and logs.'''
return {
'daemon': self._daemon.logged_url(),
'daemon_height': self._daemon.cached_height(),
'db_height': self.db_height(),
}
async def raw_header(self, height):
'''Return the binary header at the given height.'''
header, n = await self.read_headers(height, 1)
if n != 1:
raise IndexError(f'height {height:,d} out of range')
return header
def set_daemon_url(self, daemon_url):
self._daemon.set_urls(self._env.coin.daemon_urls(daemon_url))
return self._daemon.logged_url()
async def query(self, args, limit):
coin = self._env.coin
db = self._db
lines = []
def arg_to_hashX(arg):
try:
script = bytes.fromhex(arg)
lines.append(f'Script: {arg}')
return coin.hashX_from_script(script)
except ValueError:
pass
hashX = coin.address_to_hashX(arg)
lines.append(f'Address: {arg}')
return hashX
for arg in args:
hashX = arg_to_hashX(arg)
if not hashX:
continue
n = None
history = await db.limited_history(hashX, limit=limit)
for n, (tx_hash, height) in enumerate(history):
lines.append(f'History #{n:,d}: height {height:,d} '
f'tx_hash {hash_to_hex_str(tx_hash)}')
if n is None:
lines.append('No history found')
n = None
utxos = await db.all_utxos(hashX)
for n, utxo in enumerate(utxos, start=1):
lines.append(f'UTXO #{n:,d}: tx_hash '
f'{hash_to_hex_str(utxo.tx_hash)} '
f'tx_pos {utxo.tx_pos:,d} height '
f'{utxo.height:,d} value {utxo.value:,d}')
if n == limit:
break
if n is None:
lines.append('No UTXOs found')
balance = sum(utxo.value for utxo in utxos)
lines.append(f'Balance: {coin.decimal_value(balance):,f} '
f'{coin.SHORTNAME}')
return lines

View File

@ -12,7 +12,6 @@ from aiorpcx import _version as aiorpcx_version, TaskGroup
import electrumx import electrumx
from electrumx.lib.server_base import ServerBase from electrumx.lib.server_base import ServerBase
from electrumx.lib.util import version_string from electrumx.lib.util import version_string
from electrumx.server.chain_state import ChainState
from electrumx.server.db import DB from electrumx.server.db import DB
from electrumx.server.mempool import MemPool, MemPoolAPI from electrumx.server.mempool import MemPool, MemPoolAPI
from electrumx.server.session import SessionManager from electrumx.server.session import SessionManager
@ -93,11 +92,12 @@ class Controller(ServerBase):
self.logger.info(f'reorg limit is {env.reorg_limit:,d} blocks') self.logger.info(f'reorg limit is {env.reorg_limit:,d} blocks')
notifications = Notifications() notifications = Notifications()
daemon = env.coin.DAEMON(env) Daemon = env.coin.DAEMON
db = DB(env)
BlockProcessor = env.coin.BLOCK_PROCESSOR BlockProcessor = env.coin.BLOCK_PROCESSOR
daemon = Daemon(env.coin, env.daemon_url)
db = DB(env)
bp = BlockProcessor(env, db, daemon, notifications) bp = BlockProcessor(env, db, daemon, notifications)
chain_state = ChainState(env, db, daemon, bp)
# Set ourselves up to implement the MemPoolAPI # Set ourselves up to implement the MemPoolAPI
self.height = daemon.height self.height = daemon.height
@ -109,13 +109,16 @@ class Controller(ServerBase):
MemPoolAPI.register(Controller) MemPoolAPI.register(Controller)
mempool = MemPool(env.coin, self) mempool = MemPool(env.coin, self)
session_mgr = SessionManager(env, chain_state, mempool, session_mgr = SessionManager(env, db, bp, daemon, mempool,
notifications, shutdown_event) notifications, shutdown_event)
# Test daemon authentication, and also ensure it has a cached
# height. Do this before entering the task group.
await daemon.height()
caught_up_event = Event() caught_up_event = Event()
serve_externally_event = Event() serve_externally_event = Event()
synchronized_event = Event() synchronized_event = Event()
async with TaskGroup() as group: async with TaskGroup() as group:
await group.spawn(session_mgr.serve(serve_externally_event)) await group.spawn(session_mgr.serve(serve_externally_event))
await group.spawn(bp.fetch_and_process_blocks(caught_up_event)) await group.spawn(bp.fetch_and_process_blocks(caught_up_event))

View File

@ -9,6 +9,7 @@
daemon.''' daemon.'''
import asyncio import asyncio
import itertools
import json import json
import time import time
from calendar import timegm from calendar import timegm
@ -28,48 +29,53 @@ class DaemonError(Exception):
'''Raised when the daemon returns an error in its results.''' '''Raised when the daemon returns an error in its results.'''
class WarmingUpError(Exception):
'''Internal - when the daemon is warming up.'''
class WorkQueueFullError(Exception):
'''Internal - when the daemon's work queue is full.'''
class Daemon(object): class Daemon(object):
'''Handles connections to a daemon at the given URL.''' '''Handles connections to a daemon at the given URL.'''
WARMING_UP = -28 WARMING_UP = -28
RPC_MISC_ERROR = -1 id_counter = itertools.count()
class DaemonWarmingUpError(Exception): def __init__(self, coin, url, max_workqueue=10, init_retry=0.25,
'''Raised when the daemon returns an error in its results.''' max_retry=4.0):
self.coin = coin
def __init__(self, env):
self.logger = class_logger(__name__, self.__class__.__name__) self.logger = class_logger(__name__, self.__class__.__name__)
self.coin = env.coin self.set_url(url)
self.set_urls(env.coin.daemon_urls(env.daemon_url))
self._height = None
# Limit concurrent RPC calls to this number. # Limit concurrent RPC calls to this number.
# See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16 # See DEFAULT_HTTP_WORKQUEUE in bitcoind, which is typically 16
self.workqueue_semaphore = asyncio.Semaphore(value=10) self.workqueue_semaphore = asyncio.Semaphore(value=max_workqueue)
self.down = False self.init_retry = init_retry
self.last_error_time = 0 self.max_retry = max_retry
self.req_id = 0 self._height = None
self._available_rpcs = {} # caches results for _is_rpc_available() self.available_rpcs = {}
def next_req_id(self): def set_url(self, url):
'''Retrns the next request ID.'''
self.req_id += 1
return self.req_id
def set_urls(self, urls):
'''Set the URLS to the given list, and switch to the first one.''' '''Set the URLS to the given list, and switch to the first one.'''
if not urls: urls = url.split(',')
raise DaemonError('no daemon URLs provided') urls = [self.coin.sanitize_url(url) for url in urls]
self.urls = urls
self.url_index = 0
for n, url in enumerate(urls): for n, url in enumerate(urls):
self.logger.info('daemon #{:d} at {}{}' status = '' if n else ' (current)'
.format(n + 1, self.logged_url(url), logged_url = self.logged_url(url)
'' if n else ' (current)')) self.logger.info(f'daemon #{n + 1} at {logged_url}{status}')
self.url_index = 0
self.urls = urls
def url(self): def current_url(self):
'''Returns the current daemon URL.''' '''Returns the current daemon URL.'''
return self.urls[self.url_index] return self.urls[self.url_index]
def logged_url(self, url=None):
'''The host and port part, for logging.'''
url = url or self.current_url()
return url[url.rindex('@') + 1:]
def failover(self): def failover(self):
'''Call to fail-over to the next daemon URL. '''Call to fail-over to the next daemon URL.
@ -77,7 +83,7 @@ class Daemon(object):
''' '''
if len(self.urls) > 1: if len(self.urls) > 1:
self.url_index = (self.url_index + 1) % len(self.urls) self.url_index = (self.url_index + 1) % len(self.urls)
self.logger.info('failing over to {}'.format(self.logged_url())) self.logger.info(f'failing over to {self.logged_url()}')
return True return True
return False return False
@ -88,13 +94,17 @@ class Daemon(object):
async def _send_data(self, data): async def _send_data(self, data):
async with self.workqueue_semaphore: async with self.workqueue_semaphore:
async with self.client_session() as session: async with self.client_session() as session:
async with session.post(self.url(), data=data) as resp: async with session.post(self.current_url(), data=data) as resp:
# If bitcoind can't find a tx, for some reason kind = resp.headers.get('Content-Type', None)
# it returns 500 but fills out the JSON. if kind == 'application/json':
# Should still return 200 IMO.
if resp.status in (200, 404, 500):
return await resp.json() return await resp.json()
return (resp.status, resp.reason) # bitcoind's HTTP protocol "handling" is a bad joke
text = await resp.text()
if 'Work queue depth exceeded' in text:
raise WorkQueueFullError
text = text.strip() or resp.reason
self.logger.error(text)
raise DaemonError(text)
async def _send(self, payload, processor): async def _send(self, payload, processor):
'''Send a payload to be converted to JSON. '''Send a payload to be converted to JSON.
@ -103,54 +113,42 @@ class Daemon(object):
are raise through DaemonError. are raise through DaemonError.
''' '''
def log_error(error): def log_error(error):
self.down = True nonlocal last_error_log, retry
now = time.time() now = time.time()
prior_time = self.last_error_time if now - last_error_log > 60:
if now - prior_time > 60: last_error_time = now
self.last_error_time = now self.logger.error(f'{error} Retrying occasionally...')
if prior_time and self.failover(): if retry == self.max_retry and self.failover():
secs = 0 retry = 0
else:
self.logger.error('{} Retrying occasionally...'
.format(error))
on_good_message = None
last_error_log = 0
data = json.dumps(payload) data = json.dumps(payload)
secs = 1 retry = self.init_retry
max_secs = 4
while True: while True:
try: try:
result = await self._send_data(data) result = await self._send_data(data)
if not isinstance(result, tuple): result = processor(result)
result = processor(result) if on_good_message:
if self.down: self.logger.info(on_good_message)
self.down = False return result
self.last_error_time = 0
self.logger.info('connection restored')
return result
log_error('HTTP error code {:d}: {}'
.format(result[0], result[1]))
except asyncio.TimeoutError: except asyncio.TimeoutError:
log_error('timeout error.') log_error('timeout error.')
except aiohttp.ServerDisconnectedError: except aiohttp.ServerDisconnectedError:
log_error('disconnected.') log_error('disconnected.')
except aiohttp.ClientPayloadError: on_good_message = 'connection restored'
log_error('payload encoding error.')
except aiohttp.ClientConnectionError: except aiohttp.ClientConnectionError:
log_error('connection problem - is your daemon running?') log_error('connection problem - is your daemon running?')
except self.DaemonWarmingUpError: on_good_message = 'connection restored'
except WarmingUpError:
log_error('starting up checking blocks.') log_error('starting up checking blocks.')
except (asyncio.CancelledError, DaemonError): on_good_message = 'running normally'
raise except WorkQueueFullError:
except Exception as e: log_error('work queue full.')
self.logger.exception(f'uncaught exception: {e}') on_good_message = 'running normally'
await asyncio.sleep(secs) await asyncio.sleep(retry)
secs = min(max_secs, secs * 2, 1) retry = max(min(self.max_retry, retry * 2), self.init_retry)
def logged_url(self, url=None):
'''The host and port part, for logging.'''
url = url or self.url()
return url[url.rindex('@') + 1:]
async def _send_single(self, method, params=None): async def _send_single(self, method, params=None):
'''Send a single request to the daemon.''' '''Send a single request to the daemon.'''
@ -159,10 +157,10 @@ class Daemon(object):
if not err: if not err:
return result['result'] return result['result']
if err.get('code') == self.WARMING_UP: if err.get('code') == self.WARMING_UP:
raise self.DaemonWarmingUpError raise WarmingUpError
raise DaemonError(err) raise DaemonError(err)
payload = {'method': method, 'id': self.next_req_id()} payload = {'method': method, 'id': next(self.id_counter)}
if params: if params:
payload['params'] = params payload['params'] = params
return await self._send(payload, processor) return await self._send(payload, processor)
@ -176,12 +174,12 @@ class Daemon(object):
def processor(result): def processor(result):
errs = [item['error'] for item in result if item['error']] errs = [item['error'] for item in result if item['error']]
if any(err.get('code') == self.WARMING_UP for err in errs): if any(err.get('code') == self.WARMING_UP for err in errs):
raise self.DaemonWarmingUpError raise WarmingUpError
if not errs or replace_errs: if not errs or replace_errs:
return [item['result'] for item in result] return [item['result'] for item in result]
raise DaemonError(errs) raise DaemonError(errs)
payload = [{'method': method, 'params': p, 'id': self.next_req_id()} payload = [{'method': method, 'params': p, 'id': next(self.id_counter)}
for p in params_iterable] for p in params_iterable]
if payload: if payload:
return await self._send(payload, processor) return await self._send(payload, processor)
@ -192,27 +190,16 @@ class Daemon(object):
Results are cached and the daemon will generally not be queried with Results are cached and the daemon will generally not be queried with
the same method more than once.''' the same method more than once.'''
available = self._available_rpcs.get(method, None) available = self.available_rpcs.get(method)
if available is None: if available is None:
available = True
try: try:
await self._send_single(method) await self._send_single(method)
available = True
except DaemonError as e: except DaemonError as e:
err = e.args[0] err = e.args[0]
error_code = err.get("code") error_code = err.get("code")
if error_code == JSONRPC.METHOD_NOT_FOUND: available = error_code != JSONRPC.METHOD_NOT_FOUND
available = False self.available_rpcs[method] = available
elif error_code == self.RPC_MISC_ERROR:
# method found but exception was thrown in command handling
# probably because we did not provide arguments
available = True
else:
self.logger.warning('error (code {:d}: {}) when testing '
'RPC availability of method {}'
.format(error_code, err.get("message"),
method))
available = False
self._available_rpcs[method] = available
return available return available
async def block_hex_hashes(self, first, count): async def block_hex_hashes(self, first, count):
@ -235,12 +222,16 @@ class Daemon(object):
'''Update our record of the daemon's mempool hashes.''' '''Update our record of the daemon's mempool hashes.'''
return await self._send_single('getrawmempool') return await self._send_single('getrawmempool')
async def estimatefee(self, params): async def estimatefee(self, block_count):
'''Return the fee estimate for the given parameters.''' '''Return the fee estimate for the block count. Units are whole
currency units per KB, e.g. 0.00000995, or -1 if no estimate
is available.
'''
args = (block_count, )
if await self._is_rpc_available('estimatesmartfee'): if await self._is_rpc_available('estimatesmartfee'):
estimate = await self._send_single('estimatesmartfee', params) estimate = await self._send_single('estimatesmartfee', args)
return estimate.get('feerate', -1) return estimate.get('feerate', -1)
return await self._send_single('estimatefee', params) return await self._send_single('estimatefee', args)
async def getnetworkinfo(self): async def getnetworkinfo(self):
'''Return the result of the 'getnetworkinfo' RPC call.''' '''Return the result of the 'getnetworkinfo' RPC call.'''
@ -268,9 +259,9 @@ class Daemon(object):
# Convert hex strings to bytes # Convert hex strings to bytes
return [hex_to_bytes(tx) if tx else None for tx in txs] return [hex_to_bytes(tx) if tx else None for tx in txs]
async def sendrawtransaction(self, params): async def broadcast_transaction(self, raw_tx):
'''Broadcast a transaction to the network.''' '''Broadcast a transaction to the network.'''
return await self._send_single('sendrawtransaction', params) return await self._send_single('sendrawtransaction', (raw_tx, ))
async def height(self): async def height(self):
'''Query the daemon for its current height.''' '''Query the daemon for its current height.'''
@ -299,7 +290,7 @@ class FakeEstimateFeeDaemon(Daemon):
'''Daemon that simulates estimatefee and relayfee RPC calls. Coin that '''Daemon that simulates estimatefee and relayfee RPC calls. Coin that
wants to use this daemon must define ESTIMATE_FEE & RELAY_FEE''' wants to use this daemon must define ESTIMATE_FEE & RELAY_FEE'''
async def estimatefee(self, params): async def estimatefee(self, block_count):
'''Return the fee estimate for the given parameters.''' '''Return the fee estimate for the given parameters.'''
return self.coin.ESTIMATE_FEE return self.coin.ESTIMATE_FEE

View File

@ -370,6 +370,13 @@ class DB(object):
# Truncate header_mc: header count is 1 more than the height. # Truncate header_mc: header count is 1 more than the height.
self.header_mc.truncate(height + 1) self.header_mc.truncate(height + 1)
async def raw_header(self, height):
'''Return the binary header at the given height.'''
header, n = await self.read_headers(height, 1)
if n != 1:
raise IndexError(f'height {height:,d} out of range')
return header
async def read_headers(self, start_height, count): async def read_headers(self, start_height, count):
'''Requires start_height >= 0, count >= 0. Reads as many headers as '''Requires start_height >= 0, count >= 0. Reads as many headers as
are available starting at start_height up to count. This are available starting at start_height up to count. This

View File

@ -55,12 +55,12 @@ class PeerManager(object):
Attempts to maintain a connection with up to 8 peers. Attempts to maintain a connection with up to 8 peers.
Issues a 'peers.subscribe' RPC to them and tells them our data. Issues a 'peers.subscribe' RPC to them and tells them our data.
''' '''
def __init__(self, env, chain_state): def __init__(self, env, db):
self.logger = class_logger(__name__, self.__class__.__name__) self.logger = class_logger(__name__, self.__class__.__name__)
# Initialise the Peer class # Initialise the Peer class
Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS Peer.DEFAULT_PORTS = env.coin.PEER_DEFAULT_PORTS
self.env = env self.env = env
self.chain_state = chain_state self.db = db
# Our clearnet and Tor Peers, if any # Our clearnet and Tor Peers, if any
sclass = env.coin.SESSIONCLS sclass = env.coin.SESSIONCLS
@ -300,7 +300,7 @@ class PeerManager(object):
result = await session.send_request(message) result = await session.send_request(message)
assert_good(message, result, dict) assert_good(message, result, dict)
our_height = self.chain_state.db_height() our_height = self.db.db_height
if ptuple < (1, 3): if ptuple < (1, 3):
their_height = result.get('block_height') their_height = result.get('block_height')
else: else:
@ -313,7 +313,7 @@ class PeerManager(object):
# Check prior header too in case of hard fork. # Check prior header too in case of hard fork.
check_height = min(our_height, their_height) check_height = min(our_height, their_height)
raw_header = await self.chain_state.raw_header(check_height) raw_header = await self.db.raw_header(check_height)
if ptuple >= (1, 4): if ptuple >= (1, 4):
ours = raw_header.hex() ours = raw_header.hex()
message = 'blockchain.block.header' message = 'blockchain.block.header'

View File

@ -109,13 +109,15 @@ class SessionManager(object):
CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4) CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4)
def __init__(self, env, chain_state, mempool, notifications, def __init__(self, env, db, bp, daemon, mempool, notifications,
shutdown_event): shutdown_event):
env.max_send = max(350000, env.max_send) env.max_send = max(350000, env.max_send)
self.env = env self.env = env
self.chain_state = chain_state self.db = db
self.bp = bp
self.daemon = daemon
self.mempool = mempool self.mempool = mempool
self.peer_mgr = PeerManager(env, chain_state) self.peer_mgr = PeerManager(env, db)
self.shutdown_event = shutdown_event self.shutdown_event = shutdown_event
self.logger = util.class_logger(__name__, self.__class__.__name__) self.logger = util.class_logger(__name__, self.__class__.__name__)
self.servers = {} self.servers = {}
@ -127,8 +129,8 @@ class SessionManager(object):
self.state = self.CATCHING_UP self.state = self.CATCHING_UP
self.txs_sent = 0 self.txs_sent = 0
self.start_time = time.time() self.start_time = time.time()
self._history_cache = pylru.lrucache(256) self.history_cache = pylru.lrucache(256)
self._hc_height = 0 self.notified_height = None
# Cache some idea of room to avoid recounting on each subscription # Cache some idea of room to avoid recounting on each subscription
self.subs_room = 0 self.subs_room = 0
# Masternode stuff only for such coins # Masternode stuff only for such coins
@ -152,7 +154,7 @@ class SessionManager(object):
protocol_class = LocalRPC protocol_class = LocalRPC
else: else:
protocol_class = self.env.coin.SESSIONCLS protocol_class = self.env.coin.SESSIONCLS
protocol_factory = partial(protocol_class, self, self.chain_state, protocol_factory = partial(protocol_class, self, self.db,
self.mempool, self.peer_mgr, kind) self.mempool, self.peer_mgr, kind)
server = loop.create_server(protocol_factory, *args, **kw_args) server = loop.create_server(protocol_factory, *args, **kw_args)
@ -276,10 +278,11 @@ class SessionManager(object):
def _get_info(self): def _get_info(self):
'''A summary of server state.''' '''A summary of server state.'''
group_map = self._group_map() group_map = self._group_map()
result = self.chain_state.get_info() return {
result.update({
'version': electrumx.version,
'closing': len([s for s in self.sessions if s.is_closing()]), 'closing': len([s for s in self.sessions if s.is_closing()]),
'daemon': self.daemon.logged_url(),
'daemon_height': self.daemon.cached_height(),
'db_height': self.db.db_height,
'errors': sum(s.errors for s in self.sessions), 'errors': sum(s.errors for s in self.sessions),
'groups': len(group_map), 'groups': len(group_map),
'logged': len([s for s in self.sessions if s.log_me]), 'logged': len([s for s in self.sessions if s.log_me]),
@ -291,8 +294,8 @@ class SessionManager(object):
'subs': self._sub_count(), 'subs': self._sub_count(),
'txs_sent': self.txs_sent, 'txs_sent': self.txs_sent,
'uptime': util.formatted_time(time.time() - self.start_time), 'uptime': util.formatted_time(time.time() - self.start_time),
}) 'version': electrumx.version,
return result }
def _session_data(self, for_log): def _session_data(self, for_log):
'''Returned to the RPC 'sessions' call.''' '''Returned to the RPC 'sessions' call.'''
@ -329,6 +332,19 @@ class SessionManager(object):
]) ])
return result return result
async def _electrum_and_raw_headers(self, height):
raw_header = await self.raw_header(height)
electrum_header = self.env.coin.electrum_header(raw_header, height)
return electrum_header, raw_header
async def _refresh_hsub_results(self, height):
'''Refresh the cached header subscription responses to be for height,
and record that as notified_height.
'''
electrum, raw = await self._electrum_and_raw_headers(height)
self.hsub_results = (electrum, {'hex': raw.hex(), 'height': height})
self.notified_height = height
# --- LocalRPC command handlers # --- LocalRPC command handlers
async def rpc_add_peer(self, real_name): async def rpc_add_peer(self, real_name):
@ -367,10 +383,10 @@ class SessionManager(object):
'''Replace the daemon URL.''' '''Replace the daemon URL.'''
daemon_url = daemon_url or self.env.daemon_url daemon_url = daemon_url or self.env.daemon_url
try: try:
daemon_url = self.chain_state.set_daemon_url(daemon_url) self.daemon.set_url(daemon_url)
except Exception as e: except Exception as e:
raise RPCError(BAD_REQUEST, f'an error occured: {e!r}') raise RPCError(BAD_REQUEST, f'an error occured: {e!r}')
return f'now using daemon at {daemon_url}' return f'now using daemon at {self.daemon.logged_url()}'
async def rpc_stop(self): async def rpc_stop(self):
'''Shut down the server cleanly.''' '''Shut down the server cleanly.'''
@ -391,10 +407,54 @@ class SessionManager(object):
async def rpc_query(self, items, limit): async def rpc_query(self, items, limit):
'''Return a list of data about server peers.''' '''Return a list of data about server peers.'''
try: coin = self.env.coin
return await self.chain_state.query(items, limit) db = self.db
except Base58Error as e: lines = []
raise RPCError(BAD_REQUEST, e.args[0]) from None
def arg_to_hashX(arg):
try:
script = bytes.fromhex(arg)
lines.append(f'Script: {arg}')
return coin.hashX_from_script(script)
except ValueError:
pass
try:
hashX = coin.address_to_hashX(arg)
except Base58Error as e:
lines.append(e.args[0])
return None
lines.append(f'Address: {arg}')
return hashX
for arg in args:
hashX = arg_to_hashX(arg)
if not hashX:
continue
n = None
history = await db.limited_history(hashX, limit=limit)
for n, (tx_hash, height) in enumerate(history):
lines.append(f'History #{n:,d}: height {height:,d} '
f'tx_hash {hash_to_hex_str(tx_hash)}')
if n is None:
lines.append('No history found')
n = None
utxos = await db.all_utxos(hashX)
for n, utxo in enumerate(utxos, start=1):
lines.append(f'UTXO #{n:,d}: tx_hash '
f'{hash_to_hex_str(utxo.tx_hash)} '
f'tx_pos {utxo.tx_pos:,d} height '
f'{utxo.height:,d} value {utxo.value:,d}')
if n == limit:
break
if n is None:
lines.append('No UTXOs found')
balance = sum(utxo.value for utxo in utxos)
lines.append(f'Balance: {coin.decimal_value(balance):,f} '
f'{coin.SHORTNAME}')
return lines
async def rpc_sessions(self): async def rpc_sessions(self):
'''Return statistics about connected sessions.''' '''Return statistics about connected sessions.'''
@ -406,7 +466,7 @@ class SessionManager(object):
count: number of blocks to reorg count: number of blocks to reorg
''' '''
count = non_negative_integer(count) count = non_negative_integer(count)
if not self.chain_state.force_chain_reorg(count): if not self.bp.force_chain_reorg(count):
raise RPCError(BAD_REQUEST, 'still catching up with daemon') raise RPCError(BAD_REQUEST, 'still catching up with daemon')
return f'scheduled a reorg of {count:,d} blocks' return f'scheduled a reorg of {count:,d} blocks'
@ -454,31 +514,57 @@ class SessionManager(object):
'''The number of connections that we've sent something to.''' '''The number of connections that we've sent something to.'''
return len(self.sessions) return len(self.sessions)
async def daemon_request(self, method, *args):
'''Catch a DaemonError and convert it to an RPCError.'''
try:
return await getattr(self.daemon, method)(*args)
except DaemonError as e:
raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None
async def raw_header(self, height):
'''Return the binary header at the given height.'''
try:
return await self.db.raw_header(height)
except IndexError:
raise RPCError(BAD_REQUEST, f'height {height:,d} '
'out of range') from None
async def electrum_header(self, height):
'''Return the deserialized header at the given height.'''
electrum_header, _ = await self._electrum_and_raw_headers(height)
return electrum_header
async def broadcast_transaction(self, raw_tx):
hex_hash = await self.daemon.broadcast_transaction(raw_tx)
self.txs_sent += 1
return hex_hash
async def limited_history(self, hashX): async def limited_history(self, hashX):
'''A caching layer.''' '''A caching layer.'''
hc = self._history_cache hc = self.history_cache
if hashX not in hc: if hashX not in hc:
# History DoS limit. Each element of history is about 99 # History DoS limit. Each element of history is about 99
# bytes when encoded as JSON. This limits resource usage # bytes when encoded as JSON. This limits resource usage
# on bloated history requests, and uses a smaller divisor # on bloated history requests, and uses a smaller divisor
# so large requests are logged before refusing them. # so large requests are logged before refusing them.
limit = self.env.max_send // 97 limit = self.env.max_send // 97
hc[hashX] = await self.chain_state.limited_history(hashX, hc[hashX] = await self.db.limited_history(hashX, limit=limit)
limit=limit)
return hc[hashX] return hc[hashX]
async def _notify_sessions(self, height, touched): async def _notify_sessions(self, height, touched):
'''Notify sessions about height changes and touched addresses.''' '''Notify sessions about height changes and touched addresses.'''
# Invalidate our history cache for touched hashXs height_changed = height != self.notified_height
if height != self._hc_height: if height_changed:
self._hc_height = height # Paranoia: a reorg could race and leave db_height lower
hc = self._history_cache await self._refresh_hsub_results(min(height, self.db.db_height))
# Invalidate our history cache for touched hashXs
hc = self.history_cache
for hashX in set(hc).intersection(touched): for hashX in set(hc).intersection(touched):
del hc[hashX] del hc[hashX]
async with TaskGroup() as group: async with TaskGroup() as group:
for session in self.sessions: for session in self.sessions:
await group.spawn(session.notify(height, touched)) await group.spawn(session.notify(touched, height_changed))
def add_session(self, session): def add_session(self, session):
self.sessions.add(session) self.sessions.add(session)
@ -518,12 +604,12 @@ class SessionBase(ServerSession):
MAX_CHUNK_SIZE = 2016 MAX_CHUNK_SIZE = 2016
session_counter = itertools.count() session_counter = itertools.count()
def __init__(self, session_mgr, chain_state, mempool, peer_mgr, kind): def __init__(self, session_mgr, db, mempool, peer_mgr, kind):
connection = JSONRPCConnection(JSONRPCAutoDetect) connection = JSONRPCConnection(JSONRPCAutoDetect)
super().__init__(connection=connection) super().__init__(connection=connection)
self.logger = util.class_logger(__name__, self.__class__.__name__) self.logger = util.class_logger(__name__, self.__class__.__name__)
self.session_mgr = session_mgr self.session_mgr = session_mgr
self.chain_state = chain_state self.db = db
self.mempool = mempool self.mempool = mempool
self.peer_mgr = peer_mgr self.peer_mgr = peer_mgr
self.kind = kind # 'RPC', 'TCP' etc. self.kind = kind # 'RPC', 'TCP' etc.
@ -534,11 +620,12 @@ class SessionBase(ServerSession):
self.txs_sent = 0 self.txs_sent = 0
self.log_me = False self.log_me = False
self.bw_limit = self.env.bandwidth_limit self.bw_limit = self.env.bandwidth_limit
self.daemon_request = self.session_mgr.daemon_request
# Hijack the connection so we can log messages # Hijack the connection so we can log messages
self._receive_message_orig = self.connection.receive_message self._receive_message_orig = self.connection.receive_message
self.connection.receive_message = self.receive_message self.connection.receive_message = self.receive_message
async def notify(self, height, touched): async def notify(self, touched, height_changed):
pass pass
def peer_address_str(self, *, for_log=True): def peer_address_str(self, *, for_log=True):
@ -623,14 +710,12 @@ class ElectrumX(SessionBase):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.subscribe_headers = False self.subscribe_headers = False
self.subscribe_headers_raw = False self.subscribe_headers_raw = False
self.notified_height = None
self.connection.max_response_size = self.env.max_send self.connection.max_response_size = self.env.max_send
self.max_subs = self.env.max_session_subs self.max_subs = self.env.max_session_subs
self.hashX_subs = {} self.hashX_subs = {}
self.sv_seen = False self.sv_seen = False
self.mempool_statuses = {} self.mempool_statuses = {}
self.set_request_handlers(self.PROTOCOL_MIN) self.set_request_handlers(self.PROTOCOL_MIN)
self.db_height = self.chain_state.db_height
@classmethod @classmethod
def protocol_min_max_strings(cls): def protocol_min_max_strings(cls):
@ -662,96 +747,58 @@ class ElectrumX(SessionBase):
def protocol_version_string(self): def protocol_version_string(self):
return util.version_string(self.protocol_tuple) return util.version_string(self.protocol_tuple)
async def daemon_request(self, method, *args):
'''Catch a DaemonError and convert it to an RPCError.'''
try:
return await self.chain_state.daemon_request(method, args)
except DaemonError as e:
raise RPCError(DAEMON_ERROR, f'daemon error: {e!r}') from None
def sub_count(self): def sub_count(self):
return len(self.hashX_subs) return len(self.hashX_subs)
async def notify_touched(self, our_touched): async def notify(self, touched, height_changed):
changed = {}
for hashX in our_touched:
alias = self.hashX_subs[hashX]
status = await self.address_status(hashX)
changed[alias] = status
# Check mempool hashXs - the status is a function of the
# confirmed state of other transactions. Note: we cannot
# iterate over mempool_statuses as it changes size.
for hashX in tuple(self.mempool_statuses):
# Items can be evicted whilst await-ing below; False
# ensures such hashXs are notified
old_status = self.mempool_statuses.get(hashX, False)
status = await self.address_status(hashX)
if status != old_status:
alias = self.hashX_subs[hashX]
changed[alias] = status
for alias, status in changed.items():
if len(alias) == 64:
method = 'blockchain.scripthash.subscribe'
else:
method = 'blockchain.address.subscribe'
await self.send_notification(method, (alias, status))
if changed:
es = '' if len(changed) == 1 else 'es'
self.logger.info('notified of {:,d} address{}'
.format(len(changed), es))
async def notify(self, height, touched):
'''Notify the client about changes to touched addresses (from mempool '''Notify the client about changes to touched addresses (from mempool
updates or new blocks) and height. updates or new blocks) and height.
Return the set of addresses the session needs to be
asyncronously notified about. This can be empty if there are
possible mempool status updates.
Returns None if nothing needs to be notified asynchronously.
''' '''
height_changed = height != self.notified_height if height_changed and self.subscribe_headers:
if height_changed: args = (await self.subscribe_headers_result(), )
self.notified_height = height await self.send_notification('blockchain.headers.subscribe', args)
if self.subscribe_headers:
args = (await self.subscribe_headers_result(height), )
await self.send_notification('blockchain.headers.subscribe',
args)
touched = touched.intersection(self.hashX_subs) touched = touched.intersection(self.hashX_subs)
if touched or (height_changed and self.mempool_statuses): if touched or (height_changed and self.mempool_statuses):
await self.notify_touched(touched) changed = {}
async def raw_header(self, height): for hashX in touched:
'''Return the binary header at the given height.''' alias = self.hashX_subs[hashX]
try: status = await self.address_status(hashX)
return await self.chain_state.raw_header(height) changed[alias] = status
except IndexError:
raise RPCError(BAD_REQUEST, f'height {height:,d} '
'out of range') from None
async def electrum_header(self, height): # Check mempool hashXs - the status is a function of the
'''Return the deserialized header at the given height.''' # confirmed state of other transactions. Note: we cannot
raw_header = await self.raw_header(height) # iterate over mempool_statuses as it changes size.
return self.coin.electrum_header(raw_header, height) for hashX in tuple(self.mempool_statuses):
# Items can be evicted whilst await-ing status; False
# ensures such hashXs are notified
old_status = self.mempool_statuses.get(hashX, False)
status = await self.address_status(hashX)
if status != old_status:
alias = self.hashX_subs[hashX]
changed[alias] = status
async def subscribe_headers_result(self, height): for alias, status in changed.items():
'''The result of a header subscription for the given height.''' if len(alias) == 64:
if self.subscribe_headers_raw: method = 'blockchain.scripthash.subscribe'
raw_header = await self.raw_header(height) else:
return {'hex': raw_header.hex(), 'height': height} method = 'blockchain.address.subscribe'
return await self.electrum_header(height) await self.send_notification(method, (alias, status))
if changed:
es = '' if len(changed) == 1 else 'es'
self.logger.info(f'notified of {len(changed):,d} address{es}')
async def subscribe_headers_result(self):
'''The result of a header subscription or notification.'''
return self.session_mgr.hsub_results[self.subscribe_headers_raw]
async def _headers_subscribe(self, raw): async def _headers_subscribe(self, raw):
'''Subscribe to get headers of new blocks.''' '''Subscribe to get headers of new blocks.'''
self.subscribe_headers = True
self.subscribe_headers_raw = assert_boolean(raw) self.subscribe_headers_raw = assert_boolean(raw)
self.notified_height = self.db_height() self.subscribe_headers = True
return await self.subscribe_headers_result(self.notified_height) return await self.subscribe_headers_result()
async def headers_subscribe(self): async def headers_subscribe(self):
'''Subscribe to get raw headers of new blocks.''' '''Subscribe to get raw headers of new blocks.'''
@ -804,7 +851,7 @@ class ElectrumX(SessionBase):
async def hashX_listunspent(self, hashX): async def hashX_listunspent(self, hashX):
'''Return the list of UTXOs of a script hash, including mempool '''Return the list of UTXOs of a script hash, including mempool
effects.''' effects.'''
utxos = await self.chain_state.all_utxos(hashX) utxos = await self.db.all_utxos(hashX)
utxos = sorted(utxos) utxos = sorted(utxos)
utxos.extend(await self.mempool.unordered_UTXOs(hashX)) utxos.extend(await self.mempool.unordered_UTXOs(hashX))
spends = await self.mempool.potential_spends(hashX) spends = await self.mempool.potential_spends(hashX)
@ -861,7 +908,7 @@ class ElectrumX(SessionBase):
return await self.hashX_subscribe(hashX, address) return await self.hashX_subscribe(hashX, address)
async def get_balance(self, hashX): async def get_balance(self, hashX):
utxos = await self.chain_state.all_utxos(hashX) utxos = await self.db.all_utxos(hashX)
confirmed = sum(utxo.value for utxo in utxos) confirmed = sum(utxo.value for utxo in utxos)
unconfirmed = await self.mempool.balance_delta(hashX) unconfirmed = await self.mempool.balance_delta(hashX)
return {'confirmed': confirmed, 'unconfirmed': unconfirmed} return {'confirmed': confirmed, 'unconfirmed': unconfirmed}
@ -909,14 +956,14 @@ class ElectrumX(SessionBase):
return await self.hashX_subscribe(hashX, scripthash) return await self.hashX_subscribe(hashX, scripthash)
async def _merkle_proof(self, cp_height, height): async def _merkle_proof(self, cp_height, height):
max_height = self.db_height() max_height = self.db.db_height
if not height <= cp_height <= max_height: if not height <= cp_height <= max_height:
raise RPCError(BAD_REQUEST, raise RPCError(BAD_REQUEST,
f'require header height {height:,d} <= ' f'require header height {height:,d} <= '
f'cp_height {cp_height:,d} <= ' f'cp_height {cp_height:,d} <= '
f'chain height {max_height:,d}') f'chain height {max_height:,d}')
branch, root = await self.chain_state.header_branch_and_root( branch, root = await self.db.header_branch_and_root(cp_height + 1,
cp_height + 1, height) height)
return { return {
'branch': [hash_to_hex_str(elt) for elt in branch], 'branch': [hash_to_hex_str(elt) for elt in branch],
'root': hash_to_hex_str(root), 'root': hash_to_hex_str(root),
@ -927,7 +974,7 @@ class ElectrumX(SessionBase):
dictionary with a merkle proof.''' dictionary with a merkle proof.'''
height = non_negative_integer(height) height = non_negative_integer(height)
cp_height = non_negative_integer(cp_height) cp_height = non_negative_integer(cp_height)
raw_header_hex = (await self.raw_header(height)).hex() raw_header_hex = (await self.session_mgr.raw_header(height)).hex()
if cp_height == 0: if cp_height == 0:
return raw_header_hex return raw_header_hex
result = {'header': raw_header_hex} result = {'header': raw_header_hex}
@ -953,8 +1000,7 @@ class ElectrumX(SessionBase):
max_size = self.MAX_CHUNK_SIZE max_size = self.MAX_CHUNK_SIZE
count = min(count, max_size) count = min(count, max_size)
headers, count = await self.chain_state.read_headers(start_height, headers, count = await self.db.read_headers(start_height, count)
count)
result = {'hex': headers.hex(), 'count': count, 'max': max_size} result = {'hex': headers.hex(), 'count': count, 'max': max_size}
if count and cp_height: if count and cp_height:
last_height = start_height + count - 1 last_height = start_height + count - 1
@ -971,7 +1017,7 @@ class ElectrumX(SessionBase):
index = non_negative_integer(index) index = non_negative_integer(index)
size = self.coin.CHUNK_SIZE size = self.coin.CHUNK_SIZE
start_height = index * size start_height = index * size
headers, _ = await self.chain_state.read_headers(start_height, size) headers, _ = await self.db.read_headers(start_height, size)
return headers.hex() return headers.hex()
async def block_get_header(self, height): async def block_get_header(self, height):
@ -979,7 +1025,7 @@ class ElectrumX(SessionBase):
height: the header's height''' height: the header's height'''
height = non_negative_integer(height) height = non_negative_integer(height)
return await self.electrum_header(height) return await self.session_mgr.electrum_header(height)
def is_tor(self): def is_tor(self):
'''Try to detect if the connection is to a tor hidden service we are '''Try to detect if the connection is to a tor hidden service we are
@ -1042,7 +1088,7 @@ class ElectrumX(SessionBase):
number: the number of blocks number: the number of blocks
''' '''
number = non_negative_integer(number) number = non_negative_integer(number)
return await self.daemon_request('estimatefee', [number]) return await self.daemon_request('estimatefee', number)
async def ping(self): async def ping(self):
'''Serves as a connection keep-alive mechanism and for the client to '''Serves as a connection keep-alive mechanism and for the client to
@ -1091,15 +1137,14 @@ class ElectrumX(SessionBase):
raw_tx: the raw transaction as a hexadecimal string''' raw_tx: the raw transaction as a hexadecimal string'''
# This returns errors as JSON RPC errors, as is natural # This returns errors as JSON RPC errors, as is natural
try: try:
tx_hash = await self.chain_state.broadcast_transaction(raw_tx) hex_hash = await self.session_mgr.broadcast_transaction(raw_tx)
self.txs_sent += 1 self.txs_sent += 1
self.session_mgr.txs_sent += 1 self.logger.info(f'sent tx: {hex_hash}')
self.logger.info('sent tx: {}'.format(tx_hash)) return hex_hash
return tx_hash
except DaemonError as e: except DaemonError as e:
error, = e.args error, = e.args
message = error['message'] message = error['message']
self.logger.info(f'sendrawtransaction: {message}') self.logger.info(f'error sending transaction: {message}')
raise RPCError(BAD_REQUEST, 'the transaction was rejected by ' raise RPCError(BAD_REQUEST, 'the transaction was rejected by '
f'network rules.\n\n{message}\n[{raw_tx}]') f'network rules.\n\n{message}\n[{raw_tx}]')
@ -1135,7 +1180,7 @@ class ElectrumX(SessionBase):
tx_pos: index of transaction in tx_hashes to create branch for tx_pos: index of transaction in tx_hashes to create branch for
''' '''
hashes = [hex_str_to_hash(hash) for hash in tx_hashes] hashes = [hex_str_to_hash(hash) for hash in tx_hashes]
branch, root = self.chain_state.tx_branch_and_root(hashes, tx_pos) branch, root = self.db.merkle.branch_and_root(hashes, tx_pos)
branch = [hash_to_hex_str(hash) for hash in branch] branch = [hash_to_hex_str(hash) for hash in branch]
return branch return branch
@ -1264,9 +1309,9 @@ class DashElectrumX(ElectrumX):
'masternode.list': self.masternode_list 'masternode.list': self.masternode_list
}) })
async def notify(self, height, touched): async def notify(self, touched, height_changed):
'''Notify the client about changes in masternode list.''' '''Notify the client about changes in masternode list.'''
await super().notify(height, touched) await super().notify(touched, height_changed)
for mn in self.mns: for mn in self.mns:
status = await self.daemon_request('masternode_list', status = await self.daemon_request('masternode_list',
['status', mn]) ['status', mn])
@ -1358,7 +1403,7 @@ class DashElectrumX(ElectrumX):
# with the masternode information including the payment # with the masternode information including the payment
# position is returned. # position is returned.
cache = self.session_mgr.mn_cache cache = self.session_mgr.mn_cache
if not cache or self.session_mgr.mn_cache_height != self.db_height(): if not cache or self.session_mgr.mn_cache_height != self.db.db_height:
full_mn_list = await self.daemon_request('masternode_list', full_mn_list = await self.daemon_request('masternode_list',
['full']) ['full'])
mn_payment_queue = get_masternode_payment_queue(full_mn_list) mn_payment_queue = get_masternode_payment_queue(full_mn_list)
@ -1386,7 +1431,7 @@ class DashElectrumX(ElectrumX):
mn_list.append(mn_info) mn_list.append(mn_info)
cache.clear() cache.clear()
cache.extend(mn_list) cache.extend(mn_list)
self.session_mgr.mn_cache_height = self.db_height() self.session_mgr.mn_cache_height = self.db.db_height
# If payees is an empty list the whole masternode list is returned # If payees is an empty list the whole masternode list is returned
if payees: if payees:

View File

@ -1,5 +1,5 @@
import setuptools import setuptools
version = '1.8.3' version = '1.8.4'
setuptools.setup( setuptools.setup(
name='electrumX', name='electrumX',

489
tests/server/test_daemon.py Normal file
View File

@ -0,0 +1,489 @@
import aiohttp
import asyncio
import json
import logging
import pytest
from aiorpcx import (
JSONRPCv1, JSONRPCLoose, RPCError, ignore_after,
Request, Batch,
)
from electrumx.lib.coins import BitcoinCash, CoinError, Bitzeny
from electrumx.server.daemon import (
Daemon, FakeEstimateFeeDaemon, DaemonError
)
coin = BitcoinCash
# These should be full, canonical URLs
urls = ['http://rpc_user:rpc_pass@127.0.0.1:8332/',
'http://rpc_user:rpc_pass@192.168.0.1:8332/']
@pytest.fixture(params=[BitcoinCash, Bitzeny])
def daemon(request):
coin = request.param
return coin.DAEMON(coin, ','.join(urls))
class ResponseBase(object):
def __init__(self, headers, status):
self.headers = headers
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass
class JSONResponse(ResponseBase):
def __init__(self, result, msg_id, status=200):
super().__init__({'Content-Type': 'application/json'}, status)
self.result = result
self.msg_id = msg_id
async def json(self):
if isinstance(self.msg_id, int):
message = JSONRPCv1.response_message(self.result, self.msg_id)
else:
parts = [JSONRPCv1.response_message(item, msg_id)
for item, msg_id in zip(self.result, self.msg_id)]
message = JSONRPCv1.batch_message_from_parts(parts)
return json.loads(message.decode())
class HTMLResponse(ResponseBase):
def __init__(self, text, reason, status):
super().__init__({'Content-Type': 'text/html; charset=ISO-8859-1'},
status)
self._text = text
self.reason = reason
async def text(self):
return self._text
class ClientSessionBase(object):
def __enter__(self):
self.prior_class = aiohttp.ClientSession
aiohttp.ClientSession = lambda: self
def __exit__(self, exc_type, exc_value, traceback):
aiohttp.ClientSession = self.prior_class
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
pass
class ClientSessionGood(ClientSessionBase):
'''Imitate aiohttp for testing purposes.'''
def __init__(self, *triples):
self.triples = triples # each a (method, args, result)
self.count = 0
self.expected_url = urls[0]
def post(self, url, data=""):
assert url == self.expected_url
request, request_id = JSONRPCLoose.message_to_item(data.encode())
method, args, result = self.triples[self.count]
self.count += 1
if isinstance(request, Request):
assert request.method == method
assert request.args == args
return JSONResponse(result, request_id)
else:
assert isinstance(request, Batch)
for request, args in zip(request, args):
assert request.method == method
assert request.args == args
return JSONResponse(result, request_id)
class ClientSessionBadAuth(ClientSessionBase):
def post(self, url, data=""):
return HTMLResponse('', 'Unauthorized', 401)
class ClientSessionWorkQueueFull(ClientSessionGood):
def post(self, url, data=""):
self.post = super().post
return HTMLResponse('Work queue depth exceeded',
'Internal server error', 500)
class ClientSessionNoConnection(ClientSessionGood):
def __init__(self, *args):
self.args = args
async def __aenter__(self):
aiohttp.ClientSession = lambda: ClientSessionGood(*self.args)
raise aiohttp.ClientConnectionError
class ClientSessionPostError(ClientSessionGood):
def __init__(self, exception, *args):
self.exception = exception
self.args = args
def post(self, url, data=""):
aiohttp.ClientSession = lambda: ClientSessionGood(*self.args)
raise self.exception
class ClientSessionFailover(ClientSessionGood):
def post(self, url, data=""):
# If not failed over; simulate disconnecting
if url == self.expected_url:
raise aiohttp.ServerDisconnectedError
else:
self.expected_url = urls[1]
return super().post(url, data)
def in_caplog(caplog, message, count=1):
return sum(message in record.message
for record in caplog.records) == count
#
# Tests
#
def test_set_urls_bad():
with pytest.raises(CoinError):
Daemon(coin, '')
with pytest.raises(CoinError):
Daemon(coin, 'a')
def test_set_urls_one(caplog):
with caplog.at_level(logging.INFO):
daemon = Daemon(coin, urls[0])
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 1
logged_url = daemon.logged_url()
assert logged_url == '127.0.0.1:8332/'
assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)')
def test_set_urls_two(caplog):
with caplog.at_level(logging.INFO):
daemon = Daemon(coin, ','.join(urls))
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 2
logged_url = daemon.logged_url()
assert logged_url == '127.0.0.1:8332/'
assert in_caplog(caplog, f'daemon #1 at {logged_url} (current)')
assert in_caplog(caplog, 'daemon #2 at 192.168.0.1:8332')
def test_set_urls_short():
no_prefix_urls = ['/'.join(part for part in url.split('/')[2:])
for url in urls]
daemon = Daemon(coin, ','.join(no_prefix_urls))
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 2
no_slash_urls = [url[:-1] for url in urls]
daemon = Daemon(coin, ','.join(no_slash_urls))
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 2
no_port_urls = [url[:url.rfind(':')] for url in urls]
daemon = Daemon(coin, ','.join(no_port_urls))
assert daemon.current_url() == urls[0]
assert len(daemon.urls) == 2
def test_failover_good(caplog):
daemon = Daemon(coin, ','.join(urls))
with caplog.at_level(logging.INFO):
result = daemon.failover()
assert result is True
assert daemon.current_url() == urls[1]
logged_url = daemon.logged_url()
assert in_caplog(caplog, f'failing over to {logged_url}')
# And again
result = daemon.failover()
assert result is True
assert daemon.current_url() == urls[0]
def test_failover_fail(caplog):
daemon = Daemon(coin, urls[0])
with caplog.at_level(logging.INFO):
result = daemon.failover()
assert result is False
assert daemon.current_url() == urls[0]
assert not in_caplog(caplog, f'failing over')
@pytest.mark.asyncio
async def test_height(daemon):
assert daemon.cached_height() is None
height = 300
with ClientSessionGood(('getblockcount', [], height)):
assert await daemon.height() == height
assert daemon.cached_height() == height
@pytest.mark.asyncio
async def test_broadcast_transaction(daemon):
raw_tx = 'deadbeef'
tx_hash = 'hash'
with ClientSessionGood(('sendrawtransaction', [raw_tx], tx_hash)):
assert await daemon.broadcast_transaction(raw_tx) == tx_hash
@pytest.mark.asyncio
async def test_relayfee(daemon):
response = {"relayfee": sats, "other:": "cruft"}
with ClientSessionGood(('getnetworkinfo', [], response)):
assert await daemon.getnetworkinfo() == response
@pytest.mark.asyncio
async def test_relayfee(daemon):
if isinstance(daemon, FakeEstimateFeeDaemon):
sats = daemon.coin.ESTIMATE_FEE
else:
sats = 2
response = {"relayfee": sats, "other:": "cruft"}
with ClientSessionGood(('getnetworkinfo', [], response)):
assert await daemon.relayfee() == sats
@pytest.mark.asyncio
async def test_mempool_hashes(daemon):
hashes = ['hex_hash1', 'hex_hash2']
with ClientSessionGood(('getrawmempool', [], hashes)):
assert await daemon.mempool_hashes() == hashes
@pytest.mark.asyncio
async def test_deserialised_block(daemon):
block_hash = 'block_hash'
result = {'some': 'mess'}
with ClientSessionGood(('getblock', [block_hash, True], result)):
assert await daemon.deserialised_block(block_hash) == result
@pytest.mark.asyncio
async def test_estimatefee(daemon):
method_not_found = RPCError(JSONRPCv1.METHOD_NOT_FOUND, 'nope')
if isinstance(daemon, FakeEstimateFeeDaemon):
result = daemon.coin.ESTIMATE_FEE
else:
result = -1
with ClientSessionGood(
('estimatesmartfee', [], method_not_found),
('estimatefee', [2], result)
):
assert await daemon.estimatefee(2) == result
@pytest.mark.asyncio
async def test_estimatefee_smart(daemon):
bad_args = RPCError(JSONRPCv1.INVALID_ARGS, 'bad args')
if isinstance(daemon, FakeEstimateFeeDaemon):
return
rate = 0.0002
result = {'feerate': rate}
with ClientSessionGood(
('estimatesmartfee', [], bad_args),
('estimatesmartfee', [2], result)
):
assert await daemon.estimatefee(2) == rate
# Test the rpc_available_cache is used
with ClientSessionGood(('estimatesmartfee', [2], result)):
assert await daemon.estimatefee(2) == rate
@pytest.mark.asyncio
async def test_getrawtransaction(daemon):
hex_hash = 'deadbeef'
simple = 'tx_in_hex'
verbose = {'hex': hex_hash, 'other': 'cruft'}
# Test False is converted to 0 - old daemon's reject False
with ClientSessionGood(('getrawtransaction', [hex_hash, 0], simple)):
assert await daemon.getrawtransaction(hex_hash) == simple
# Test True is converted to 1
with ClientSessionGood(('getrawtransaction', [hex_hash, 1], verbose)):
assert await daemon.getrawtransaction(
hex_hash, True) == verbose
# Batch tests
@pytest.mark.asyncio
async def test_empty_send(daemon):
first = 5
count = 0
with ClientSessionGood(('getblockhash', [], [])):
assert await daemon.block_hex_hashes(first, count) == []
@pytest.mark.asyncio
async def test_block_hex_hashes(daemon):
first = 5
count = 3
hashes = [f'hex_hash{n}' for n in range(count)]
with ClientSessionGood(('getblockhash',
[[n] for n in range(first, first + count)],
hashes)):
assert await daemon.block_hex_hashes(first, count) == hashes
@pytest.mark.asyncio
async def test_raw_blocks(daemon):
count = 3
hex_hashes = [f'hex_hash{n}' for n in range(count)]
args_list = [[hex_hash, False] for hex_hash in hex_hashes]
iterable = (hex_hash for hex_hash in hex_hashes)
blocks = ["00", "019a", "02fe"]
blocks_raw = [bytes.fromhex(block) for block in blocks]
with ClientSessionGood(('getblock', args_list, blocks)):
assert await daemon.raw_blocks(iterable) == blocks_raw
@pytest.mark.asyncio
async def test_get_raw_transactions(daemon):
hex_hashes = ['deadbeef0', 'deadbeef1']
args_list = [[hex_hash, 0] for hex_hash in hex_hashes]
raw_txs_hex = ['fffefdfc', '0a0b0c0d']
raw_txs = [bytes.fromhex(raw_tx) for raw_tx in raw_txs_hex]
# Test 0 - old daemon's reject False
with ClientSessionGood(('getrawtransaction', args_list, raw_txs_hex)):
assert await daemon.getrawtransactions(hex_hashes) == raw_txs
# Test one error
tx_not_found = RPCError(-1, 'some error message')
results = ['ff0b7d', tx_not_found]
raw_txs = [bytes.fromhex(results[0]), None]
with ClientSessionGood(('getrawtransaction', args_list, results)):
assert await daemon.getrawtransactions(hex_hashes) == raw_txs
# Other tests
@pytest.mark.asyncio
async def test_bad_auth(daemon, caplog):
with pytest.raises(DaemonError) as e:
with ClientSessionBadAuth():
await daemon.height()
assert "Unauthorized" in e.value.args[0]
assert in_caplog(caplog, "Unauthorized")
@pytest.mark.asyncio
async def test_workqueue_depth(daemon, caplog):
daemon.init_retry = 0.01
height = 125
with caplog.at_level(logging.INFO):
with ClientSessionWorkQueueFull(('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "work queue full")
assert in_caplog(caplog, "running normally")
@pytest.mark.asyncio
async def test_connection_error(daemon, caplog):
height = 100
daemon.init_retry = 0.01
with caplog.at_level(logging.INFO):
with ClientSessionNoConnection(('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "connection problem - is your daemon running?")
assert in_caplog(caplog, "connection restored")
@pytest.mark.asyncio
async def test_timeout_error(daemon, caplog):
height = 100
daemon.init_retry = 0.01
with caplog.at_level(logging.INFO):
with ClientSessionPostError(asyncio.TimeoutError,
('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "timeout error")
@pytest.mark.asyncio
async def test_disconnected(daemon, caplog):
height = 100
daemon.init_retry = 0.01
with caplog.at_level(logging.INFO):
with ClientSessionPostError(aiohttp.ServerDisconnectedError,
('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "disconnected")
assert in_caplog(caplog, "connection restored")
@pytest.mark.asyncio
async def test_warming_up(daemon, caplog):
warming_up = RPCError(-28, 'reading block index')
height = 100
daemon.init_retry = 0.01
with caplog.at_level(logging.INFO):
with ClientSessionGood(
('getblockcount', [], warming_up),
('getblockcount', [], height)
):
assert await daemon.height() == height
assert in_caplog(caplog, "starting up checking blocks")
assert in_caplog(caplog, "running normally")
@pytest.mark.asyncio
async def test_warming_up_batch(daemon, caplog):
warming_up = RPCError(-28, 'reading block index')
first = 5
count = 1
daemon.init_retry = 0.01
hashes = ['hex_hash5']
with caplog.at_level(logging.INFO):
with ClientSessionGood(('getblockhash', [[first]], [warming_up]),
('getblockhash', [[first]], hashes)):
assert await daemon.block_hex_hashes(first, count) == hashes
assert in_caplog(caplog, "starting up checking blocks")
assert in_caplog(caplog, "running normally")
@pytest.mark.asyncio
async def test_failover(daemon, caplog):
height = 100
daemon.init_retry = 0.01
daemon.max_retry = 0.04
with caplog.at_level(logging.INFO):
with ClientSessionFailover(('getblockcount', [], height)):
await daemon.height() == height
assert in_caplog(caplog, "disconnected", 3)
assert in_caplog(caplog, "failing over")
assert in_caplog(caplog, "connection restored")