asyncio: more graceful shutdown

This commit is contained in:
Janus 2017-12-15 16:18:52 +01:00
parent 200a085778
commit dcb0a24e6f
2 changed files with 83 additions and 57 deletions

View File

@ -228,20 +228,18 @@ class Interface(util.PrintError):
self.buf = self.buf[pos+1:] self.buf = self.buf[pos+1:]
self.last_action = time.time() self.last_action = time.time()
return obj return obj
async def get(self): async def get(self, is_running):
reader, _ = await self._get_read_write() reader, _ = await self._get_read_write()
while True: while is_running():
tried = self._try_extract() tried = self._try_extract()
if tried: return tried if tried: return tried
temp = io.BytesIO() temp = io.BytesIO()
starttime = time.time() try:
while time.time() - starttime < 1: data = await asyncio.wait_for(reader.read(2**10), 1)
try: temp.write(data)
data = await asyncio.wait_for(reader.read(2**8), 1) except asyncio.TimeoutError:
temp.write(data) continue
except asyncio.TimeoutError:
break
self.buf += temp.getvalue() self.buf += temp.getvalue()
def idle_time(self): def idle_time(self):
@ -266,7 +264,10 @@ class Interface(util.PrintError):
'''Sends queued requests. Returns False on failure.''' '''Sends queued requests. Returns False on failure.'''
make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i} make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i}
n = self.num_requests() n = self.num_requests()
prio, request = await self.unsent_requests.get() try:
prio, request = await asyncio.wait_for(self.unsent_requests.get(), 1.5)
except TimeoutError:
return False
try: try:
await self.send_all([make_dict(*request)]) await self.send_all([make_dict(*request)])
except (SocksError, OSError, TimeoutError) as e: except (SocksError, OSError, TimeoutError) as e:
@ -298,7 +299,7 @@ class Interface(util.PrintError):
return True return True
return False return False
async def get_response(self): async def get_response(self, is_running):
'''Call if there is data available on the socket. Returns a list of '''Call if there is data available on the socket. Returns a list of
(request, response) pairs. Notifications are singleton (request, response) pairs. Notifications are singleton
unsolicited responses presumably as a result of prior unsolicited responses presumably as a result of prior
@ -307,12 +308,12 @@ class Interface(util.PrintError):
corresponding request. If the connection was closed remotely corresponding request. If the connection was closed remotely
or the remote server is misbehaving, a (None, None) will appear. or the remote server is misbehaving, a (None, None) will appear.
''' '''
response = await self.get() response = await self.get(is_running)
if not type(response) is dict: if not type(response) is dict:
print("response type not dict!", response)
if response is None: if response is None:
self.closed_remotely = True self.closed_remotely = True
self.print_error("connection closed remotely") if is_running():
self.print_error("connection closed remotely")
return None, None return None, None
if self.debug: if self.debug:
self.print_error("<--", response) self.print_error("<--", response)

View File

@ -20,6 +20,7 @@
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
from functools import partial
import time import time
import queue import queue
import os import os
@ -45,7 +46,7 @@ from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
NODES_RETRY_INTERVAL = 60 NODES_RETRY_INTERVAL = 60
SERVER_RETRY_INTERVAL = 10 SERVER_RETRY_INTERVAL = 10
from concurrent.futures import CancelledError from concurrent.futures import TimeoutError, CancelledError
def parse_servers(result): def parse_servers(result):
""" parse servers list into dict format""" """ parse servers list into dict format"""
@ -403,12 +404,16 @@ class Network(util.DaemonThread):
async def stop_network(self): async def stop_network(self):
self.print_error("stopping network") self.print_error("stopping network")
for interface in list(self.interfaces.values()): for num, interface in enumerate(list(self.interfaces.values())):
self.close_interface(interface) await self.close_interface(interface)
if self.interface: await interface.future
self.close_interface(self.interface) #if self.interface:
# await self.close_interface(self.interface)
# await interface.future
await asyncio.wait_for(self.process_pending_sends_job, 5)
assert self.interface is None assert self.interface is None
assert not self.interfaces while self.interfaces:
asyncio.sleep(0.1)
self.connecting = set() self.connecting = set()
# called from the Qt thread # called from the Qt thread
@ -486,19 +491,20 @@ class Network(util.DaemonThread):
self.set_status('connected') self.set_status('connected')
self.notify('updated') self.notify('updated')
def close_interface(self, interface): async def close_interface(self, interface):
self.print_error('closing connection', interface.server) self.print_error('closing connection', interface.server)
if interface: if interface:
if interface.server in self.interfaces: if interface.server in self.interfaces:
self.interfaces.pop(interface.server) self.interfaces.pop(interface.server)
if interface.server == self.default_server: if interface.server == self.default_server:
self.interface = None self.interface = None
if interface.jobs: while True:
interface.jobs.cancel() for i in interface.jobs:
if interface.boot_job is not None: if not i.done():
interface.boot_job.cancel() await asyncio.sleep(0.1)
if self.process_pending_sends_job is not None: continue
self.process_pending_sends_job.cancel() break
assert interface.boot_job.done()
interface.close() interface.close()
def add_recent_server(self, server): def add_recent_server(self, server):
@ -559,7 +565,7 @@ class Network(util.DaemonThread):
async def process_responses(self, interface): async def process_responses(self, interface):
while self.is_running(): while self.is_running():
request, response = await interface.get_response() request, response = await interface.get_response(lambda: self.is_running())
if request: if request:
method, params, message_id = request method, params, message_id = request
k = self.get_index(method, params) k = self.get_index(method, params)
@ -584,7 +590,7 @@ class Network(util.DaemonThread):
self.subscribed_addresses.add(params[0]) self.subscribed_addresses.add(params[0])
else: else:
if not response: # Closed remotely / misbehaving if not response: # Closed remotely / misbehaving
self.connection_down(interface.server) await self.connection_down(interface.server)
return return
# Rewrite response shape to match subscription request response # Rewrite response shape to match subscription request response
method = response.get('method') method = response.get('method')
@ -676,7 +682,7 @@ class Network(util.DaemonThread):
if callback in v: if callback in v:
v.remove(callback) v.remove(callback)
def connection_down(self, server): async def connection_down(self, server):
'''A connection to server either went down, or was never made. '''A connection to server either went down, or was never made.
We distinguish by whether it is in self.interfaces.''' We distinguish by whether it is in self.interfaces.'''
self.print_error("connection down", server) self.print_error("connection down", server)
@ -684,7 +690,7 @@ class Network(util.DaemonThread):
if server == self.default_server: if server == self.default_server:
self.set_status('disconnected') self.set_status('disconnected')
if server in self.interfaces: if server in self.interfaces:
self.close_interface(self.interfaces[server]) await self.close_interface(self.interfaces[server])
self.notify('interfaces') self.notify('interfaces')
for b in self.blockchains.values(): for b in self.blockchains.values():
if b.catch_up == server: if b.catch_up == server:
@ -694,6 +700,7 @@ class Network(util.DaemonThread):
# todo: get tip first, then decide which checkpoint to use. # todo: get tip first, then decide which checkpoint to use.
self.add_recent_server(server) self.add_recent_server(server)
interface = Interface(server, self.config.path, self.proxy) interface = Interface(server, self.config.path, self.proxy)
interface.future = asyncio.Future()
interface.blockchain = None interface.blockchain = None
interface.tip_header = None interface.tip_header = None
interface.tip = 0 interface.tip = 0
@ -726,7 +733,7 @@ class Network(util.DaemonThread):
self.requested_chunks.remove(index) self.requested_chunks.remove(index)
connect = interface.blockchain.connect_chunk(index, result) connect = interface.blockchain.connect_chunk(index, result)
if not connect: if not connect:
self.connection_down(interface.server) await self.connection_down(interface.server)
return return
# If not finished, get the next chunk # If not finished, get the next chunk
if interface.blockchain.height() < interface.tip: if interface.blockchain.height() < interface.tip:
@ -748,12 +755,12 @@ class Network(util.DaemonThread):
header = response.get('result') header = response.get('result')
if not header: if not header:
interface.print_error(response) interface.print_error(response)
self.connection_down(interface.server) await self.connection_down(interface.server)
return return
height = header.get('block_height') height = header.get('block_height')
if interface.request != height: if interface.request != height:
interface.print_error("unsolicited header",interface.request, height) interface.print_error("unsolicited header",interface.request, height)
self.connection_down(interface.server) await self.connection_down(interface.server)
return return
chain = blockchain.check_header(header) chain = blockchain.check_header(header)
if interface.mode == 'backward': if interface.mode == 'backward':
@ -773,7 +780,7 @@ class Network(util.DaemonThread):
assert next_height >= self.max_checkpoint(), (interface.bad, interface.good) assert next_height >= self.max_checkpoint(), (interface.bad, interface.good)
else: else:
if height == 0: if height == 0:
self.connection_down(interface.server) await self.connection_down(interface.server)
next_height = None next_height = None
else: else:
interface.bad = height interface.bad = height
@ -792,7 +799,7 @@ class Network(util.DaemonThread):
next_height = (interface.bad + interface.good) // 2 next_height = (interface.bad + interface.good) // 2
assert next_height >= self.max_checkpoint() assert next_height >= self.max_checkpoint()
elif not interface.blockchain.can_connect(interface.bad_header, check_height=False): elif not interface.blockchain.can_connect(interface.bad_header, check_height=False):
self.connection_down(interface.server) await self.connection_down(interface.server)
next_height = None next_height = None
else: else:
branch = self.blockchains.get(interface.bad) branch = self.blockchains.get(interface.bad)
@ -873,16 +880,19 @@ class Network(util.DaemonThread):
for interface in list(self.interfaces.values()): for interface in list(self.interfaces.values()):
if interface.request and time.time() - interface.request_time > 20: if interface.request and time.time() - interface.request_time > 20:
interface.print_error("blockchain request timed out") interface.print_error("blockchain request timed out")
self.connection_down(interface.server) await self.connection_down(interface.server)
def make_send_requests_job(self, interface): def make_send_requests_job(self, interface):
async def job(): async def job():
try: try:
while self.is_running(): while self.is_running():
result = await interface.send_request() try:
result = await asyncio.wait_for(interface.send_request(), 1)
except TimeoutError:
continue
if not result: if not result:
self.connection_down(interface.server) await self.connection_down(interface.server)
except CancelledError: except GeneratorExit:
pass pass
except: except:
traceback.print_exc() traceback.print_exc()
@ -893,10 +903,10 @@ class Network(util.DaemonThread):
async def job(): async def job():
try: try:
await self.process_responses(interface) await self.process_responses(interface)
except CancelledError: except GeneratorExit:
pass pass
except OSError: except OSError:
self.connection_down(interface.server) await self.connection_down(interface.server)
print("OS error, connection downed") print("OS error, connection downed")
except BaseException: except BaseException:
traceback.print_exc() traceback.print_exc()
@ -907,9 +917,12 @@ class Network(util.DaemonThread):
async def job(): async def job():
try: try:
while self.is_running(): while self.is_running():
await self.process_pending_sends() try:
except CancelledError: await asyncio.wait_for(self.process_pending_sends(), 1)
pass except TimeoutError:
continue
#except CancelledError:
# pass
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
print("FATAL ERROR in process_pending_sends") print("FATAL ERROR in process_pending_sends")
@ -931,28 +944,34 @@ class Network(util.DaemonThread):
async def job(): async def job():
try: try:
await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface) await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface)
def rem():
if self.is_running(): self.connecting.remove(interface.server)
if not await interface.send_request(): if not await interface.send_request():
self.connection_down(interface.server) await self.connection_down(interface.server)
self.connecting.remove(interface.server) rem()
return return
self.connecting.remove(interface.server) rem()
self.interfaces[interface.server] = interface self.interfaces[interface.server] = interface
await self.queue_request('blockchain.headers.subscribe', [], interface) await self.queue_request('blockchain.headers.subscribe', [], interface)
if interface.server == self.default_server: if interface.server == self.default_server:
await self.switch_to_interface(interface.server) await self.switch_to_interface(interface.server)
interface.jobs = asyncio.ensure_future(asyncio.gather(self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface))) interface.jobs = [asyncio.ensure_future(x) for x in [self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)]]
def cb(fut): def cb(num, fut):
try: try:
fut.exception() fut.exception()
except: except e:
pass interface.future.set_exception(e)
interface.jobs.add_done_callback(cb) else:
if not interface.future.done(): interface.future.set_result(str(num) + " done")
for num, i in enumerate(interface.jobs):
i.add_done_callback(partial(cb, num))
#self.notify('interfaces') #self.notify('interfaces')
except GeneratorExit: except GeneratorExit:
print(interface.server, "GENERATOR EXIT")
pass pass
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
print("FATAL ERROR in start_interface") print("FATAL ERROR in boot_interface")
raise e raise e
interface.boot_job = asyncio.ensure_future(job()) interface.boot_job = asyncio.ensure_future(job())
@ -965,11 +984,11 @@ class Network(util.DaemonThread):
# must use copy of values # must use copy of values
if interface.has_timed_out(): if interface.has_timed_out():
print(interface.server, "timed out") print(interface.server, "timed out")
self.connection_down(interface.server) await self.connection_down(interface.server)
elif interface.ping_required(): elif interface.ping_required():
params = [ELECTRUM_VERSION, PROTOCOL_VERSION] params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
await self.queue_request('server.version', params, interface) await self.queue_request('server.version', params, interface)
except CancelledError: except GeneratorExit:
pass pass
except: except:
traceback.print_exc() traceback.print_exc()
@ -1037,6 +1056,12 @@ class Network(util.DaemonThread):
self.run_jobs() self.run_jobs()
await self.stop_network() await self.stop_network()
self.on_stop() self.on_stop()
for i in asyncio.Task.all_tasks():
if asyncio.Task.current_task() == i: continue
try:
await i
except CancelledError:
pass
future.set_result("run_async done") future.set_result("run_async done")
except BaseException as e: except BaseException as e:
future.set_exception(e) future.set_exception(e)
@ -1046,7 +1071,7 @@ class Network(util.DaemonThread):
if not height: if not height:
return return
if height < self.max_checkpoint(): if height < self.max_checkpoint():
self.connection_down(interface) await self.connection_down(interface)
return return
interface.tip_header = header interface.tip_header = header
interface.tip = height interface.tip = height