From 200a0857789788a2d597c5af3c141359e4c69ec8 Mon Sep 17 00:00:00 2001 From: Janus Date: Thu, 14 Dec 2017 16:57:08 +0100 Subject: [PATCH] asyncio: do not pin CA certificates, poll for cert differently --- lib/interface.py | 94 +++++++++++++++++++++++++++++++++++------------- lib/network.py | 1 + 2 files changed, 70 insertions(+), 25 deletions(-) diff --git a/lib/interface.py b/lib/interface.py index d87c1ced..b36a5aaf 100644 --- a/lib/interface.py +++ b/lib/interface.py @@ -99,6 +99,66 @@ class Interface(util.PrintError): def conn_coro(self, context): return asyncio.open_connection(self.host, self.port, ssl=context) + async def _save_certificate(self, cert_path, require_ca): + dercert = None + if require_ca: + context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path) + else: + context = get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None) + try: + if self.addr is not None: + proto_factory = lambda: SSLProtocol(asyncio.get_event_loop(), asyncio.Protocol(), context, None) + socks_create_coro = aiosocks.create_connection(proto_factory, \ + proxy=self.addr, \ + proxy_auth=self.auth, \ + dst=(self.host, self.port)) + transport, protocol = await asyncio.wait_for(socks_create_coro, 5) + async def job(fut): + try: + if protocol._sslpipe is not None: + fut.set_result(protocol._sslpipe.ssl_object.getpeercert(True)) + except BaseException as e: + fut.set_exception(e) + while True: + fut = asyncio.Future() + asyncio.ensure_future(job(fut)) + try: + await fut + except: + pass + try: + fut.exception() + dercert = fut.result() + except ValueError: + await asyncio.sleep(1) + continue + except: + traceback.print_exc() + continue + break + print("done sleeping") + transport.close() + else: + reader, writer = await asyncio.wait_for(self.conn_coro(context), 5) + dercert = writer.get_extra_info('ssl_object').getpeercert(True) + writer.close() + except ConnectionError: + traceback.print_exc() + print("Previous exception from _save_certificate") + return + except TimeoutError: + return + assert dercert + if not require_ca: + cert = ssl.DER_cert_to_PEM_cert(dercert) + else: + # Don't pin a CA signed certificate + cert = "" + temporary_path = cert_path + '.temp' + with open(temporary_path, "w") as f: + f.write(cert) + return temporary_path + async def _get_read_write(self): async with self.lock: if self.reader is not None and self.writer is not None: @@ -106,31 +166,12 @@ class Interface(util.PrintError): if self.use_ssl: cert_path = os.path.join(self.config_path, 'certs', self.host) if not os.path.exists(cert_path): - context = get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None) - if self.addr is not None: - proto_factory = lambda: SSLProtocol(asyncio.get_event_loop(), asyncio.Protocol(), context, None) - socks_create_coro = aiosocks.create_connection(proto_factory, \ - proxy=self.addr, \ - proxy_auth=self.auth, \ - dst=(self.host, self.port)) - transport, protocol = await asyncio.wait_for(socks_create_coro, 5) - while True: - try: - if protocol._sslpipe is not None: - dercert = protocol._sslpipe.ssl_object.getpeercert(True) - break - except ValueError: - print("sleeping for cert") - await asyncio.sleep(1) - transport.close() - else: - reader, writer = await asyncio.wait_for(self.conn_coro(context), 5) - dercert = writer.get_extra_info('ssl_object').getpeercert(True) - writer.close() - cert = ssl.DER_cert_to_PEM_cert(dercert) - temporary_path = cert_path + '.temp' - with open(temporary_path, "w") as f: - f.write(cert) + temporary_path = await self._save_certificate(cert_path, False) + if not temporary_path: + temporary_path = await self._save_certificate(cert_path, True) + if not temporary_path: + raise ConnectionError("Could not get certificate on second try") + is_new = True else: is_new = False @@ -151,6 +192,9 @@ class Interface(util.PrintError): else: context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs) if self.use_ssl else None self.reader, self.writer = await asyncio.wait_for(self.conn_coro(context), 5) + except TimeoutError: + print("TimeoutError after getting certificate successfully...") + raise except BaseException as e: traceback.print_exc() print("Previous exception will now be reraised") diff --git a/lib/network.py b/lib/network.py index 58295594..be184564 100644 --- a/lib/network.py +++ b/lib/network.py @@ -162,6 +162,7 @@ class Network(util.DaemonThread): """ def __init__(self, config=None): + asyncio.set_event_loop(None) if config is None: config = {} # Do not use mutables as default values! util.DaemonThread.__init__(self)