From 3d11afbda22389c82c68119fac9ffb6921866f2c Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Sun, 30 Oct 2016 12:31:53 +0900 Subject: [PATCH] Enable servers --- server/block_processor.py | 40 ++++++++++++--------- server/controller.py | 74 ++++++++++++++++++++++----------------- server/env.py | 3 ++ server/protocol.py | 38 ++++++++++++-------- 4 files changed, 92 insertions(+), 63 deletions(-) diff --git a/server/block_processor.py b/server/block_processor.py index ba42a0e..6be4e7d 100644 --- a/server/block_processor.py +++ b/server/block_processor.py @@ -126,17 +126,18 @@ class BlockProcessor(LoggedClass): Coordinate backing up in case of chain reorganisations. ''' - def __init__(self, env, daemon): + def __init__(self, env, daemon, on_catchup=None): super().__init__() self.daemon = daemon + self.on_catchup = on_catchup # Meta self.utxo_MB = env.utxo_MB self.hist_MB = env.hist_MB self.next_cache_check = 0 self.coin = env.coin - self.caught_up = False + self.have_caught_up = False self.reorg_limit = env.reorg_limit # Chain state (initialize to genesis in case of new DB) @@ -192,6 +193,17 @@ class BlockProcessor(LoggedClass): else: return [self.start(), self.prefetcher.start()] + async def caught_up(self): + '''Call when we catch up to the daemon's height.''' + # Flush everything when in caught-up state as queries + # are performed on DB and not in-memory. + self.flush(True) + if not self.have_caught_up: + self.have_caught_up = True + self.logger.info('caught up to height {:,d}'.format(self.height)) + if self.on_catchup: + await self.on_catchup() + async def start(self): '''External entry point for block processing. @@ -199,32 +211,26 @@ class BlockProcessor(LoggedClass): shutdown. ''' try: - await self.advance_blocks() + # If we're caught up so the start servers immediately + if self.height == await self.daemon.height(): + await self.caught_up() + await self.wait_for_blocks() finally: self.flush(True) - async def advance_blocks(self): + async def wait_for_blocks(self): '''Loop forever processing blocks in the forward direction.''' while True: blocks = await self.prefetcher.get_blocks() for block in blocks: if not self.advance_block(block): await self.handle_chain_reorg() - self.caught_up = False + self.have_caught_up = False break await asyncio.sleep(0) # Yield - if self.height != self.daemon.cached_height(): - continue - - if not self.caught_up: - self.caught_up = True - self.logger.info('caught up to height {:,d}' - .format(self.height)) - - # Flush everything when in caught-up state as queries - # are performed on DB not in-memory - self.flush(True) + if self.height == self.daemon.cached_height(): + await self.caught_up() async def force_chain_reorg(self, to_genesis): try: @@ -360,7 +366,7 @@ class BlockProcessor(LoggedClass): def flush_state(self, batch): '''Flush chain state to the batch.''' - if self.caught_up: + if self.have_caught_up: self.first_sync = False now = time.time() self.wall_time += now - self.last_flush diff --git a/server/controller.py b/server/controller.py index 9d1cf64..a7af664 100644 --- a/server/controller.py +++ b/server/controller.py @@ -13,6 +13,7 @@ client-serving data such as histories. import asyncio import signal +import ssl import traceback from functools import partial @@ -35,51 +36,62 @@ class Controller(LoggedClass): self.loop = loop self.env = env self.daemon = Daemon(env.daemon_url) - self.block_processor = BlockProcessor(env, self.daemon) + self.block_processor = BlockProcessor(env, self.daemon, + on_catchup=self.start_servers) self.servers = [] self.sessions = set() self.addresses = {} - self.jobs = set() + self.jobs = asyncio.Queue() self.peers = {} def start(self): - '''Prime the event loop with asynchronous servers and jobs.''' - env = self.env - loop = self.loop - + '''Prime the event loop with asynchronous jobs.''' coros = self.block_processor.coros() - - if False: - self.start_servers() - coros.append(self.reap_jobs()) + coros.append(self.run_jobs()) for coro in coros: asyncio.ensure_future(coro) # Signal handlers for signame in ('SIGINT', 'SIGTERM'): - loop.add_signal_handler(getattr(signal, signame), - partial(self.on_signal, signame)) + self.loop.add_signal_handler(getattr(signal, signame), + partial(self.on_signal, signame)) + + async def start_servers(self): + '''Start listening on RPC, TCP and SSL ports. + + Does not start a server if the port wasn't specified. Does + nothing if servers are already running. + ''' + if self.servers: + return + + env = self.env + loop = self.loop - def start_servers(self): protocol = partial(LocalRPC, self) if env.rpc_port is not None: host = 'localhost' rpc_server = loop.create_server(protocol, host, env.rpc_port) - self.servers.append(loop.run_until_complete(rpc_server)) + self.servers.append(await rpc_server) self.logger.info('RPC server listening on {}:{:d}' .format(host, env.rpc_port)) protocol = partial(ElectrumX, self, self.daemon, env) if env.tcp_port is not None: tcp_server = loop.create_server(protocol, env.host, env.tcp_port) - self.servers.append(loop.run_until_complete(tcp_server)) + self.servers.append(await tcp_server) self.logger.info('TCP server listening on {}:{:d}' .format(env.host, env.tcp_port)) if env.ssl_port is not None: - ssl_server = loop.create_server(protocol, env.host, env.ssl_port) - self.servers.append(loop.run_until_complete(ssl_server)) + # FIXME: update if we want to require Python >= 3.5.3 + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + ssl_context.load_cert_chain(env.ssl_certfile, + keyfile=env.ssl_keyfile) + ssl_server = loop.create_server(protocol, env.host, env.ssl_port, + ssl=ssl_context) + self.servers.append(await ssl_server) self.logger.info('SSL server listening on {}:{:d}' .format(env.host, env.ssl_port)) @@ -96,30 +108,28 @@ class Controller(LoggedClass): task.cancel() def add_session(self, session): + '''Add a session representing one incoming connection.''' self.sessions.add(session) def remove_session(self, session): + '''Remove a session.''' self.sessions.remove(session) def add_job(self, coro): '''Queue a job for asynchronous processing.''' - self.jobs.add(asyncio.ensure_future(coro)) + self.jobs.put_nowait(coro) - async def reap_jobs(self): + async def run_jobs(self): + '''Asynchronously run through the job queue.''' while True: - jobs = set() - for job in self.jobs: - if job.done(): - try: - job.result() - except Exception as e: - traceback.print_exc() - else: - jobs.add(job) - self.logger.info('reaped {:d} jobs, {:d} jobs pending' - .format(len(self.jobs) - len(jobs), len(jobs))) - self.jobs = jobs - await asyncio.sleep(5) + job = await self.jobs.get() + try: + await job + except asyncio.CancelledError: + raise + except Exception: + # Getting here should probably be considered a bug and fixed + traceback.print_exc() def address_status(self, hash168): '''Returns status as 32 bytes.''' diff --git a/server/env.py b/server/env.py index e27f570..e3e8bf7 100644 --- a/server/env.py +++ b/server/env.py @@ -34,6 +34,9 @@ class Env(LoggedClass): # Server stuff self.tcp_port = self.integer('TCP_PORT', None) self.ssl_port = self.integer('SSL_PORT', None) + if self.ssl_port: + self.ssl_certfile = self.required('SSL_CERTFILE') + self.ssl_keyfile = self.required('SSL_KEYFILE') self.rpc_port = self.integer('RPC_PORT', 8000) self.max_subscriptions = self.integer('MAX_SUBSCRIPTIONS', 10000) self.banner_file = self.default('BANNER_FILE', None) diff --git a/server/protocol.py b/server/protocol.py index ebd8df7..2d11b7a 100644 --- a/server/protocol.py +++ b/server/protocol.py @@ -24,6 +24,14 @@ class Error(Exception): class JSONRPC(asyncio.Protocol, LoggedClass): + '''Base class that manages a JSONRPC connection. + + When a request comes in for an RPC method M, then a member + function handle_M is called with the request params array, except + that periods in M are replaced with underscores. So a RPC call + for method 'blockchain.estimatefee' will be passed to + handle_blockchain_estimatefee. + ''' def __init__(self, controller): super().__init__() @@ -31,39 +39,41 @@ class JSONRPC(asyncio.Protocol, LoggedClass): self.parts = [] def connection_made(self, transport): + '''Handle an incoming client connection.''' self.transport = transport - peername = transport.get_extra_info('peername') - self.logger.info('connection from {}'.format(peername)) + self.peername = transport.get_extra_info('peername') + self.logger.info('connection from {}'.format(self.peername)) self.controller.add_session(self) def connection_lost(self, exc): - self.logger.info('disconnected') + '''Handle client disconnection.''' + self.logger.info('disconnected: {}'.format(self.peername)) self.controller.remove_session(self) def data_received(self, data): + '''Handle incoming data (synchronously). + + Requests end in newline characters. Pass complete requests to + decode_message for handling. + ''' while True: npos = data.find(ord('\n')) if npos == -1: + self.parts.append(data) break tail, data = data[:npos], data[npos + 1:] - parts = self.parts - self.parts = [] + parts, self.parts = self.parts, [] parts.append(tail) self.decode_message(b''.join(parts)) - if data: - self.parts.append(data) - def decode_message(self, message): - '''Message is a binary message.''' + '''Decode a binary message and queue it for asynchronous handling.''' try: message = json.loads(message.decode()) except Exception as e: - self.logger.info('caught exception decoding message'.format(e)) - return - - job = self.request_handler(message) - self.controller.add_job(job) + self.logger.info('error decoding JSON message'.format(e)) + else: + self.controller.add_job(self.request_handler(message)) async def request_handler(self, request): '''Called asynchronously.'''