From 0879b98c60beec8f7ba86baf03a1a506e589c917 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Thu, 8 Mar 2018 15:27:21 +0100 Subject: [PATCH] use a single lock for servers --- lib/network.py | 37 ++++++++++++++----------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/lib/network.py b/lib/network.py index 99e78641..faae5f18 100644 --- a/lib/network.py +++ b/lib/network.py @@ -168,6 +168,8 @@ class Network(util.DaemonThread): def __init__(self, config=None): self.disconnected_servers = {} + self.connecting = set() + self.servers_lock = asyncio.Lock() self.stopped = True asyncio.set_event_loop(None) if config is None: @@ -384,13 +386,16 @@ class Network(util.DaemonThread): return out async def start_interface(self, server): - assert not self.connecting[server].locked() - async with self.connecting[server]: + assert server not in self.connecting + async with self.servers_lock: if (not server in self.interfaces): if server == self.default_server: self.print_error("connecting to %s as new interface" % server) self.set_status('connecting') - return await self.new_interface(server) + self.connecting.add(server) + i = await self.new_interface(server) + self.connecting.remove(server) + return i async def start_random_interface(self): exclude_set = set(self.disconnected_servers.keys()).union(set(self.interfaces.keys())) @@ -407,7 +412,7 @@ class Network(util.DaemonThread): async def start_network(self, protocol, proxy): self.stopped = False assert not self.interface and not self.interfaces - assert all(not i.locked() for i in self.connecting.values()) + assert len(self.connecting) == 0 self.print_error('starting network') self.protocol = protocol self.proxy = proxy @@ -467,7 +472,7 @@ class Network(util.DaemonThread): await self.stop_network() self.print_error("STOOOOOOOOOOOOOOOOOOOOOOOOOOPPED") self.default_server = server - async with self.all_server_locks("restart job"): + async with self.servers_lock: self.disconnected_servers = {} await self.start_network(protocol, proxy) except BaseException as e: @@ -734,7 +739,7 @@ class Network(util.DaemonThread): async def connection_down(self, server, reason=None): '''A connection to server either went down, or was never made. We distinguish by whether it is in self.interfaces.''' - async with self.all_server_locks("connection down"): + async with self.servers_lock: if server in self.disconnected_servers: try: raise Exception("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason)) @@ -1075,29 +1080,16 @@ class Network(util.DaemonThread): self.print_error("FATAL ERROR in ping_job") return asyncio.ensure_future(job()) - def all_server_locks(self, ctx): - class AllLocks: - def __init__(self2): - self2.list = list(self.get_servers().keys()) - self2.ctx = ctx - async def __aenter__(self2): - for i in self2.list: - await asyncio.wait_for(self.connecting[i].acquire(), 3) - async def __aexit__(self2, exc_type, exc, tb): - for i in self2.list: - self.connecting[i].release() - return AllLocks() - async def maintain_interfaces(self): if self.stopped: return now = time.time() # nodes - if len(self.interfaces) + sum((1 if x.locked() else 0) for x in self.connecting.values()) < self.num_server: + if len(self.interfaces) + len(self.connecting) < self.num_server: await self.start_random_interface() if now - self.nodes_retry_time > NODES_RETRY_INTERVAL: self.print_error('network: retrying connections') - async with self.all_server_locks("maintain_interfaces"): + async with self.servers_lock: self.disconnected_servers = {} self.nodes_retry_time = now @@ -1109,7 +1101,7 @@ class Network(util.DaemonThread): else: if self.default_server in self.disconnected_servers: if now - self.server_retry_time > SERVER_RETRY_INTERVAL: - async with self.all_server_locks("maintain_interfaces 2"): + async with self.servers_locks: del self.disconnected_servers[self.default_server] self.server_retry_time = now else: @@ -1125,7 +1117,6 @@ class Network(util.DaemonThread): self.loop = loop # so we store it in the instance too self.init_headers_file() self.pending_sends = asyncio.Queue() - self.connecting = collections.defaultdict(asyncio.Lock) self.restartLock = asyncio.Lock() async def job():