Make jobs truly asynchronous.

However we need to rate-limit the daemon...
This commit is contained in:
Neil Booth 2016-11-07 22:22:47 +09:00
parent e452c0bca7
commit f05a5414c1

View File

@ -12,6 +12,7 @@ import asyncio
import codecs import codecs
import json import json
import traceback import traceback
from collections import namedtuple
from functools import partial from functools import partial
from server.daemon import DaemonError from server.daemon import DaemonError
@ -29,11 +30,53 @@ def json_notification(method, params):
return {'id': None, 'method': method, 'params': params} return {'id': None, 'method': method, 'params': params}
AsyncTask = namedtuple('AsyncTask', 'session job')
class SessionManager(LoggedClass):
def __init__(self):
super().__init__()
self.sessions = set()
self.tasks = asyncio.Queue()
self.current_task = None
asyncio.ensure_future(self.run_tasks())
def add_session(self, session):
assert session not in self.sessions
self.sessions.add(session)
def remove_session(self, session):
self.sessions.remove(session)
if self.current_task and session == self.current_task.session:
self.logger.info('cancelling running task')
self.current_task.cancel()
def add_task(self, session, job):
assert session in self.sessions
task = asyncio.ensure_future(job)
self.tasks.put_nowait(AsyncTask(session, task))
async def run_tasks(self):
'''Asynchronously run through the task queue.'''
while True:
task = await self.tasks.get()
try:
if task.session in self.sessions:
self.current_task = task
await task.job
else:
task.job.cancel()
except asyncio.CancelledError:
self.logger.info('cancelled task noted')
except Exception:
# Getting here should probably be considered a bug and fixed
traceback.print_exc()
finally:
self.current_task = None
class JSONRPC(asyncio.Protocol, LoggedClass): class JSONRPC(asyncio.Protocol, LoggedClass):
'''Base class that manages a JSONRPC connection.''' '''Base class that manages a JSONRPC connection.'''
SESSIONS = set()
# Queue for aynchronous job processing.
JOBS = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -41,33 +84,13 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
self.send_count = 0 self.send_count = 0
self.send_size = 0 self.send_size = 0
self.error_count = 0 self.error_count = 0
self.init_jobs()
@classmethod
def init_jobs(cls):
if not cls.JOBS:
cls.JOBS = asyncio.Queue()
asyncio.ensure_future(cls.run_jobs())
@classmethod
async def run_jobs(cls):
'''Asynchronously run through the job queue.'''
while True:
job = await cls.JOBS.get()
try:
await job
except asyncio.CancelledError:
raise
except Exception:
# Getting here should probably be considered a bug and fixed
traceback.print_exc()
def connection_made(self, transport): def connection_made(self, transport):
'''Handle an incoming client connection.''' '''Handle an incoming client connection.'''
self.transport = transport self.transport = transport
self.peername = transport.get_extra_info('peername') self.peername = transport.get_extra_info('peername')
self.logger.info('connection from {}'.format(self.peername)) self.logger.info('connection from {}'.format(self.peername))
self.SESSIONS.add(self) self.SESSION_MGR.add_session(self)
def connection_lost(self, exc): def connection_lost(self, exc):
'''Handle client disconnection.''' '''Handle client disconnection.'''
@ -75,7 +98,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
'Sent {:,d} bytes in {:,d} messages {:,d} errors' 'Sent {:,d} bytes in {:,d} messages {:,d} errors'
.format(self.peername, self.send_size, .format(self.peername, self.send_size,
self.send_count, self.error_count)) self.send_count, self.error_count))
self.SESSIONS.remove(self) self.SESSION_MGR.remove_session(self)
def data_received(self, data): def data_received(self, data):
'''Handle incoming data (synchronously). '''Handle incoming data (synchronously).
@ -100,7 +123,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
except Exception as e: except Exception as e:
self.logger.info('error decoding JSON message: {}'.format(e)) self.logger.info('error decoding JSON message: {}'.format(e))
else: else:
self.JOBS.put_nowait(self.request_handler(message)) self.SESSION_MGR.add_task(self, self.request_handler(message))
async def request_handler(self, request): async def request_handler(self, request):
'''Called asynchronously.''' '''Called asynchronously.'''
@ -113,13 +136,21 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
self.error_count += 1 self.error_count += 1
error = {'code': 1, 'message': e.args[0]} error = {'code': 1, 'message': e.args[0]}
payload = {'id': request.get('id'), 'error': error, 'result': result} payload = {'id': request.get('id'), 'error': error, 'result': result}
self.json_send(payload) if not self.json_send(payload):
# Let asyncio call connection_lost() so we stop this
# session's tasks
await asyncio.sleep(0)
def json_send(self, payload): def json_send(self, payload):
if self.transport.is_closing():
self.logger.info('connection closing, not writing')
return False
data = (json.dumps(payload) + '\n').encode() data = (json.dumps(payload) + '\n').encode()
self.transport.write(data) self.transport.write(data)
self.send_count += 1 self.send_count += 1
self.send_size += len(data) self.send_size += len(data)
return True
def rpc_handler(self, method, params): def rpc_handler(self, method, params):
handler = None handler = None
@ -193,6 +224,7 @@ class JSONRPC(asyncio.Protocol, LoggedClass):
cls.BLOCK_PROCESSOR = block_processor cls.BLOCK_PROCESSOR = block_processor
cls.DAEMON = daemon cls.DAEMON = daemon
cls.COIN = coin cls.COIN = coin
cls.SESSION_MGR = SessionManager()
@classmethod @classmethod
def height(cls): def height(cls):
@ -240,7 +272,8 @@ class ElectrumX(JSONRPC):
@classmethod @classmethod
def watched_address_count(cls): def watched_address_count(cls):
return sum(len(session.hash168s) for session in self.SESSIONS sessions = self.SESSION_MGR.sessions
return sum(len(session.hash168s) for session in session
if isinstance(session, cls)) if isinstance(session, cls))
@classmethod @classmethod
@ -257,7 +290,7 @@ class ElectrumX(JSONRPC):
) )
hash168_to_address = cls.COIN.hash168_to_address hash168_to_address = cls.COIN.hash168_to_address
for session in cls.SESSIONS: for session in cls.SESSION_MGR.sessions:
if height != session.notified_height: if height != session.notified_height:
session.notified_height = height session.notified_height = height
if session.subscribe_headers: if session.subscribe_headers:
@ -519,7 +552,7 @@ class LocalRPC(JSONRPC):
return { return {
'blocks': self.height(), 'blocks': self.height(),
'peers': len(ElectrumX.irc_peers()), 'peers': len(ElectrumX.irc_peers()),
'sessions': len(self.SESSIONS), 'sessions': len(self.SESSION_MGR.sessions),
'watched': ElectrumX.watched_address_count(), 'watched': ElectrumX.watched_address_count(),
'cached': 0, 'cached': 0,
} }
@ -528,7 +561,7 @@ class LocalRPC(JSONRPC):
return [] return []
async def numsessions(self, params): async def numsessions(self, params):
return len(self.SESSIONS) return len(self.SESSION_MGR.sessions)
async def peers(self, params): async def peers(self, params):
return tuple(ElectrumX.irc_peers().keys()) return tuple(ElectrumX.irc_peers().keys())