From da3d10fd096762727950b5b21e5ee1b1b6405ba6 Mon Sep 17 00:00:00 2001 From: Janus Date: Mon, 12 Mar 2018 16:32:27 +0100 Subject: [PATCH] asyncio: network/interface: remove broken code, use connecting/interfaces correctly --- lib/interface.py | 87 +++++++++++--------------- lib/network.py | 159 +++++++++++++++++++++-------------------------- lib/util.py | 2 +- 3 files changed, 108 insertions(+), 140 deletions(-) diff --git a/lib/interface.py b/lib/interface.py index a9766288..aa2e9d84 100644 --- a/lib/interface.py +++ b/lib/interface.py @@ -108,58 +108,21 @@ class Interface(util.PrintError): context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path) else: context = get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None) - try: - if self.addr is not None: - proto_factory = lambda: SSLProtocol(asyncio.get_event_loop(), asyncio.Protocol(), context, None) - socks_create_coro = aiosocks.create_connection(proto_factory, \ - proxy=self.addr, \ - proxy_auth=self.auth, \ - dst=(self.host, self.port)) - transport, protocol = await asyncio.wait_for(socks_create_coro, 5) - async def job(fut): - try: - if protocol._sslpipe is not None: - fut.set_result(protocol._sslpipe.ssl_object.getpeercert(True)) - except BaseException as e: - fut.set_exception(e) - while self.is_running(): - fut = asyncio.Future() - asyncio.ensure_future(job(fut)) - try: - await fut - except: - pass - try: - fut.exception() - dercert = fut.result() - except ValueError: - await asyncio.sleep(1) - continue - except: - if self.is_running(): - traceback.print_exc() - print("Previous exception from _save_certificate") - continue - break - if not self.is_running(): return - transport.close() - else: - reader, writer = await asyncio.wait_for(self.conn_coro(context), 3) - dercert = writer.get_extra_info('ssl_object').getpeercert(True) - writer.close() - except OSError as e: # not ConnectionError because we need socket.gaierror too - if self.is_running(): - self.print_error(self.server, "Exception in _save_certificate", type(e)) - if not self.error_future.done(): self.error_future.set_result(e) - return - except TimeoutError: - return - assert dercert - if not require_ca: - cert = ssl.DER_cert_to_PEM_cert(dercert) - else: - # Don't pin a CA signed certificate + if self.addr is not None: + self.print_error("can't save certificate through socks!") + # just save the empty file to force use of PKI + # this will break all self-signed servers, of course cert = "" + else: + reader, writer = await asyncio.wait_for(self.conn_coro(context), 3) + dercert = writer.get_extra_info('ssl_object').getpeercert(True) + # an exception will be thrown by now if require_ca is True (e.g. a certificate was supplied) + writer.close() + if not require_ca: + cert = ssl.DER_cert_to_PEM_cert(dercert) + else: + # Don't pin a CA signed certificate + cert = "" temporary_path = cert_path + '.temp' with open(temporary_path, "w") as f: f.write(cert) @@ -172,10 +135,26 @@ class Interface(util.PrintError): if self.use_ssl: cert_path = os.path.join(self.config_path, 'certs', self.host) if not os.path.exists(cert_path): - temporary_path = await self._save_certificate(cert_path, True) + temporary_path = None + # first, we try to save a certificate signed through the PKI + try: + temporary_path = await self._save_certificate(cert_path, True) + except ssl.SSLError: + pass + except (TimeoutError, OSError) as e: + if not self.error_future.done(): self.error_future.set_result(e) + raise + # if the certificate verification failed, we try to save a self-signed certificate if not temporary_path: + try: temporary_path = await self._save_certificate(cert_path, False) + # we also catch SSLError here, but it shouldn't matter since no certificate is required, + # so the SSLError wouldn't mean certificate validation failed + except (TimeoutError, OSError) as e: + if not self.error_future.done(): self.error_future.set_result(e) + raise if not temporary_path: + if not self.error_future.done(): self.error_future.set_result(ConnectionError("Could not get certificate")) raise ConnectionError("Could not get certificate on second try") is_new = True @@ -201,6 +180,10 @@ class Interface(util.PrintError): except TimeoutError: self.print_error("TimeoutError after getting certificate successfully...") raise + except ssl.SSLError: + # FIXME TODO + assert not self_signed, "we shouldn't reject self-signed here since the certificate has been saved (has size {})".format(size) + raise except BaseException as e: if self.is_running(): if not isinstance(e, OSError): diff --git a/lib/network.py b/lib/network.py index f0a538f1..a0839348 100644 --- a/lib/network.py +++ b/lib/network.py @@ -167,8 +167,12 @@ class Network(util.DaemonThread): """ def __init__(self, config=None): + self.lock = threading.Lock() + # callbacks set by the GUI + self.callbacks = defaultdict(list) + self.set_status("disconnected") self.disconnected_servers = {} - self.connecting = set() + self.connecting = {} self.stopped = True asyncio.set_event_loop(None) if config is None: @@ -192,7 +196,6 @@ class Network(util.DaemonThread): self.default_server = None if not self.default_server: self.default_server = pick_random_server() - self.lock = threading.Lock() self.message_id = 0 self.debug = False self.irc_servers = {} # returned by interface (list from irc) @@ -204,8 +207,6 @@ class Network(util.DaemonThread): # callbacks passed with subscriptions self.subscriptions = defaultdict(list) self.sub_cache = {} - # callbacks set by the GUI - self.callbacks = defaultdict(list) dir_path = os.path.join( self.config.path, 'certs') if not os.path.exists(dir_path): @@ -221,7 +222,7 @@ class Network(util.DaemonThread): self.server_retry_time = time.time() self.nodes_retry_time = time.time() # kick off the network. interface is the main server we are currently - # communicating with. interfaces is the set of servers we are connecting + # communicating with. interfaces is the set of servers we are connected # to or have an ongoing connection with self.interface = None self.interfaces = {} @@ -385,18 +386,26 @@ class Network(util.DaemonThread): return out async def start_interface(self, server): - if server not in self.interfaces and server not in self.connecting: - self.connecting.add(server) + if server not in self.interfaces and server not in self.connecting and not self.stopped: + self.connecting[server] = asyncio.Future() if server == self.default_server: self.print_error("connecting to %s as new interface" % server) self.set_status('connecting') - return self.new_interface(server) + iface = self.new_interface(server) + success = await iface.launched + # if the interface was launched sucessfully, we save it in interfaces + # since that dictionary only stores "good" interfaces + if success: + assert server not in self.interfaces + self.interfaces[server] = iface + if not self.connecting[server].done(): self.connecting[server].set_result(True) + del self.connecting[server] async def start_random_interface(self): exclude_set = set(self.disconnected_servers.keys()).union(set(self.interfaces.keys())) server = pick_random_server(self.get_servers(), self.protocol, exclude_set) if server: - return await self.start_interface(server) + await self.start_interface(server) def start_network(self, protocol, proxy): self.stopped = False @@ -411,8 +420,14 @@ class Network(util.DaemonThread): self.print_error("stopping network") async def stop(interface): await self.connection_down(interface.server, "stopping network") - await interface.future + await interface.boot_job stopped_this_time = set() + while self.connecting: + try: + await next(iter(self.connecting.values())) + except asyncio.CancelledError: + # ok since we are shutting down + pass while self.interfaces: do_next = next(iter(self.interfaces.values())) stopped_this_time.add(do_next) @@ -507,15 +522,14 @@ class Network(util.DaemonThread): async def close_interface(self, interface): self.print_error('closing connection', interface.server) - if interface: - if interface.server in self.interfaces: - self.interfaces.pop(interface.server) - if interface.server == self.default_server: - self.interface = None - if interface.jobs: - await asyncio.gather(*interface.jobs) - await interface.boot_job - interface.close() + if interface.server in self.interfaces: + self.interfaces.pop(interface.server) + if interface.server == self.default_server: + self.interface = None + assert not interface.boot_job.cancelled() + interface.boot_job.cancel() + await interface.boot_job + interface.close() def add_recent_server(self, server): # list is ordered @@ -708,18 +722,14 @@ class Network(util.DaemonThread): '''A connection to server either went down, or was never made. We distinguish by whether it is in self.interfaces.''' if server in self.disconnected_servers: - try: - raise Exception("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason)) - except: - traceback.print_exc() + print("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason)) return self.print_error("connection down", server, reason) - for i in self.interfaces[server].jobs: - assert not i.cancelled() - i.cancel() self.disconnected_servers[server] = reason if server == self.default_server: self.set_status('disconnected') + if server in self.connecting: + await self.connecting[server] if server in self.interfaces: await self.close_interface(self.interfaces[server]) self.notify('interfaces') @@ -733,11 +743,12 @@ class Network(util.DaemonThread): def interface_exception_handler(exception_future): try: raise exception_future.result() - except Exception as e: - print(type(e), " handled in new_interface") - asyncio.ensure_future(self.connection_down(server, "error in interface")) - interface = Interface(server, self.config.path, self.proxy, lambda: not self.stopped and server in self.interfaces, interface_exception_handler) - interface.future = asyncio.Future() + except BaseException as e: + asyncio.ensure_future(self.connection_down(server, "error in interface: " + str(e))) + interface = Interface(server, self.config.path, self.proxy, lambda: not self.stopped and (server in self.connecting or server in self.interfaces), interface_exception_handler) + # this future has its result set when the interface should be considered opened, + # e.g. should be moved from connecting to interfaces + interface.launched = asyncio.Future() interface.blockchain = None interface.tip_header = None interface.tip = 0 @@ -748,7 +759,6 @@ class Network(util.DaemonThread): self.boot_interface(interface) assert server not in self.interfaces assert not self.stopped - self.interfaces[server] = interface return interface async def request_chunk(self, interface, idx): @@ -916,15 +926,17 @@ class Network(util.DaemonThread): async def job(): try: while True: - await interface.send_request() + # this wait_for is necessary since CancelledError + # doesn't seem to get thrown without it + # when using ssl_in_socks + try: + await asyncio.wait_for(interface.send_request(), 1) + except TimeoutError: + pass except CancelledError: pass - except Exception as e: - print("send requests exp", str(e)) - if interface.is_running() or self.num_server == 0: - traceback.print_exc() - self.print_error("FATAL ERROR ^^^") - await self.connection_down(interface.server, "exp while send_request") + except BaseException as e: + await self.connection_down(interface.server, "exp while send_request: " + str(e)) return asyncio.ensure_future(job()) @@ -932,15 +944,10 @@ class Network(util.DaemonThread): async def job(): try: await self.process_responses(interface) - except OSError: - await self.connection_down(interface.server, "OSError in process_responses") - self.print_error("OS error, connection downed") except CancelledError: pass except: - if interface.is_running() or self.num_server == 0: - traceback.print_exc() - self.print_error("FATAL ERROR in process_responses") + await self.connection_down(interface.server, "exp in process_responses") return asyncio.ensure_future(job()) def make_process_pending_sends_job(self): @@ -951,9 +958,7 @@ class Network(util.DaemonThread): except CancelledError: pass except: - if not self.stopped or self.num_server == 0: - traceback.print_exc() - self.print_error("FATAL ERROR in process_pending_sends") + await self.connection_down(interface.server, "exp in process_pending_sends_job") return asyncio.ensure_future(job()) def init_headers_file(self): @@ -972,30 +977,17 @@ class Network(util.DaemonThread): async def job(): try: interface.jobs = [self.make_send_requests_job(interface)] # we need this job to process the request queued below - try: - await asyncio.wait_for(self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface), 1) - except TimeoutError: - asyncio.ensure_future(self.connection_down(interface.server, "couldn't send initial version")) - interface.future.set_result("couldn't send initial version") - return + await asyncio.wait_for(self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface), 10) + handshakeResult = await asyncio.wait_for(interface.get_response(), 10) if not interface.is_running(): - interface.future.set_result("stopped after sending request") + self.print_error("WARNING: quitting bootjob instead of handling CancelledError!") return - try: - await asyncio.wait_for(interface.get_response(), 1) - except TimeoutError: - asyncio.ensure_future(self.connection_down(interface.server, "timeout in boot_interface while getting response")) - interface.future.set_result("timeout while getting response") - return - if not interface.is_running(): - interface.future.set_result("stopped after getting response") - return - #self.interfaces[interface.server] = interface await self.queue_request('blockchain.headers.subscribe', [], interface) if interface.server == self.default_server: - await asyncio.wait_for(self.switch_to_interface(interface.server), 1) + await asyncio.wait_for(self.switch_to_interface(interface.server), 5) interface.jobs += [self.make_process_responses_job(interface)] gathered = asyncio.gather(*interface.jobs) + interface.launched.set_result(handshakeResult) while interface.is_running(): try: await asyncio.wait_for(asyncio.shield(gathered), 1) @@ -1008,26 +1000,18 @@ class Network(util.DaemonThread): elif interface.ping_required(): params = [ELECTRUM_VERSION, PROTOCOL_VERSION] await self.queue_request('server.version', params, interface) - interface.future.set_result("finished") - return #self.notify('interfaces') - except BaseException as e: - if interface.is_running() or self.num_server == 0: - traceback.print_exc() - self.print_error("FATAL ERROR in boot_interface") - raise e - finally: - self.connecting.remove(interface.server) - interface.boot_job = asyncio.ensure_future(job()) - interface.boot_job.server = interface.server - def boot_job_cb(fut): - try: - fut.exception() + except TimeoutError: + asyncio.ensure_future(self.connection_down(interface.server, "timeout in boot_interface while getting response")) + except CancelledError: + pass except: traceback.print_exc() - self.print_error("Previous exception in boot_job") - interface.boot_job.add_done_callback(boot_job_cb) - + finally: + for i in interface.jobs: i.cancel() + if not interface.launched.done(): interface.launched.set_result(None) + interface.boot_job = asyncio.ensure_future(job()) + interface.boot_job.server = interface.server async def maintain_interfaces(self): if self.stopped: @@ -1035,7 +1019,8 @@ class Network(util.DaemonThread): now = time.time() # nodes if len(self.interfaces) + len(self.connecting) < self.num_server: - await self.start_random_interface() + # ensure future so that servers can be connected to in parallel + asyncio.ensure_future(self.start_random_interface()) if now - self.nodes_retry_time > NODES_RETRY_INTERVAL: self.print_error('network: retrying connections') self.disconnected_servers = {} @@ -1072,10 +1057,10 @@ class Network(util.DaemonThread): asyncio.ensure_future(self.run_async(run_future)) loop.run_until_complete(run_future) - assert self.forever_coroutines_task.done() + loop.run_until_complete(self.forever_coroutines_task) run_future.exception() - self.print_error("run future result", run_future.result()) - loop.close() + # we don't want to wait for the closure of all connections, the OS can do that + #loop.close() async def run_async(self, future): try: diff --git a/lib/util.py b/lib/util.py index 67467d68..fa88fa6c 100644 --- a/lib/util.py +++ b/lib/util.py @@ -228,6 +228,7 @@ class DaemonThread(threading.Thread, PrintError): self.forever_coroutines_queue = asyncio.Queue() # making queue here because __init__ is called from non-network thread self.loop = asyncio.get_event_loop() async def getFromQueueAndStart(): + jobs = [] while True: try: jobs = await asyncio.wait_for(self.forever_coroutines_queue.get(), 1) @@ -236,7 +237,6 @@ class DaemonThread(threading.Thread, PrintError): if not self.is_running(): break continue await asyncio.gather(*[i.run(self.is_running) for i in jobs]) - self.print_error("FOREVER JOBS DONE") self.forever_coroutines_task = asyncio.ensure_future(getFromQueueAndStart()) return self.forever_coroutines_task