From 2d1ccfcc69907bda5c18a23569f8ad322b7eecf1 Mon Sep 17 00:00:00 2001 From: Janus Date: Mon, 18 Dec 2017 12:05:25 +0100 Subject: [PATCH] asyncio: support switching servers --- lib/interface.py | 10 +++--- lib/network.py | 92 +++++++++++++++++++++++++++++++++--------------- 2 files changed, 69 insertions(+), 33 deletions(-) diff --git a/lib/interface.py b/lib/interface.py index 7cbcfa20..64629054 100644 --- a/lib/interface.py +++ b/lib/interface.py @@ -143,10 +143,9 @@ class Interface(util.PrintError): reader, writer = await asyncio.wait_for(self.conn_coro(context), 5) dercert = writer.get_extra_info('ssl_object').getpeercert(True) writer.close() - except ConnectionError: + except OSError as e: # not ConnectionError because we need socket.gaierror too if self.is_running(): - traceback.print_exc() - print("Previous exception from _save_certificate") + print("Exception in _save_certificate", type(e)) return except TimeoutError: return @@ -199,8 +198,9 @@ class Interface(util.PrintError): raise except BaseException as e: if self.is_running(): - traceback.print_exc() - print("Previous exception will now be reraised") + if not isinstance(e, OSError): + traceback.print_exc() + print("Previous exception will now be reraised") raise e if self.use_ssl and is_new: self.print_error("saving new certificate for", self.host) diff --git a/lib/network.py b/lib/network.py index eac5867f..4600c12c 100644 --- a/lib/network.py +++ b/lib/network.py @@ -163,6 +163,7 @@ class Network(util.DaemonThread): """ def __init__(self, config=None): + self.stopped = True asyncio.set_event_loop(None) if config is None: config = {} # Do not use mutables as default values! @@ -393,6 +394,7 @@ class Network(util.DaemonThread): await self.start_random_interface() async def start_network(self, protocol, proxy): + self.stopped = False # TODO proxy assert not self.interface and not self.interfaces assert not self.connecting @@ -403,17 +405,24 @@ class Network(util.DaemonThread): await self.start_interfaces() async def stop_network(self): + self.stopped = True self.print_error("stopping network") - for num, interface in enumerate(list(self.interfaces.values())): - await self.close_interface(interface) - await interface.future - #if self.interface: - # await self.close_interface(self.interface) - # await interface.future - await asyncio.wait_for(self.process_pending_sends_job, 5) + list_with_main = ([self.interface] if self.interface else []) + for num, interface in enumerate(list(self.interfaces.values()) + list_with_main): + try: + await asyncio.wait_for(asyncio.shield(self.close_interface(interface)), 5) + await asyncio.wait_for(asyncio.shield(interface.future), 5) + except TimeoutError: + print("close_interface too slow...") + await asyncio.wait_for(asyncio.shield(self.process_pending_sends_job), 5) assert self.interface is None - while self.interfaces: - asyncio.sleep(0.1) + for i in range(100): + if not self.interfaces: + break + else: + await asyncio.sleep(0.1) + if self.interfaces: + assert False, "interfaces not empty after waiting: " + repr(self.interfaces) self.connecting = set() # called from the Qt thread @@ -437,10 +446,15 @@ class Network(util.DaemonThread): self.auto_connect = auto_connect if self.proxy != proxy or self.protocol != protocol: async def job(): - # Restart the network defaulting to the given server - await self.stop_network() - self.default_server = server - await self.start_network(protocol, proxy) + try: + # Restart the network defaulting to the given server + await self.stop_network() + self.default_server = server + await self.start_network(protocol, proxy) + self.notify('interfaces') + except BaseException as e: + traceback.print_exc() + print("exception from restart job") asyncio.run_coroutine_threadsafe(job(), self.loop) elif self.default_server != server: async def job(): @@ -564,7 +578,7 @@ class Network(util.DaemonThread): return str(method) + (':' + str(params[0]) if params else '') async def process_responses(self, interface): - while self.is_running(): + while not self.stopped: request, response = await interface.get_response() if request: method, params, message_id = request @@ -698,7 +712,7 @@ class Network(util.DaemonThread): async def new_interface(self, server): # todo: get tip first, then decide which checkpoint to use. self.add_recent_server(server) - interface = Interface(server, self.config.path, self.proxy, lambda: self.is_running()) + interface = Interface(server, self.config.path, self.proxy, lambda: not self.stopped) interface.future = asyncio.Future() interface.blockchain = None interface.tip_header = None @@ -884,9 +898,9 @@ class Network(util.DaemonThread): def make_send_requests_job(self, interface): async def job(): try: - while self.is_running(): + while not self.stopped: try: - result = await asyncio.wait_for(interface.send_request(), 1) + result = await asyncio.wait_for(asyncio.shield(interface.send_request()), 1) except TimeoutError: continue if not result: @@ -894,7 +908,7 @@ class Network(util.DaemonThread): except GeneratorExit: pass except: - if self.is_running(): + if not self.stopped: traceback.print_exc() print("FATAL ERROR ^^^") return asyncio.ensure_future(job()) @@ -909,7 +923,7 @@ class Network(util.DaemonThread): await self.connection_down(interface.server) print("OS error, connection downed") except BaseException: - if self.is_running(): + if not self.stopped: traceback.print_exc() print("FATAL ERROR in process_responses") return asyncio.ensure_future(job()) @@ -917,15 +931,16 @@ class Network(util.DaemonThread): def make_process_pending_sends_job(self): async def job(): try: - while self.is_running(): + while not self.stopped: + print("pend send") try: - await asyncio.wait_for(self.process_pending_sends(), 1) + await asyncio.wait_for(asyncio.shield(self.process_pending_sends()), 1) except TimeoutError: continue #except CancelledError: # pass except BaseException as e: - if self.is_running(): + if not self.stopped: traceback.print_exc() print("FATAL ERROR in process_pending_sends") return asyncio.ensure_future(job()) @@ -947,11 +962,20 @@ class Network(util.DaemonThread): try: await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface) def rem(): - if self.is_running(): self.connecting.remove(interface.server) + if not self.stopped: + self.connecting.remove(interface.server) if not await interface.send_request(): - await self.connection_down(interface.server) rem() + await self.connection_down(interface.server) return + if self.stopped: return + try: + await asyncio.wait_for(interface.get_response(), 1) + except TimeoutError: + rem() + await self.connection_down(interface.server) + return + if self.stopped: return rem() self.interfaces[interface.server] = interface await self.queue_request('blockchain.headers.subscribe', [], interface) @@ -972,34 +996,45 @@ class Network(util.DaemonThread): print(interface.server, "GENERATOR EXIT") pass except BaseException as e: - if self.is_running(): + if not self.stopped: traceback.print_exc() print("FATAL ERROR in boot_interface") raise e interface.boot_job = asyncio.ensure_future(job()) + interface.boot_job.server = interface.server + def boot_job_cb(fut): + try: + fut.exception() + except: + traceback.print_exc() + print("Previous exception in boot_job") + interface.boot_job.add_done_callback(boot_job_cb) def make_ping_job(self, interface): async def job(): try: - while self.is_running(): + while not self.stopped: await asyncio.sleep(1) # Send pings and shut down stale interfaces # must use copy of values if interface.has_timed_out(): print(interface.server, "timed out") await self.connection_down(interface.server) + return elif interface.ping_required(): params = [ELECTRUM_VERSION, PROTOCOL_VERSION] await self.queue_request('server.version', params, interface) except GeneratorExit: pass except: - if self.is_running(): + if not self.stopped: traceback.print_exc() print("FATAL ERROR in ping_job") return asyncio.ensure_future(job()) async def maintain_interfaces(self): + if self.stopped: return + now = time.time() # nodes if len(self.interfaces) + len(self.connecting) < self.num_server: @@ -1054,6 +1089,7 @@ class Network(util.DaemonThread): async def run_async(self, future): try: while self.is_running(): + #print(len(asyncio.Task.all_tasks())) await asyncio.sleep(1) await self.maintain_requests() await self.maintain_interfaces() @@ -1063,7 +1099,7 @@ class Network(util.DaemonThread): for i in asyncio.Task.all_tasks(): if asyncio.Task.current_task() == i: continue try: - await asyncio.wait_for(i, 5) + await asyncio.wait_for(asyncio.shield(i), 5) except TimeoutError: print("TOO SLOW TO SHUT DOWN, CANCELLING", i) i.cancel()