From 3e2881bcfc52a2e9788e47eb347ffb9ed9510af1 Mon Sep 17 00:00:00 2001 From: Janus Date: Tue, 19 Dec 2017 21:26:10 +0100 Subject: [PATCH] asyncio: add locks for more robust network handling --- lib/interface.py | 2 +- lib/network.py | 140 ++++++++++++++++++++++++++++------------------- 2 files changed, 84 insertions(+), 58 deletions(-) diff --git a/lib/interface.py b/lib/interface.py index a6299994..ae413ec3 100644 --- a/lib/interface.py +++ b/lib/interface.py @@ -189,7 +189,7 @@ class Interface(util.PrintError): self.reader, self.writer = await asyncio.wait_for(open_coro, 5) else: ssl_in_socks_coro = sslInSocksReaderWriter(self.addr, self.auth, self.host, self.port, ca_certs) - self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 10) + self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 5) else: context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs) if self.use_ssl else None self.reader, self.writer = await asyncio.wait_for(self.conn_coro(context), 5) diff --git a/lib/network.py b/lib/network.py index 34a98e66..b2775210 100644 --- a/lib/network.py +++ b/lib/network.py @@ -20,6 +20,7 @@ # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import sys import collections from functools import partial import time @@ -164,6 +165,7 @@ class Network(util.DaemonThread): """ def __init__(self, config=None): + self.disconnected_servers = {} self.stopped = True asyncio.set_event_loop(None) if config is None: @@ -382,7 +384,7 @@ class Network(util.DaemonThread): return await self.new_interface(server) async def start_random_interface(self): - exclude_set = self.disconnected_servers.union(set(self.interfaces.keys())) + exclude_set = set(self.disconnected_servers.keys()).union(set(self.interfaces.keys())) server = pick_random_server(self.get_servers(), self.protocol, exclude_set) if server: return await self.start_interface(server) @@ -398,7 +400,6 @@ class Network(util.DaemonThread): assert not self.interface and not self.interfaces assert all(not i.locked() for i in self.connecting.values()) self.print_error('starting network') - self.disconnected_servers = set([]) self.protocol = protocol self.proxy = proxy await self.start_interfaces() @@ -407,20 +408,19 @@ class Network(util.DaemonThread): self.stopped = True self.print_error("stopping network") async def stop(interface): - while True: - try: - await asyncio.wait_for(asyncio.shield(self.connection_down(interface.server)), 1.5) - break - except TimeoutError: - print("close_interface too slow...", interface.server) - try: - print(interface.server, await asyncio.wait_for(asyncio.shield(interface.future), 3)) - except TimeoutError: - print("interface future too slow", interface.server) - if self.interface: - await stop(self.interface) + await self.connection_down(interface.server, "stopping network") + await asyncio.wait_for(asyncio.shield(interface.future), 3) + stopped_this_time = set() while self.interfaces: - await stop(next(iter(self.interfaces.values()))) + do_next = next(iter(self.interfaces.values())) + assert do_next not in stopped_this_time + for i in self.disconnected_servers: + assert i not in self.interfaces.keys() + assert i != do_next.server + stopped_this_time.add(do_next) + await stop(do_next) + if self.interface: + assert self.interface.server in stopped_this_time, self.interface.server await asyncio.wait_for(asyncio.shield(self.process_pending_sends_job), 5) assert self.interface is None for i in range(100): @@ -430,8 +430,6 @@ class Network(util.DaemonThread): await asyncio.sleep(0.1) if self.interfaces: assert False, "interfaces not empty after waiting: " + repr(self.interfaces) - for i in self.connecting.values(): - assert not i.locked() # called from the Qt thread def set_parameters(self, host, port, protocol, proxy, auto_connect): @@ -455,13 +453,20 @@ class Network(util.DaemonThread): if self.proxy != proxy or self.protocol != protocol: async def job(): try: - # Restart the network defaulting to the given server - await self.stop_network() - self.default_server = server - await self.start_network(protocol, proxy) + async with self.restartLock: + # Restart the network defaulting to the given server + await self.stop_network() + print("STOOOOOOOOOOOOOOOOOOOOOOOOOOPPED") + self.default_server = server + async with self.all_server_locks("restart job"): + self.disconnected_servers = {} + await self.start_network(protocol, proxy) except BaseException as e: traceback.print_exc() print("exception from restart job") + if self.restartLock.locked(): + print("NOT RESTARTING, RESTART IN PROGRESS") + return asyncio.run_coroutine_threadsafe(job(), self.loop) elif self.default_server != server: async def job(): @@ -523,7 +528,11 @@ class Network(util.DaemonThread): for i in interface.jobs: asyncio.wait_for(i, 3) assert interface.boot_job - await asyncio.wait_for(asyncio.shield(interface.boot_job), 3) + try: + await asyncio.wait_for(asyncio.shield(interface.boot_job), 6) # longer than any timeout while connecting + except TimeoutError: + print("taking too long", interface.server) + raise interface.close() def add_recent_server(self, server): @@ -609,7 +618,7 @@ class Network(util.DaemonThread): self.subscribed_addresses.add(params[0]) else: if not response: # Closed remotely / misbehaving - await self.connection_down(interface.server) + if not self.stopped: await self.connection_down(interface.server, "no response in process responses") return # Rewrite response shape to match subscription request response method = response.get('method') @@ -703,15 +712,19 @@ class Network(util.DaemonThread): if callback in v: v.remove(callback) - async def connection_down(self, server): + async def connection_down(self, server, reason=None): '''A connection to server either went down, or was never made. We distinguish by whether it is in self.interfaces.''' - async with self.connecting[server]: + async with self.all_server_locks("connection down"): if server in self.disconnected_servers: - self.print_error("already disconnected", server) + try: + raise Exception("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason)) + except: + traceback.print_exc() + sys.exit(1) return self.print_error("connection down", server) - self.disconnected_servers.add(server) + self.disconnected_servers[server] = reason if server == self.default_server: self.set_status('disconnected') if server in self.interfaces: @@ -735,6 +748,7 @@ class Network(util.DaemonThread): interface.boot_job = None self.boot_interface(interface) assert server not in self.interfaces + assert not self.stopped self.interfaces[server] = interface return interface @@ -759,7 +773,7 @@ class Network(util.DaemonThread): self.requested_chunks.remove(index) connect = interface.blockchain.connect_chunk(index, result) if not connect: - await self.connection_down(interface.server) + await self.connection_down(interface.server, "could not connect chunk") return # If not finished, get the next chunk if interface.blockchain.height() < interface.tip: @@ -781,12 +795,12 @@ class Network(util.DaemonThread): header = response.get('result') if not header: interface.print_error(response) - await self.connection_down(interface.server) + await self.connection_down(interface.server, "no header in on_get_header") return height = header.get('block_height') if interface.request != height: interface.print_error("unsolicited header",interface.request, height) - await self.connection_down(interface.server) + await self.connection_down(interface.server, "unsolicited header") return chain = blockchain.check_header(header) if interface.mode == 'backward': @@ -806,7 +820,7 @@ class Network(util.DaemonThread): assert next_height >= self.max_checkpoint(), (interface.bad, interface.good) else: if height == 0: - await self.connection_down(interface.server) + await self.connection_down(interface.server, "height zero in on_get_header") next_height = None else: interface.bad = height @@ -825,7 +839,7 @@ class Network(util.DaemonThread): next_height = (interface.bad + interface.good) // 2 assert next_height >= self.max_checkpoint() elif not interface.blockchain.can_connect(interface.bad_header, check_height=False): - await self.connection_down(interface.server) + await self.connection_down(interface.server, "blockchain can't connect") next_height = None else: branch = self.blockchains.get(interface.bad) @@ -906,7 +920,7 @@ class Network(util.DaemonThread): for interface in list(self.interfaces.values()): if interface.request and time.time() - interface.request_time > 20: interface.print_error("blockchain request timed out") - await self.connection_down(interface.server) + await self.connection_down(interface.server, "blockchain request timed out") def make_send_requests_job(self, interface): async def job(): @@ -916,8 +930,8 @@ class Network(util.DaemonThread): result = await asyncio.wait_for(asyncio.shield(interface.send_request()), 1) except TimeoutError: continue - if not result: - await self.connection_down(interface.server) + if not result and not self.stopped: + await self.connection_down(interface.server, "send_request returned false") except GeneratorExit: pass except: @@ -933,7 +947,7 @@ class Network(util.DaemonThread): except GeneratorExit: pass except OSError: - await self.connection_down(interface.server) + await self.connection_down(interface.server, "OSError in process_responses") print("OS error, connection downed") except BaseException: if not self.stopped: @@ -974,21 +988,21 @@ class Network(util.DaemonThread): try: await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface) if not await interface.send_request(): - asyncio.ensure_future(self.connection_down(interface.server)) + if not self.stopped: + asyncio.ensure_future(self.connection_down(interface.server, "send_request false in boot_interface")) interface.future.set_result("could not send request") return if self.stopped: - asyncio.ensure_future(self.connection_down(interface.server)) interface.future.set_result("stopped after sending request") return try: await asyncio.wait_for(interface.get_response(), 1) except TimeoutError: - asyncio.ensure_future(self.connection_down(interface.server)) + if not self.stopped: + asyncio.ensure_future(self.connection_down(interface.server, "timeout in boot_interface while getting response")) interface.future.set_result("timeout while getting response") return if self.stopped: - asyncio.ensure_future(self.connection_down(interface.server)) interface.future.set_result("stopped after getting response") return #self.interfaces[interface.server] = interface @@ -996,15 +1010,12 @@ class Network(util.DaemonThread): if interface.server == self.default_server: await asyncio.wait_for(self.switch_to_interface(interface.server), 1) interface.jobs = [asyncio.ensure_future(x) for x in [self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)]] - def cb(num, fut): + gathered = asyncio.gather(*interface.jobs) + while not self.stopped: try: - fut.exception() - except e: - interface.future.set_exception(e) - else: - if not interface.future.done(): interface.future.set_result(str(num) + " done") - for num, i in enumerate(interface.jobs): - i.add_done_callback(partial(cb, num)) + await asyncio.wait_for(asyncio.shield(gathered), 1) + except TimeoutError: + pass interface.future.set_result("finished") return #self.notify('interfaces') @@ -1035,7 +1046,7 @@ class Network(util.DaemonThread): # must use copy of values if interface.has_timed_out(): print(interface.server, "timed out") - await self.connection_down(interface.server) + await self.connection_down(interface.server, "time out in ping_job") return elif interface.ping_required(): params = [ELECTRUM_VERSION, PROTOCOL_VERSION] @@ -1048,6 +1059,19 @@ class Network(util.DaemonThread): print("FATAL ERROR in ping_job") return asyncio.ensure_future(job()) + def all_server_locks(self, ctx): + class AllLocks: + def __init__(self2): + self2.list = list(self.get_servers().keys()) + self2.ctx = ctx + async def __aenter__(self2): + for i in self2.list: + await asyncio.wait_for(self.connecting[i].acquire(), 3) + async def __aexit__(self2, exc_type, exc, tb): + for i in self2.list: + self.connecting[i].release() + return AllLocks() + async def maintain_interfaces(self): if self.stopped: return @@ -1057,7 +1081,8 @@ class Network(util.DaemonThread): await self.start_random_interface() if now - self.nodes_retry_time > NODES_RETRY_INTERVAL: self.print_error('network: retrying connections') - self.disconnected_servers = set([]) + async with self.all_server_locks("maintain_interfaces"): + self.disconnected_servers = {} self.nodes_retry_time = now # main interface @@ -1066,13 +1091,13 @@ class Network(util.DaemonThread): if not self.is_connecting(): await self.switch_to_random_interface() else: - async with self.connecting[self.default_server]: - if self.default_server in self.disconnected_servers: - if now - self.server_retry_time > SERVER_RETRY_INTERVAL: - self.disconnected_servers.remove(self.default_server) + if self.default_server in self.disconnected_servers: + if now - self.server_retry_time > SERVER_RETRY_INTERVAL: + async with self.all_server_locks("maintain_interfaces 2"): + del self.disconnected_servers[self.default_server] self.server_retry_time = now - else: - await self.switch_to_interface(self.default_server) + else: + await self.switch_to_interface(self.default_server) else: if self.config.is_fee_estimates_update_required(): await self.request_fee_estimates() @@ -1085,6 +1110,7 @@ class Network(util.DaemonThread): self.init_headers_file() self.pending_sends = asyncio.Queue() self.connecting = collections.defaultdict(asyncio.Lock) + self.restartLock = asyncio.Lock() async def job(): try: @@ -1132,7 +1158,7 @@ class Network(util.DaemonThread): if not height: return if height < self.max_checkpoint(): - await self.connection_down(interface.server) + await self.connection_down(interface.server, "height under max checkpoint in on_notify_header") return interface.tip_header = header interface.tip = height