Use session.spawn()

Fixes #632 properly

Requires aiorpcX 0.10.x
This commit is contained in:
Neil Booth 2018-11-05 16:13:26 -04:00
parent 2271c8a45c
commit 6209300e35
3 changed files with 15 additions and 18 deletions

View File

@ -82,8 +82,8 @@ class Controller(ServerBase):
'''Start the RPC server and wait for the mempool to synchronize. Then '''Start the RPC server and wait for the mempool to synchronize. Then
start serving external clients. start serving external clients.
''' '''
if not (0, 9, 1) <= aiorpcx_version < (0, 10): if not (0, 10, 0) <= aiorpcx_version < (0, 11):
raise RuntimeError('aiorpcX version 0.9.x with x >= 1 required') raise RuntimeError('aiorpcX version 0.10.x required')
env = self.env env = self.env
min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings() min_str, max_str = env.coin.SESSIONCLS.protocol_min_max_strings()

View File

@ -262,9 +262,8 @@ class SessionManager(object):
for session in stale_sessions) for session in stale_sessions)
self.logger.info(f'closing stale connections {text}') self.logger.info(f'closing stale connections {text}')
# Give the sockets some time to close gracefully # Give the sockets some time to close gracefully
async with TaskGroup() as group: for session in stale_sessions:
for session in stale_sessions: await session.spawn(session.close())
await group.spawn(session.close())
# Consolidate small groups # Consolidate small groups
bw_limit = self.env.bandwidth_limit bw_limit = self.env.bandwidth_limit
@ -512,9 +511,8 @@ class SessionManager(object):
finally: finally:
# Close servers and sessions # Close servers and sessions
await self._close_servers(list(self.servers.keys())) await self._close_servers(list(self.servers.keys()))
async with TaskGroup() as group: for session in self.sessions:
for session in list(self.sessions): await session.spawn(session.close(force_after=1))
await group.spawn(session.close(force_after=1))
def session_count(self): def session_count(self):
'''The number of connections that we've sent something to.''' '''The number of connections that we've sent something to.'''
@ -567,9 +565,8 @@ class SessionManager(object):
for hashX in set(hc).intersection(touched): for hashX in set(hc).intersection(touched):
del hc[hashX] del hc[hashX]
async with TaskGroup() as group: for session in self.sessions:
for session in self.sessions: await session.spawn(session.notify, touched, height_changed)
await group.spawn(session.notify(touched, height_changed))
def add_session(self, session): def add_session(self, session):
self.sessions.add(session) self.sessions.add(session)
@ -649,7 +646,7 @@ class SessionBase(RPCSession):
status += 'C' status += 'C'
if self.log_me: if self.log_me:
status += 'L' status += 'L'
status += str(self.concurrency.max_concurrent) status += str(self._concurrency.max_concurrent)
return status return status
def connection_made(self, transport): def connection_made(self, transport):
@ -664,12 +661,11 @@ class SessionBase(RPCSession):
def connection_lost(self, exc): def connection_lost(self, exc):
'''Handle client disconnection.''' '''Handle client disconnection.'''
super().connection_lost(exc)
self.session_mgr.remove_session(self) self.session_mgr.remove_session(self)
msg = '' msg = ''
if not self.can_send.is_set(): if not self._can_send.is_set():
msg += ' whilst paused' msg += ' with full socket buffer'
if self.concurrency.max_concurrent != self.max_concurrent: if self._concurrency.max_concurrent != self.max_concurrent:
msg += ' whilst throttled' msg += ' whilst throttled'
if self.send_size >= 1024*1024: if self.send_size >= 1024*1024:
msg += ('. Sent {:,d} bytes in {:,d} messages' msg += ('. Sent {:,d} bytes in {:,d} messages'
@ -677,12 +673,13 @@ class SessionBase(RPCSession):
if msg: if msg:
msg = 'disconnected' + msg msg = 'disconnected' + msg
self.logger.info(msg) self.logger.info(msg)
super().connection_lost(exc)
def count_pending_items(self): def count_pending_items(self):
return len(self.connection.pending_requests()) return len(self.connection.pending_requests())
def semaphore(self): def semaphore(self):
return Semaphores([self.concurrency.semaphore, self.group.semaphore]) return Semaphores([self._concurrency.semaphore, self.group.semaphore])
def sub_count(self): def sub_count(self):
return 0 return 0

View File

@ -12,7 +12,7 @@ setuptools.setup(
# "blake256" package is required to sync Decred network. # "blake256" package is required to sync Decred network.
# "xevan_hash" package is required to sync Xuez network. # "xevan_hash" package is required to sync Xuez network.
# "groestlcoin_hash" package is required to sync Groestlcoin network. # "groestlcoin_hash" package is required to sync Groestlcoin network.
install_requires=['aiorpcX>=0.9.1,<0.10', 'attrs', install_requires=['aiorpcX>=0.10.0,<0.11', 'attrs',
'plyvel', 'pylru', 'aiohttp >= 2'], 'plyvel', 'pylru', 'aiohttp >= 2'],
packages=setuptools.find_packages(include=('electrumx*',)), packages=setuptools.find_packages(include=('electrumx*',)),
description='ElectrumX Server', description='ElectrumX Server',