From 7523735f99c5d564a9ab5ac0e077ccf21a9a2f6e Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Sun, 13 Nov 2016 14:43:13 +0900 Subject: [PATCH] Split out server and session management --- server/protocol.py | 335 +++++++++++++++++++++++---------------------- 1 file changed, 170 insertions(+), 165 deletions(-) diff --git a/server/protocol.py b/server/protocol.py index f3b6e2a..5a17257 100644 --- a/server/protocol.py +++ b/server/protocol.py @@ -27,28 +27,43 @@ from server.version import VERSION class BlockServer(BlockProcessor): - '''Like BlockProcessor but also starts servers when caught up.''' + '''Like BlockProcessor but also has a server manager and starts + servers when caught up.''' def __init__(self, env): super().__init__(env) - self.servers = [] - self.irc = IRC(env) + self.server_mgr = ServerManager(self, env) async def caught_up(self, mempool_hashes): await super().caught_up(mempool_hashes) - if not self.servers: - await self.start_servers() - if self.env.irc: - self.logger.info('starting IRC coroutine') - asyncio.ensure_future(self.irc.start()) - else: - self.logger.info('IRC disabled') - ElectrumX.notify(self.height, self.touched) + self.server_mgr.notify(self.height, self.touched) - async def start_server(self, class_name, kind, host, port, *, ssl=None): + def stop(self): + '''Close the listening servers.''' + self.server_mgr.stop() + + +class ServerManager(LoggedClass): + '''Manages the servers.''' + + AsyncTask = namedtuple('AsyncTask', 'session job') + + def __init__(self, bp, env): + super().__init__() + self.bp = bp + self.env = env + self.servers = [] + self.irc = IRC(env) + self.sessions = set() + self.tasks = asyncio.Queue() + self.current_task = None + + async def start_server(self, kind, *args, **kw_args): loop = asyncio.get_event_loop() - protocol = partial(class_name, self.env, kind) - server = loop.create_server(protocol, host, port, ssl=ssl) + protocol_class = LocalRPC if kind == 'RPC' else ElectrumX + protocol = partial(protocol_class, self, self.bp, self.env, kind) + server = loop.create_server(protocol, *args, **kw_args) + try: self.servers.append(await server) except asyncio.CancelledError: @@ -61,45 +76,50 @@ class BlockServer(BlockProcessor): .format(kind, host, port)) async def start_servers(self): - '''Start listening on RPC, TCP and SSL ports. + '''Connect to IRC and start listening for incoming connections. - Does not start a server if the port wasn't specified. + Only connect to IRC if enabled. Start listening on RCP, TCP + and SSL ports only if the port wasn pecified. ''' env = self.env - Session.init(self, self.daemon, self.coin) + if env.rpc_port is not None: - await self.start_server(LocalRPC, 'RPC', 'localhost', env.rpc_port) + await self.start_server('RPC', 'localhost', env.rpc_port) if env.tcp_port is not None: - await self.start_server(ElectrumX, 'TCP', env.host, env.tcp_port) + await self.start_server('TCP', env.host, env.tcp_port) if env.ssl_port is not None: # FIXME: update if we want to require Python >= 3.5.3 sslc = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile) - await self.start_server(ElectrumX, 'SSL', env.host, - env.ssl_port, ssl=sslc) + await self.start_server('SSL', env.host, env.ssl_port, ssl=sslc) + + asyncio.ensure_future(self.run_tasks()) + + if env.irc: + self.logger.info('starting IRC coroutine') + asyncio.ensure_future(self.irc.start()) + else: + self.logger.info('IRC disabled') + + async def notify(self, height, touched): + '''Notify electrum clients about height changes and touched addresses. + + Start listening if not yet listening. + ''' + if not self.servers: + await self.start_servers() + + sessions = [session for session in self.sessions + if isinstance(session, ElectrumX)] + self.ElectrumX.notify(sessions, height, touched) def stop(self): '''Close the listening servers.''' for server in self.servers: server.close() - def irc_peers(self): - return self.irc.peers - - -AsyncTask = namedtuple('AsyncTask', 'session job') - -class SessionManager(LoggedClass): - - def __init__(self): - super().__init__() - self.sessions = set() - self.tasks = asyncio.Queue() - self.current_task = None - asyncio.ensure_future(self.run_tasks()) - def add_session(self, session): assert session not in self.sessions self.sessions.add(session) @@ -113,7 +133,7 @@ class SessionManager(LoggedClass): def add_task(self, session, job): assert session in self.sessions task = asyncio.ensure_future(job) - self.tasks.put_nowait(AsyncTask(session, task)) + self.tasks.put_nowait(self.AsyncTask(session, task)) async def run_tasks(self): '''Asynchronously run through the task queue.''' @@ -133,22 +153,55 @@ class SessionManager(LoggedClass): finally: self.current_task = None + def irc_peers(self): + return self.irc.peers + + def session_count(self): + return len(self.manager.sessions) + + def info(self): + '''Returned in the RPC 'getinfo' call.''' + address_count = sum(len(session.hash168s) + for session in self.sessions + if isinstance(session, ElectrumX)) + return { + 'blocks': self.bp.height, + 'peers': len(self.irc_peers()), + 'sessions': self.session_count(), + 'watched': address_count, + 'cached': 0, + } + + def sessions_info(self): + '''Returned to the RPC 'sessions' call.''' + now = time.time() + return [(session.kind, + session.peername(), + len(session.hash168s), + 'RPC' if isinstance(session, LocalRPC) else session.client, + now - session.start) + for session in self.sessions] + class Session(JSONRPC): '''Base class of ElectrumX JSON session protocols.''' - def __init__(self, env, kind): + def __init__(self, manager, bp, env, kind): super().__init__() + self.manager = manager + self.bp = bp + self.env = env + self.daemon = bp.daemon + self.coin = bp.coin + self.kind = kind self.hash168s = set() self.client = 'unknown' - self.env = env - self.kind = kind def connection_made(self, transport): '''Handle an incoming client connection.''' super().connection_made(transport) self.logger.info('connection from {}'.format(self.peername())) - self.SESSION_MGR.add_session(self) + self.manager.add_session(self) def connection_lost(self, exc): '''Handle client disconnection.''' @@ -158,7 +211,7 @@ class Session(JSONRPC): 'Sent {:,d} bytes in {:,d} messages {:,d} errors' .format(self.peername(), self.send_size, self.send_count, self.error_count)) - self.SESSION_MGR.remove_session(self) + self.maanger.remove_session(self) def method_handler(self, method): '''Return the handler that will handle the RPC method.''' @@ -166,14 +219,13 @@ class Session(JSONRPC): def on_json_request(self, request): '''Queue the request for asynchronous handling.''' - self.SESSION_MGR.add_task(self, self.handle_json_request(request)) + self.manager.add_task(self, self.handle_json_request(request)) def peername(self): info = self.peer_info() return 'unknown' if not info else '{}:{}'.format(info[0], info[1]) - @classmethod - def tx_hash_from_param(cls, param): + def tx_hash_from_param(self, param): '''Raise an RPCError if the parameter is not a valid transaction hash.''' if isinstance(param, str) and len(param) == 64: @@ -185,17 +237,15 @@ class Session(JSONRPC): raise RPCError('parameter should be a transaction hash: {}' .format(param)) - @classmethod - def hash168_from_param(cls, param): + def hash168_from_param(self, param): if isinstance(param, str): try: - return cls.COIN.address_to_hash168(param) + return self.coin.address_to_hash168(param) except: pass raise RPCError('parameter should be a valid address: {}'.format(param)) - @classmethod - def non_negative_integer_from_param(cls, param): + def non_negative_integer_from_param(self, param): try: param = int(param) except ValueError: @@ -207,60 +257,28 @@ class Session(JSONRPC): raise RPCError('param should be a non-negative integer: {}' .format(param)) - @classmethod - def extract_hash168(cls, params): + def extract_hash168(self, params): if len(params) == 1: - return cls.hash168_from_param(params[0]) + return self.hash168_from_param(params[0]) raise RPCError('params should contain a single address: {}' .format(params)) - @classmethod - def extract_non_negative_integer(cls, params): + def extract_non_negative_integer(self, params): if len(params) == 1: - return cls.non_negative_integer_from_param(params[0]) + return self.non_negative_integer_from_param(params[0]) raise RPCError('params should contain a non-negative integer: {}' .format(params)) - @classmethod - def require_empty_params(cls, params): + def require_empty_params(self, params): if params: raise RPCError('params should be empty: {}'.format(params)) - @classmethod - def init(cls, block_processor, daemon, coin): - cls.BLOCK_PROCESSOR = block_processor - cls.DAEMON = daemon - cls.COIN = coin - cls.SESSION_MGR = SessionManager() - - @classmethod - def irc_peers(cls): - return cls.BLOCK_PROCESSOR.irc_peers() - - @classmethod - def height(cls): - '''Return the current height.''' - return cls.BLOCK_PROCESSOR.height - - @classmethod - def electrum_header(cls, height=None): - '''Return the binary header at the given height.''' - if not 0 <= height <= cls.height(): - raise RPCError('height {:,d} out of range'.format(height)) - header = cls.BLOCK_PROCESSOR.read_headers(height, 1) - return cls.COIN.electrum_header(header, height) - - @classmethod - def current_electrum_header(cls): - '''Used as response to a headers subscription request.''' - return cls.electrum_header(cls.height()) - class ElectrumX(Session): '''A TCP server that handles incoming Electrum connections.''' - def __init__(self, env, kind): - super().__init__(env, kind) + def __init__(self, *args): + super().__init__(*args) self.subscribe_headers = False self.subscribe_height = False self.notified_height = None @@ -280,49 +298,57 @@ class ElectrumX(Session): for suffix in suffixes.split()} @classmethod - def watched_address_count(cls): - sessions = cls.SESSION_MGR.sessions - return sum(len(session.hash168s) for session in sessions) - - @classmethod - def notify(cls, height, touched): - '''Notify electrum clients about height changes and touched - addresses.''' - headers_payload = json_notification_payload( - 'blockchain.headers.subscribe', - (cls.electrum_header(height), ), - ) - height_payload = json_notification_payload( - 'blockchain.numblocks.subscribe', - (height, ), - ) - hash168_to_address = cls.COIN.hash168_to_address - - for session in cls.SESSION_MGR.sessions: - if not isinstance(session, ElectrumX): - continue + def notify(cls, sessions, height, touched): + headers_payload = height_payload = None + for session in sessions: if height != session.notified_height: session.notified_height = height if session.subscribe_headers: + if headers_payload is None: + headers_payload = json_notification_payload( + 'blockchain.headers.subscribe', + (session.electrum_header(height), ), + ) session.send_json(headers_payload) + if session.subscribe_height: + if height_payload is None: + height_payload = json_notification_payload( + 'blockchain.numblocks.subscribe', + (height, ), + ) session.send_json(height_payload) + hash168_to_address = session.coin.hash168_to_address for hash168 in session.hash168s.intersection(touched): address = hash168_to_address(hash168) - status = cls.address_status(hash168) + status = session.address_status(hash168) payload = json_notification_payload( 'blockchain.address.subscribe', (address, status)) session.send_json(payload) - @classmethod - def address_status(cls, hash168): + def height(self): + '''Return the block processor's current height.''' + return self.bp.height + + def current_electrum_header(self): + '''Used as response to a headers subscription request.''' + return self.electrum_header(self.height()) + + def electrum_header(self, height): + '''Return the binary header at the given height.''' + if not 0 <= height <= self.height(): + raise RPCError('height {:,d} out of range'.format(height)) + header = self.bp.read_headers(height, 1) + return self.coin.electrum_header(header, height) + + def address_status(self, hash168): '''Returns status as 32 bytes.''' # Note history is ordered and mempool unordered in electrum-server # For mempool, height is -1 if unconfirmed txins, otherwise 0 - history = cls.BLOCK_PROCESSOR.get_history(hash168) - mempool = cls.BLOCK_PROCESSOR.mempool_transactions(hash168) + history = self.bp.get_history(hash168) + mempool = self.bp.mempool_transactions(hash168) status = ''.join('{}:{:d}:'.format(hash_to_str(tx_hash), height) for tx_hash, height in history) @@ -332,11 +358,10 @@ class ElectrumX(Session): return sha256(status.encode()).hex() return None - @classmethod - async def tx_merkle(cls, tx_hash, height): + async def tx_merkle(self, tx_hash, height): '''tx_hash is a hex string.''' - hex_hashes = await cls.DAEMON.block_hex_hashes(height, 1) - block = await cls.DAEMON.deserialised_block(hex_hashes[0]) + hex_hashes = await self.daemon.block_hex_hashes(height, 1) + block = await self.daemon.deserialised_block(hex_hashes[0]) tx_hashes = block['tx'] # This will throw if the tx_hash is bad pos = tx_hashes.index(tx_hash) @@ -355,16 +380,11 @@ class ElectrumX(Session): return {"block_height": height, "merkle": merkle_branch, "pos": pos} - @classmethod - def height(cls): - return cls.BLOCK_PROCESSOR.height - - @classmethod - def get_history(cls, hash168): + def get_history(self, hash168): # Note history is ordered and mempool unordered in electrum-server # For mempool, height is -1 if unconfirmed txins, otherwise 0 - history = cls.BLOCK_PROCESSOR.get_history(hash168, limit=None) - mempool = cls.BLOCK_PROCESSOR.mempool_transactions(hash168) + history = self.bp.get_history(hash168, limit=None) + mempool = self.bp.mempool_transactions(hash168) conf = tuple({'tx_hash': hash_to_str(tx_hash), 'height': height} for tx_hash, height in history) @@ -372,24 +392,21 @@ class ElectrumX(Session): for tx_hash, fee, unconfirmed in mempool) return conf + unconf - @classmethod - def get_chunk(cls, index): + def get_chunk(self, index): '''Return header chunk as hex. Index is a non-negative integer.''' - chunk_size = cls.COIN.CHUNK_SIZE - next_height = cls.height() + 1 + chunk_size = self.coin.CHUNK_SIZE + next_height = self.height() + 1 start_height = min(index * chunk_size, next_height) count = min(next_height - start_height, chunk_size) - return cls.BLOCK_PROCESSOR.read_headers(start_height, count).hex() + return self.bp.read_headers(start_height, count).hex() - @classmethod - def get_balance(cls, hash168): - confirmed = cls.BLOCK_PROCESSOR.get_balance(hash168) - unconfirmed = cls.BLOCK_PROCESSOR.mempool_value(hash168) + def get_balance(self, hash168): + confirmed = self.bp.get_balance(hash168) + unconfirmed = self.bp.mempool_value(hash168) return {'confirmed': confirmed, 'unconfirmed': unconfirmed} - @classmethod - def list_unspent(cls, hash168): - utxos = cls.BLOCK_PROCESSOR.get_utxos_sorted(hash168) + def list_unspent(self, hash168): + utxos = self.bp.get_utxos_sorted(hash168) return tuple({'tx_hash': hash_to_str(utxo.tx_hash), 'tx_pos': utxo.tx_pos, 'height': utxo.height, 'value': utxo.value} @@ -431,7 +448,7 @@ class ElectrumX(Session): return self.electrum_header(height) async def estimatefee(self, params): - return await self.DAEMON.estimatefee(params) + return await self.daemon.estimatefee(params) async def headers_subscribe(self, params): self.require_empty_params(params) @@ -447,7 +464,7 @@ class ElectrumX(Session): '''The minimum fee a low-priority tx must pay in order to be accepted to the daemon's memory pool.''' self.require_empty_params(params) - return await self.DAEMON.relayfee() + return await self.daemon.relayfee() async def transaction_broadcast(self, params): '''Pass through the parameters to the daemon. @@ -458,7 +475,7 @@ class ElectrumX(Session): user interface job here. ''' try: - tx_hash = await self.DAEMON.sendrawtransaction(params) + tx_hash = await self.daemon.sendrawtransaction(params) self.logger.info('sent tx: {}'.format(tx_hash)) return tx_hash except DaemonError as e: @@ -483,7 +500,7 @@ class ElectrumX(Session): # in anticipation it might be dropped in the future. if 1 <= len(params) <= 2: tx_hash = self.tx_hash_from_param(params[0]) - return await self.DAEMON.getrawtransaction(tx_hash) + return await self.daemon.getrawtransaction(tx_hash) raise RPCError('params wrong length: {}'.format(params)) @@ -500,9 +517,9 @@ class ElectrumX(Session): tx_hash = self.tx_hash_from_param(params[0]) index = self.non_negative_integer_from_param(params[1]) tx_hash = hex_str_to_hash(tx_hash) - hash168 = self.BLOCK_PROCESSOR.get_utxo_hash168(tx_hash, index) + hash168 = self.bp.get_utxo_hash168(tx_hash, index) if hash168: - return self.COIN.hash168_to_address(hash168) + return self.coin.hash168_to_address(hash168) return None raise RPCError('params should contain a transaction hash and index') @@ -537,7 +554,7 @@ class ElectrumX(Session): subscription. ''' self.require_empty_params(params) - return list(self.irc_peers().values()) + return list(self.manager.irc_peers().values()) async def version(self, params): '''Return the server version as a string.''' @@ -550,34 +567,22 @@ class ElectrumX(Session): class LocalRPC(Session): '''A local TCP RPC server for querying status.''' - def __init__(self, env, kind): - super().__init__(env, kind) + def __init__(self, *args): + super().__init__(*args) cmds = 'getinfo sessions numsessions peers numpeers'.split() self.handlers = {cmd: getattr(self, cmd) for cmd in cmds} async def getinfo(self, params): - return { - 'blocks': self.height(), - 'peers': len(self.irc_peers()), - 'sessions': len(self.SESSION_MGR.sessions), - 'watched': ElectrumX.watched_address_count(), - 'cached': 0, - } + return self.manager.info() async def sessions(self, params): - now = time.time() - return [(session.kind, - '' if session == self else session.peername(), - len(session.hash168s), - 'this RPC client' if session == self else session.client, - now - session.start) - for session in self.SESSION_MGR.sessions] + return self.manager.sessions_info() async def numsessions(self, params): - return len(self.SESSION_MGR.sessions) + return self.manager.session_count() async def peers(self, params): - return self.irc_peers() + return self.manager.irc_peers() async def numpeers(self, params): - return len(self.irc_peers()) + return len(self.manager.irc_peers())