electrumx/server/controller.py
Neil Booth 15051124af Make flushes and reorgs async
Apart from the flush on shutdown and the flush when caught up,
neither of which matter, this makes flushes asynchronous.

Also, block processing for reorgs is now asynchronous.

This also removes the FORCE_REORG debug envvar; I want to
put that into the RPC interface.

Closes #102
2017-01-09 16:15:17 +09:00

562 lines
22 KiB
Python

# Copyright (c) 2016-2017, Neil Booth
#
# All rights reserved.
#
# See the file "LICENCE" for information about the copyright
# and warranty status of this software.
import asyncio
import json
import os
import ssl
import time
from bisect import bisect_left
from collections import defaultdict
from functools import partial
import pylru
from lib.jsonrpc import JSONRPC, RequestBase
import lib.util as util
from server.block_processor import BlockProcessor
from server.irc import IRC
from server.session import LocalRPC, ElectrumX
from server.mempool import MemPool
class Controller(util.LoggedClass):
'''Manages the client servers, a mempool, and a block processor.
Servers are started immediately the block processor first catches
up with the daemon.
'''
BANDS = 5
CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4)
class NotificationRequest(RequestBase):
def __init__(self, height, touched):
super().__init__(1)
self.height = height
self.touched = touched
async def process(self, session):
self.remaining = 0
await session.notify(self.height, self.touched)
def __init__(self, env):
super().__init__()
self.loop = asyncio.get_event_loop()
self.start = time.time()
self.bp = BlockProcessor(env)
self.mempool = MemPool(self.bp.daemon, env.coin, self.bp)
self.irc = IRC(env)
self.env = env
self.servers = {}
# Map of session to the key of its list in self.groups
self.sessions = {}
self.groups = defaultdict(list)
self.txs_sent = 0
self.next_log_sessions = 0
self.state = self.CATCHING_UP
self.max_sessions = env.max_sessions
self.low_watermark = self.max_sessions * 19 // 20
self.max_subs = env.max_subs
self.subscription_count = 0
self.next_stale_check = 0
self.history_cache = pylru.lrucache(256)
self.header_cache = pylru.lrucache(8)
self.queue = asyncio.PriorityQueue()
self.delayed_sessions = []
self.next_queue_id = 0
self.height = 0
self.futures = []
env.max_send = max(350000, env.max_send)
self.setup_bands()
async def mempool_transactions(self, hashX):
'''Generate (hex_hash, tx_fee, unconfirmed) tuples for mempool
entries for the hashX.
unconfirmed is True if any txin is unconfirmed.
'''
return await self.mempool.transactions(hashX)
def mempool_value(self, hashX):
'''Return the unconfirmed amount in the mempool for hashX.
Can be positive or negative.
'''
return self.mempool.value(hashX)
def sent_tx(self, tx_hash):
'''Call when a TX is sent. Tells mempool to prioritize it.'''
self.txs_sent += 1
self.mempool.prioritize(tx_hash)
def setup_bands(self):
bands = []
limit = self.env.bandwidth_limit
for n in range(self.BANDS):
bands.append(limit)
limit //= 4
limit = self.env.bandwidth_limit
for n in range(self.BANDS):
limit += limit // 2
bands.append(limit)
self.bands = sorted(bands)
def session_priority(self, session):
if isinstance(session, LocalRPC):
return 0
gid = self.sessions[session]
group_bandwidth = sum(s.bandwidth_used for s in self.groups[gid])
return 1 + (bisect_left(self.bands, session.bandwidth_used)
+ bisect_left(self.bands, group_bandwidth)) // 2
def is_deprioritized(self, session):
return self.session_priority(session) > self.BANDS
async def enqueue_delayed_sessions(self):
while True:
now = time.time()
keep = []
for pair in self.delayed_sessions:
timeout, item = pair
priority, queue_id, session = item
if not session.pause and timeout <= now:
self.queue.put_nowait(item)
else:
keep.append(pair)
self.delayed_sessions = keep
# If paused and session count has fallen, start listening again
if (len(self.sessions) <= self.low_watermark
and self.state == self.PAUSED):
await self.start_external_servers()
await asyncio.sleep(1)
def enqueue_session(self, session):
# Might have disconnected whilst waiting
if not session in self.sessions:
return
priority = self.session_priority(session)
item = (priority, self.next_queue_id, session)
self.next_queue_id += 1
excess = max(0, priority - self.BANDS)
if excess != session.last_delay:
session.last_delay = excess
if excess:
session.log_info('high bandwidth use, deprioritizing by '
'delaying responses {:d}s'.format(excess))
else:
session.log_info('stopped delaying responses')
delay = max(int(session.pause), excess)
if delay:
self.delayed_sessions.append((time.time() + delay, item))
else:
self.queue.put_nowait(item)
async def serve_requests(self):
'''Asynchronously run through the task queue.'''
while True:
priority_, id_, session = await self.queue.get()
if session in self.sessions:
await session.serve_requests()
async def main_loop(self):
'''Server manager main loop.'''
def add_future(coro):
self.futures.append(asyncio.ensure_future(coro))
# shutdown() assumes bp.main_loop() is first
add_future(self.bp.main_loop(self.mempool.touched))
add_future(self.bp.prefetcher.main_loop(self.bp.caught_up_event))
add_future(self.irc.start(self.bp.caught_up_event))
add_future(self.start_servers(self.bp.caught_up_event))
add_future(self.mempool.main_loop())
add_future(self.enqueue_delayed_sessions())
add_future(self.notify())
for n in range(4):
add_future(self.serve_requests())
for future in asyncio.as_completed(self.futures):
try:
await future # Note: future is not one of self.futures
except asyncio.CancelledError:
break
await self.shutdown()
await asyncio.sleep(1)
def close_servers(self, kinds):
'''Close the servers of the given kinds (TCP etc.).'''
for kind in kinds:
server = self.servers.pop(kind, None)
if server:
server.close()
# Don't bother awaiting the close - we're not async
async def start_server(self, kind, *args, **kw_args):
protocol_class = LocalRPC if kind == 'RPC' else ElectrumX
protocol = partial(protocol_class, self, self.bp, self.env, kind)
server = self.loop.create_server(protocol, *args, **kw_args)
host, port = args[:2]
try:
self.servers[kind] = await server
except Exception as e:
self.logger.error('{} server failed to listen on {}:{:d} :{}'
.format(kind, host, port, e))
else:
self.logger.info('{} server listening on {}:{:d}'
.format(kind, host, port))
async def start_servers(self, caught_up):
'''Start RPC, TCP and SSL servers once caught up.'''
if self.env.rpc_port is not None:
await self.start_server('RPC', 'localhost', self.env.rpc_port)
await caught_up.wait()
self.logger.info('max session count: {:,d}'.format(self.max_sessions))
self.logger.info('session timeout: {:,d} seconds'
.format(self.env.session_timeout))
self.logger.info('session bandwidth limit {:,d} bytes'
.format(self.env.bandwidth_limit))
self.logger.info('max response size {:,d} bytes'
.format(self.env.max_send))
self.logger.info('max subscriptions across all sessions: {:,d}'
.format(self.max_subs))
self.logger.info('max subscriptions per session: {:,d}'
.format(self.env.max_session_subs))
self.logger.info('bands: {}'.format(self.bands))
await self.start_external_servers()
async def start_external_servers(self):
'''Start listening on TCP and SSL ports, but only if the respective
port was given in the environment.
'''
self.state = self.LISTENING
env= self.env
if env.tcp_port is not None:
await self.start_server('TCP', env.host, env.tcp_port)
if env.ssl_port is not None:
# Python 3.5.3: use PROTOCOL_TLS
sslc = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile)
await self.start_server('SSL', env.host, env.ssl_port, ssl=sslc)
async def notify(self):
'''Notify sessions about height changes and touched addresses.'''
while True:
await self.mempool.touched_event.wait()
touched = self.mempool.touched.copy()
self.mempool.touched.clear()
self.mempool.touched_event.clear()
# Invalidate caches
hc = self.history_cache
for hashX in set(hc).intersection(touched):
del hc[hashX]
if self.bp.db_height != self.height:
self.height = self.bp.db_height
self.header_cache.clear()
for session in self.sessions:
if isinstance(session, ElectrumX):
request = self.NotificationRequest(self.bp.db_height,
touched)
session.enqueue_request(request)
# Periodically log sessions
if self.env.log_sessions and time.time() > self.next_log_sessions:
if self.next_log_sessions:
data = self.session_data(for_log=True)
for line in Controller.sessions_text_lines(data):
self.logger.info(line)
self.logger.info(json.dumps(self.server_summary()))
self.next_log_sessions = time.time() + self.env.log_sessions
def electrum_header(self, height):
'''Return the binary header at the given height.'''
if not 0 <= height <= self.bp.db_height:
raise JSONRPC.RPCError('height {:,d} out of range'.format(height))
if height in self.header_cache:
return self.header_cache[height]
header = self.bp.read_headers(height, 1)
header = self.env.coin.electrum_header(header, height)
self.header_cache[height] = header
return header
async def async_get_history(self, hashX):
'''Get history asynchronously to reduce latency.'''
if hashX in self.history_cache:
return self.history_cache[hashX]
def job():
# History DoS limit. Each element of history is about 99
# bytes when encoded as JSON. This limits resource usage
# on bloated history requests, and uses a smaller divisor
# so large requests are logged before refusing them.
limit = self.env.max_send // 97
return list(self.bp.get_history(hashX, limit=limit))
loop = asyncio.get_event_loop()
history = await loop.run_in_executor(None, job)
self.history_cache[hashX] = history
return history
async def shutdown(self):
'''Call to shutdown everything. Returns when done.'''
self.state = self.SHUTTING_DOWN
self.close_servers(list(self.servers.keys()))
# Don't cancel the block processor main loop - let it close itself
for future in self.futures[1:]:
future.cancel()
if self.sessions:
await self.close_sessions()
await self.futures[0]
async def close_sessions(self, secs=30):
self.logger.info('cleanly closing client sessions, please wait...')
for session in self.sessions:
self.close_session(session)
self.logger.info('listening sockets closed, waiting up to '
'{:d} seconds for socket cleanup'.format(secs))
limit = time.time() + secs
while self.sessions and time.time() < limit:
self.clear_stale_sessions(grace=secs//2)
await asyncio.sleep(2)
self.logger.info('{:,d} sessions remaining'
.format(len(self.sessions)))
def add_session(self, session):
now = time.time()
if now > self.next_stale_check:
self.next_stale_check = now + 300
self.clear_stale_sessions()
gid = int(session.start - self.start) // 900
self.groups[gid].append(session)
self.sessions[session] = gid
session.log_info('{} {}, {:,d} total'
.format(session.kind, session.peername(),
len(self.sessions)))
if (len(self.sessions) >= self.max_sessions
and self.state == self.LISTENING):
self.state = self.PAUSED
session.log_info('maximum sessions {:,d} reached, stopping new '
'connections until count drops to {:,d}'
.format(self.max_sessions, self.low_watermark))
self.close_servers(['TCP', 'SSL'])
def remove_session(self, session):
'''Remove a session from our sessions list if there.'''
if session in self.sessions:
gid = self.sessions.pop(session)
assert gid in self.groups
self.groups[gid].remove(session)
self.subscription_count -= session.sub_count()
def close_session(self, session):
'''Close the session's transport and cancel its future.'''
session.close_connection()
return 'disconnected {:d}'.format(session.id_)
def toggle_logging(self, session):
'''Toggle logging of the session.'''
session.log_me = not session.log_me
return 'log {:d}: {}'.format(session.id_, session.log_me)
def clear_stale_sessions(self, grace=15):
'''Cut off sessions that haven't done anything for 10 minutes. Force
close stubborn connections that won't close cleanly after a
short grace period.
'''
now = time.time()
shutdown_cutoff = now - grace
stale_cutoff = now - self.env.session_timeout
stale = []
for session in self.sessions:
if session.is_closing():
if session.stop <= shutdown_cutoff:
session.transport.abort()
elif session.last_recv < stale_cutoff:
self.close_session(session)
stale.append(session.id_)
if stale:
self.logger.info('closing stale connections {}'.format(stale))
# Consolidate small groups
gids = [gid for gid, l in self.groups.items() if len(l) <= 4
and sum(session.bandwidth_used for session in l) < 10000]
if len(gids) > 1:
sessions = sum([self.groups[gid] for gid in gids], [])
new_gid = max(gids)
for gid in gids:
del self.groups[gid]
for session in sessions:
self.sessions[session] = new_gid
self.groups[new_gid] = sessions
def new_subscription(self):
if self.subscription_count >= self.max_subs:
raise JSONRPC.RPCError('server subscription limit {:,d} reached'
.format(self.max_subs))
self.subscription_count += 1
def irc_peers(self):
return self.irc.peers
def session_count(self):
'''The number of connections that we've sent something to.'''
return len(self.sessions)
def server_summary(self):
'''A one-line summary of server state.'''
return {
'daemon_height': self.bp.daemon.cached_height(),
'db_height': self.bp.db_height,
'closing': len([s for s in self.sessions if s.is_closing()]),
'errors': sum(s.error_count for s in self.sessions),
'groups': len(self.groups),
'logged': len([s for s in self.sessions if s.log_me]),
'paused': sum(s.pause for s in self.sessions),
'pid': os.getpid(),
'peers': len(self.irc.peers),
'requests': sum(s.requests_remaining() for s in self.sessions),
'sessions': self.session_count(),
'subs': self.subscription_count,
'txs_sent': self.txs_sent,
}
@staticmethod
def text_lines(method, data):
if method == 'sessions':
return Controller.sessions_text_lines(data)
else:
return Controller.groups_text_lines(data)
@staticmethod
def groups_text_lines(data):
'''A generator returning lines for a list of groups.
data is the return value of rpc_groups().'''
fmt = ('{:<6} {:>9} {:>9} {:>6} {:>6} {:>8}'
'{:>7} {:>9} {:>7} {:>9}')
yield fmt.format('ID', 'Sessions', 'Bwidth KB', 'Reqs', 'Txs', 'Subs',
'Recv', 'Recv KB', 'Sent', 'Sent KB')
for (id_, session_count, bandwidth, reqs, txs_sent, subs,
recv_count, recv_size, send_count, send_size) in data:
yield fmt.format(id_,
'{:,d}'.format(session_count),
'{:,d}'.format(bandwidth // 1024),
'{:,d}'.format(reqs),
'{:,d}'.format(txs_sent),
'{:,d}'.format(subs),
'{:,d}'.format(recv_count),
'{:,d}'.format(recv_size // 1024),
'{:,d}'.format(send_count),
'{:,d}'.format(send_size // 1024))
def group_data(self):
'''Returned to the RPC 'groups' call.'''
result = []
for gid in sorted(self.groups.keys()):
sessions = self.groups[gid]
result.append([gid,
len(sessions),
sum(s.bandwidth_used for s in sessions),
sum(s.requests_remaining() for s in sessions),
sum(s.txs_sent for s in sessions),
sum(s.sub_count() for s in sessions),
sum(s.recv_count for s in sessions),
sum(s.recv_size for s in sessions),
sum(s.send_count for s in sessions),
sum(s.send_size for s in sessions),
])
return result
@staticmethod
def sessions_text_lines(data):
'''A generator returning lines for a list of sessions.
data is the return value of rpc_sessions().'''
def time_fmt(t):
t = int(t)
return ('{:3d}:{:02d}:{:02d}'
.format(t // 3600, (t % 3600) // 60, t % 60))
fmt = ('{:<6} {:<5} {:>15} {:>5} {:>5} '
'{:>7} {:>7} {:>7} {:>7} {:>7} {:>9} {:>21}')
yield fmt.format('ID', 'Flags', 'Client', 'Reqs', 'Txs', 'Subs',
'Recv', 'Recv KB', 'Sent', 'Sent KB', 'Time', 'Peer')
for (id_, flags, peer, client, reqs, txs_sent, subs,
recv_count, recv_size, send_count, send_size, time) in data:
yield fmt.format(id_, flags, client,
'{:,d}'.format(reqs),
'{:,d}'.format(txs_sent),
'{:,d}'.format(subs),
'{:,d}'.format(recv_count),
'{:,d}'.format(recv_size // 1024),
'{:,d}'.format(send_count),
'{:,d}'.format(send_size // 1024),
time_fmt(time), peer)
def session_data(self, for_log):
'''Returned to the RPC 'sessions' call.'''
now = time.time()
sessions = sorted(self.sessions, key=lambda s: s.start)
return [(session.id_,
session.flags(),
session.peername(for_log=for_log),
session.client,
session.requests_remaining(),
session.txs_sent,
session.sub_count(),
session.recv_count, session.recv_size,
session.send_count, session.send_size,
now - session.start)
for session in sessions]
def lookup_session(self, param):
try:
id_ = int(param)
except:
pass
else:
for session in self.sessions:
if session.id_ == id_:
return session
return None
def for_each_session(self, params, operation):
result = []
for param in params:
session = self.lookup_session(param)
if session:
result.append(operation(session))
else:
result.append('unknown session: {}'.format(param))
return result
async def rpc_disconnect(self, params):
return self.for_each_session(params, self.close_session)
async def rpc_log(self, params):
return self.for_each_session(params, self.toggle_logging)
async def rpc_getinfo(self, params):
return self.server_summary()
async def rpc_groups(self, params):
return self.group_data()
async def rpc_sessions(self, params):
return self.session_data(for_log=False)
async def rpc_peers(self, params):
return self.irc.peers