Apart from the flush on shutdown and the flush when caught up, neither of which matter, this makes flushes asynchronous. Also, block processing for reorgs is now asynchronous. This also removes the FORCE_REORG debug envvar; I want to put that into the RPC interface. Closes #102
562 lines
22 KiB
Python
562 lines
22 KiB
Python
# Copyright (c) 2016-2017, Neil Booth
|
|
#
|
|
# All rights reserved.
|
|
#
|
|
# See the file "LICENCE" for information about the copyright
|
|
# and warranty status of this software.
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import ssl
|
|
import time
|
|
from bisect import bisect_left
|
|
from collections import defaultdict
|
|
from functools import partial
|
|
|
|
import pylru
|
|
|
|
from lib.jsonrpc import JSONRPC, RequestBase
|
|
import lib.util as util
|
|
from server.block_processor import BlockProcessor
|
|
from server.irc import IRC
|
|
from server.session import LocalRPC, ElectrumX
|
|
from server.mempool import MemPool
|
|
|
|
|
|
class Controller(util.LoggedClass):
|
|
'''Manages the client servers, a mempool, and a block processor.
|
|
|
|
Servers are started immediately the block processor first catches
|
|
up with the daemon.
|
|
'''
|
|
|
|
BANDS = 5
|
|
CATCHING_UP, LISTENING, PAUSED, SHUTTING_DOWN = range(4)
|
|
|
|
class NotificationRequest(RequestBase):
|
|
def __init__(self, height, touched):
|
|
super().__init__(1)
|
|
self.height = height
|
|
self.touched = touched
|
|
|
|
async def process(self, session):
|
|
self.remaining = 0
|
|
await session.notify(self.height, self.touched)
|
|
|
|
def __init__(self, env):
|
|
super().__init__()
|
|
self.loop = asyncio.get_event_loop()
|
|
self.start = time.time()
|
|
self.bp = BlockProcessor(env)
|
|
self.mempool = MemPool(self.bp.daemon, env.coin, self.bp)
|
|
self.irc = IRC(env)
|
|
self.env = env
|
|
self.servers = {}
|
|
# Map of session to the key of its list in self.groups
|
|
self.sessions = {}
|
|
self.groups = defaultdict(list)
|
|
self.txs_sent = 0
|
|
self.next_log_sessions = 0
|
|
self.state = self.CATCHING_UP
|
|
self.max_sessions = env.max_sessions
|
|
self.low_watermark = self.max_sessions * 19 // 20
|
|
self.max_subs = env.max_subs
|
|
self.subscription_count = 0
|
|
self.next_stale_check = 0
|
|
self.history_cache = pylru.lrucache(256)
|
|
self.header_cache = pylru.lrucache(8)
|
|
self.queue = asyncio.PriorityQueue()
|
|
self.delayed_sessions = []
|
|
self.next_queue_id = 0
|
|
self.height = 0
|
|
self.futures = []
|
|
env.max_send = max(350000, env.max_send)
|
|
self.setup_bands()
|
|
|
|
async def mempool_transactions(self, hashX):
|
|
'''Generate (hex_hash, tx_fee, unconfirmed) tuples for mempool
|
|
entries for the hashX.
|
|
|
|
unconfirmed is True if any txin is unconfirmed.
|
|
'''
|
|
return await self.mempool.transactions(hashX)
|
|
|
|
def mempool_value(self, hashX):
|
|
'''Return the unconfirmed amount in the mempool for hashX.
|
|
|
|
Can be positive or negative.
|
|
'''
|
|
return self.mempool.value(hashX)
|
|
|
|
def sent_tx(self, tx_hash):
|
|
'''Call when a TX is sent. Tells mempool to prioritize it.'''
|
|
self.txs_sent += 1
|
|
self.mempool.prioritize(tx_hash)
|
|
|
|
def setup_bands(self):
|
|
bands = []
|
|
limit = self.env.bandwidth_limit
|
|
for n in range(self.BANDS):
|
|
bands.append(limit)
|
|
limit //= 4
|
|
limit = self.env.bandwidth_limit
|
|
for n in range(self.BANDS):
|
|
limit += limit // 2
|
|
bands.append(limit)
|
|
self.bands = sorted(bands)
|
|
|
|
def session_priority(self, session):
|
|
if isinstance(session, LocalRPC):
|
|
return 0
|
|
gid = self.sessions[session]
|
|
group_bandwidth = sum(s.bandwidth_used for s in self.groups[gid])
|
|
return 1 + (bisect_left(self.bands, session.bandwidth_used)
|
|
+ bisect_left(self.bands, group_bandwidth)) // 2
|
|
|
|
def is_deprioritized(self, session):
|
|
return self.session_priority(session) > self.BANDS
|
|
|
|
async def enqueue_delayed_sessions(self):
|
|
while True:
|
|
now = time.time()
|
|
keep = []
|
|
for pair in self.delayed_sessions:
|
|
timeout, item = pair
|
|
priority, queue_id, session = item
|
|
if not session.pause and timeout <= now:
|
|
self.queue.put_nowait(item)
|
|
else:
|
|
keep.append(pair)
|
|
self.delayed_sessions = keep
|
|
|
|
# If paused and session count has fallen, start listening again
|
|
if (len(self.sessions) <= self.low_watermark
|
|
and self.state == self.PAUSED):
|
|
await self.start_external_servers()
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
def enqueue_session(self, session):
|
|
# Might have disconnected whilst waiting
|
|
if not session in self.sessions:
|
|
return
|
|
priority = self.session_priority(session)
|
|
item = (priority, self.next_queue_id, session)
|
|
self.next_queue_id += 1
|
|
|
|
excess = max(0, priority - self.BANDS)
|
|
if excess != session.last_delay:
|
|
session.last_delay = excess
|
|
if excess:
|
|
session.log_info('high bandwidth use, deprioritizing by '
|
|
'delaying responses {:d}s'.format(excess))
|
|
else:
|
|
session.log_info('stopped delaying responses')
|
|
delay = max(int(session.pause), excess)
|
|
if delay:
|
|
self.delayed_sessions.append((time.time() + delay, item))
|
|
else:
|
|
self.queue.put_nowait(item)
|
|
|
|
async def serve_requests(self):
|
|
'''Asynchronously run through the task queue.'''
|
|
while True:
|
|
priority_, id_, session = await self.queue.get()
|
|
if session in self.sessions:
|
|
await session.serve_requests()
|
|
|
|
async def main_loop(self):
|
|
'''Server manager main loop.'''
|
|
def add_future(coro):
|
|
self.futures.append(asyncio.ensure_future(coro))
|
|
|
|
# shutdown() assumes bp.main_loop() is first
|
|
add_future(self.bp.main_loop(self.mempool.touched))
|
|
add_future(self.bp.prefetcher.main_loop(self.bp.caught_up_event))
|
|
add_future(self.irc.start(self.bp.caught_up_event))
|
|
add_future(self.start_servers(self.bp.caught_up_event))
|
|
add_future(self.mempool.main_loop())
|
|
add_future(self.enqueue_delayed_sessions())
|
|
add_future(self.notify())
|
|
for n in range(4):
|
|
add_future(self.serve_requests())
|
|
|
|
for future in asyncio.as_completed(self.futures):
|
|
try:
|
|
await future # Note: future is not one of self.futures
|
|
except asyncio.CancelledError:
|
|
break
|
|
await self.shutdown()
|
|
await asyncio.sleep(1)
|
|
|
|
def close_servers(self, kinds):
|
|
'''Close the servers of the given kinds (TCP etc.).'''
|
|
for kind in kinds:
|
|
server = self.servers.pop(kind, None)
|
|
if server:
|
|
server.close()
|
|
# Don't bother awaiting the close - we're not async
|
|
|
|
async def start_server(self, kind, *args, **kw_args):
|
|
protocol_class = LocalRPC if kind == 'RPC' else ElectrumX
|
|
protocol = partial(protocol_class, self, self.bp, self.env, kind)
|
|
server = self.loop.create_server(protocol, *args, **kw_args)
|
|
|
|
host, port = args[:2]
|
|
try:
|
|
self.servers[kind] = await server
|
|
except Exception as e:
|
|
self.logger.error('{} server failed to listen on {}:{:d} :{}'
|
|
.format(kind, host, port, e))
|
|
else:
|
|
self.logger.info('{} server listening on {}:{:d}'
|
|
.format(kind, host, port))
|
|
|
|
async def start_servers(self, caught_up):
|
|
'''Start RPC, TCP and SSL servers once caught up.'''
|
|
if self.env.rpc_port is not None:
|
|
await self.start_server('RPC', 'localhost', self.env.rpc_port)
|
|
await caught_up.wait()
|
|
self.logger.info('max session count: {:,d}'.format(self.max_sessions))
|
|
self.logger.info('session timeout: {:,d} seconds'
|
|
.format(self.env.session_timeout))
|
|
self.logger.info('session bandwidth limit {:,d} bytes'
|
|
.format(self.env.bandwidth_limit))
|
|
self.logger.info('max response size {:,d} bytes'
|
|
.format(self.env.max_send))
|
|
self.logger.info('max subscriptions across all sessions: {:,d}'
|
|
.format(self.max_subs))
|
|
self.logger.info('max subscriptions per session: {:,d}'
|
|
.format(self.env.max_session_subs))
|
|
self.logger.info('bands: {}'.format(self.bands))
|
|
await self.start_external_servers()
|
|
|
|
async def start_external_servers(self):
|
|
'''Start listening on TCP and SSL ports, but only if the respective
|
|
port was given in the environment.
|
|
'''
|
|
self.state = self.LISTENING
|
|
|
|
env= self.env
|
|
if env.tcp_port is not None:
|
|
await self.start_server('TCP', env.host, env.tcp_port)
|
|
if env.ssl_port is not None:
|
|
# Python 3.5.3: use PROTOCOL_TLS
|
|
sslc = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
|
|
sslc.load_cert_chain(env.ssl_certfile, keyfile=env.ssl_keyfile)
|
|
await self.start_server('SSL', env.host, env.ssl_port, ssl=sslc)
|
|
|
|
async def notify(self):
|
|
'''Notify sessions about height changes and touched addresses.'''
|
|
while True:
|
|
await self.mempool.touched_event.wait()
|
|
touched = self.mempool.touched.copy()
|
|
self.mempool.touched.clear()
|
|
self.mempool.touched_event.clear()
|
|
|
|
# Invalidate caches
|
|
hc = self.history_cache
|
|
for hashX in set(hc).intersection(touched):
|
|
del hc[hashX]
|
|
if self.bp.db_height != self.height:
|
|
self.height = self.bp.db_height
|
|
self.header_cache.clear()
|
|
|
|
for session in self.sessions:
|
|
if isinstance(session, ElectrumX):
|
|
request = self.NotificationRequest(self.bp.db_height,
|
|
touched)
|
|
session.enqueue_request(request)
|
|
# Periodically log sessions
|
|
if self.env.log_sessions and time.time() > self.next_log_sessions:
|
|
if self.next_log_sessions:
|
|
data = self.session_data(for_log=True)
|
|
for line in Controller.sessions_text_lines(data):
|
|
self.logger.info(line)
|
|
self.logger.info(json.dumps(self.server_summary()))
|
|
self.next_log_sessions = time.time() + self.env.log_sessions
|
|
|
|
def electrum_header(self, height):
|
|
'''Return the binary header at the given height.'''
|
|
if not 0 <= height <= self.bp.db_height:
|
|
raise JSONRPC.RPCError('height {:,d} out of range'.format(height))
|
|
if height in self.header_cache:
|
|
return self.header_cache[height]
|
|
header = self.bp.read_headers(height, 1)
|
|
header = self.env.coin.electrum_header(header, height)
|
|
self.header_cache[height] = header
|
|
return header
|
|
|
|
async def async_get_history(self, hashX):
|
|
'''Get history asynchronously to reduce latency.'''
|
|
if hashX in self.history_cache:
|
|
return self.history_cache[hashX]
|
|
|
|
def job():
|
|
# History DoS limit. Each element of history is about 99
|
|
# bytes when encoded as JSON. This limits resource usage
|
|
# on bloated history requests, and uses a smaller divisor
|
|
# so large requests are logged before refusing them.
|
|
limit = self.env.max_send // 97
|
|
return list(self.bp.get_history(hashX, limit=limit))
|
|
|
|
loop = asyncio.get_event_loop()
|
|
history = await loop.run_in_executor(None, job)
|
|
self.history_cache[hashX] = history
|
|
return history
|
|
|
|
async def shutdown(self):
|
|
'''Call to shutdown everything. Returns when done.'''
|
|
self.state = self.SHUTTING_DOWN
|
|
self.close_servers(list(self.servers.keys()))
|
|
# Don't cancel the block processor main loop - let it close itself
|
|
for future in self.futures[1:]:
|
|
future.cancel()
|
|
if self.sessions:
|
|
await self.close_sessions()
|
|
await self.futures[0]
|
|
|
|
async def close_sessions(self, secs=30):
|
|
self.logger.info('cleanly closing client sessions, please wait...')
|
|
for session in self.sessions:
|
|
self.close_session(session)
|
|
self.logger.info('listening sockets closed, waiting up to '
|
|
'{:d} seconds for socket cleanup'.format(secs))
|
|
limit = time.time() + secs
|
|
while self.sessions and time.time() < limit:
|
|
self.clear_stale_sessions(grace=secs//2)
|
|
await asyncio.sleep(2)
|
|
self.logger.info('{:,d} sessions remaining'
|
|
.format(len(self.sessions)))
|
|
|
|
def add_session(self, session):
|
|
now = time.time()
|
|
if now > self.next_stale_check:
|
|
self.next_stale_check = now + 300
|
|
self.clear_stale_sessions()
|
|
gid = int(session.start - self.start) // 900
|
|
self.groups[gid].append(session)
|
|
self.sessions[session] = gid
|
|
session.log_info('{} {}, {:,d} total'
|
|
.format(session.kind, session.peername(),
|
|
len(self.sessions)))
|
|
if (len(self.sessions) >= self.max_sessions
|
|
and self.state == self.LISTENING):
|
|
self.state = self.PAUSED
|
|
session.log_info('maximum sessions {:,d} reached, stopping new '
|
|
'connections until count drops to {:,d}'
|
|
.format(self.max_sessions, self.low_watermark))
|
|
self.close_servers(['TCP', 'SSL'])
|
|
|
|
def remove_session(self, session):
|
|
'''Remove a session from our sessions list if there.'''
|
|
if session in self.sessions:
|
|
gid = self.sessions.pop(session)
|
|
assert gid in self.groups
|
|
self.groups[gid].remove(session)
|
|
self.subscription_count -= session.sub_count()
|
|
|
|
def close_session(self, session):
|
|
'''Close the session's transport and cancel its future.'''
|
|
session.close_connection()
|
|
return 'disconnected {:d}'.format(session.id_)
|
|
|
|
def toggle_logging(self, session):
|
|
'''Toggle logging of the session.'''
|
|
session.log_me = not session.log_me
|
|
return 'log {:d}: {}'.format(session.id_, session.log_me)
|
|
|
|
def clear_stale_sessions(self, grace=15):
|
|
'''Cut off sessions that haven't done anything for 10 minutes. Force
|
|
close stubborn connections that won't close cleanly after a
|
|
short grace period.
|
|
'''
|
|
now = time.time()
|
|
shutdown_cutoff = now - grace
|
|
stale_cutoff = now - self.env.session_timeout
|
|
|
|
stale = []
|
|
for session in self.sessions:
|
|
if session.is_closing():
|
|
if session.stop <= shutdown_cutoff:
|
|
session.transport.abort()
|
|
elif session.last_recv < stale_cutoff:
|
|
self.close_session(session)
|
|
stale.append(session.id_)
|
|
if stale:
|
|
self.logger.info('closing stale connections {}'.format(stale))
|
|
|
|
# Consolidate small groups
|
|
gids = [gid for gid, l in self.groups.items() if len(l) <= 4
|
|
and sum(session.bandwidth_used for session in l) < 10000]
|
|
if len(gids) > 1:
|
|
sessions = sum([self.groups[gid] for gid in gids], [])
|
|
new_gid = max(gids)
|
|
for gid in gids:
|
|
del self.groups[gid]
|
|
for session in sessions:
|
|
self.sessions[session] = new_gid
|
|
self.groups[new_gid] = sessions
|
|
|
|
def new_subscription(self):
|
|
if self.subscription_count >= self.max_subs:
|
|
raise JSONRPC.RPCError('server subscription limit {:,d} reached'
|
|
.format(self.max_subs))
|
|
self.subscription_count += 1
|
|
|
|
def irc_peers(self):
|
|
return self.irc.peers
|
|
|
|
def session_count(self):
|
|
'''The number of connections that we've sent something to.'''
|
|
return len(self.sessions)
|
|
|
|
def server_summary(self):
|
|
'''A one-line summary of server state.'''
|
|
return {
|
|
'daemon_height': self.bp.daemon.cached_height(),
|
|
'db_height': self.bp.db_height,
|
|
'closing': len([s for s in self.sessions if s.is_closing()]),
|
|
'errors': sum(s.error_count for s in self.sessions),
|
|
'groups': len(self.groups),
|
|
'logged': len([s for s in self.sessions if s.log_me]),
|
|
'paused': sum(s.pause for s in self.sessions),
|
|
'pid': os.getpid(),
|
|
'peers': len(self.irc.peers),
|
|
'requests': sum(s.requests_remaining() for s in self.sessions),
|
|
'sessions': self.session_count(),
|
|
'subs': self.subscription_count,
|
|
'txs_sent': self.txs_sent,
|
|
}
|
|
|
|
@staticmethod
|
|
def text_lines(method, data):
|
|
if method == 'sessions':
|
|
return Controller.sessions_text_lines(data)
|
|
else:
|
|
return Controller.groups_text_lines(data)
|
|
|
|
@staticmethod
|
|
def groups_text_lines(data):
|
|
'''A generator returning lines for a list of groups.
|
|
|
|
data is the return value of rpc_groups().'''
|
|
|
|
fmt = ('{:<6} {:>9} {:>9} {:>6} {:>6} {:>8}'
|
|
'{:>7} {:>9} {:>7} {:>9}')
|
|
yield fmt.format('ID', 'Sessions', 'Bwidth KB', 'Reqs', 'Txs', 'Subs',
|
|
'Recv', 'Recv KB', 'Sent', 'Sent KB')
|
|
for (id_, session_count, bandwidth, reqs, txs_sent, subs,
|
|
recv_count, recv_size, send_count, send_size) in data:
|
|
yield fmt.format(id_,
|
|
'{:,d}'.format(session_count),
|
|
'{:,d}'.format(bandwidth // 1024),
|
|
'{:,d}'.format(reqs),
|
|
'{:,d}'.format(txs_sent),
|
|
'{:,d}'.format(subs),
|
|
'{:,d}'.format(recv_count),
|
|
'{:,d}'.format(recv_size // 1024),
|
|
'{:,d}'.format(send_count),
|
|
'{:,d}'.format(send_size // 1024))
|
|
|
|
def group_data(self):
|
|
'''Returned to the RPC 'groups' call.'''
|
|
result = []
|
|
for gid in sorted(self.groups.keys()):
|
|
sessions = self.groups[gid]
|
|
result.append([gid,
|
|
len(sessions),
|
|
sum(s.bandwidth_used for s in sessions),
|
|
sum(s.requests_remaining() for s in sessions),
|
|
sum(s.txs_sent for s in sessions),
|
|
sum(s.sub_count() for s in sessions),
|
|
sum(s.recv_count for s in sessions),
|
|
sum(s.recv_size for s in sessions),
|
|
sum(s.send_count for s in sessions),
|
|
sum(s.send_size for s in sessions),
|
|
])
|
|
return result
|
|
|
|
@staticmethod
|
|
def sessions_text_lines(data):
|
|
'''A generator returning lines for a list of sessions.
|
|
|
|
data is the return value of rpc_sessions().'''
|
|
|
|
def time_fmt(t):
|
|
t = int(t)
|
|
return ('{:3d}:{:02d}:{:02d}'
|
|
.format(t // 3600, (t % 3600) // 60, t % 60))
|
|
|
|
fmt = ('{:<6} {:<5} {:>15} {:>5} {:>5} '
|
|
'{:>7} {:>7} {:>7} {:>7} {:>7} {:>9} {:>21}')
|
|
yield fmt.format('ID', 'Flags', 'Client', 'Reqs', 'Txs', 'Subs',
|
|
'Recv', 'Recv KB', 'Sent', 'Sent KB', 'Time', 'Peer')
|
|
for (id_, flags, peer, client, reqs, txs_sent, subs,
|
|
recv_count, recv_size, send_count, send_size, time) in data:
|
|
yield fmt.format(id_, flags, client,
|
|
'{:,d}'.format(reqs),
|
|
'{:,d}'.format(txs_sent),
|
|
'{:,d}'.format(subs),
|
|
'{:,d}'.format(recv_count),
|
|
'{:,d}'.format(recv_size // 1024),
|
|
'{:,d}'.format(send_count),
|
|
'{:,d}'.format(send_size // 1024),
|
|
time_fmt(time), peer)
|
|
|
|
def session_data(self, for_log):
|
|
'''Returned to the RPC 'sessions' call.'''
|
|
now = time.time()
|
|
sessions = sorted(self.sessions, key=lambda s: s.start)
|
|
return [(session.id_,
|
|
session.flags(),
|
|
session.peername(for_log=for_log),
|
|
session.client,
|
|
session.requests_remaining(),
|
|
session.txs_sent,
|
|
session.sub_count(),
|
|
session.recv_count, session.recv_size,
|
|
session.send_count, session.send_size,
|
|
now - session.start)
|
|
for session in sessions]
|
|
|
|
def lookup_session(self, param):
|
|
try:
|
|
id_ = int(param)
|
|
except:
|
|
pass
|
|
else:
|
|
for session in self.sessions:
|
|
if session.id_ == id_:
|
|
return session
|
|
return None
|
|
|
|
def for_each_session(self, params, operation):
|
|
result = []
|
|
for param in params:
|
|
session = self.lookup_session(param)
|
|
if session:
|
|
result.append(operation(session))
|
|
else:
|
|
result.append('unknown session: {}'.format(param))
|
|
return result
|
|
|
|
async def rpc_disconnect(self, params):
|
|
return self.for_each_session(params, self.close_session)
|
|
|
|
async def rpc_log(self, params):
|
|
return self.for_each_session(params, self.toggle_logging)
|
|
|
|
async def rpc_getinfo(self, params):
|
|
return self.server_summary()
|
|
|
|
async def rpc_groups(self, params):
|
|
return self.group_data()
|
|
|
|
async def rpc_sessions(self, params):
|
|
return self.session_data(for_log=False)
|
|
|
|
async def rpc_peers(self, params):
|
|
return self.irc.peers
|