use a single lock for servers

This commit is contained in:
ThomasV 2018-03-08 15:27:21 +01:00
parent 7a113fb954
commit 0879b98c60

View File

@ -168,6 +168,8 @@ class Network(util.DaemonThread):
def __init__(self, config=None):
self.disconnected_servers = {}
self.connecting = set()
self.servers_lock = asyncio.Lock()
self.stopped = True
asyncio.set_event_loop(None)
if config is None:
@ -384,13 +386,16 @@ class Network(util.DaemonThread):
return out
async def start_interface(self, server):
assert not self.connecting[server].locked()
async with self.connecting[server]:
assert server not in self.connecting
async with self.servers_lock:
if (not server in self.interfaces):
if server == self.default_server:
self.print_error("connecting to %s as new interface" % server)
self.set_status('connecting')
return await self.new_interface(server)
self.connecting.add(server)
i = await self.new_interface(server)
self.connecting.remove(server)
return i
async def start_random_interface(self):
exclude_set = set(self.disconnected_servers.keys()).union(set(self.interfaces.keys()))
@ -407,7 +412,7 @@ class Network(util.DaemonThread):
async def start_network(self, protocol, proxy):
self.stopped = False
assert not self.interface and not self.interfaces
assert all(not i.locked() for i in self.connecting.values())
assert len(self.connecting) == 0
self.print_error('starting network')
self.protocol = protocol
self.proxy = proxy
@ -467,7 +472,7 @@ class Network(util.DaemonThread):
await self.stop_network()
self.print_error("STOOOOOOOOOOOOOOOOOOOOOOOOOOPPED")
self.default_server = server
async with self.all_server_locks("restart job"):
async with self.servers_lock:
self.disconnected_servers = {}
await self.start_network(protocol, proxy)
except BaseException as e:
@ -734,7 +739,7 @@ class Network(util.DaemonThread):
async def connection_down(self, server, reason=None):
'''A connection to server either went down, or was never made.
We distinguish by whether it is in self.interfaces.'''
async with self.all_server_locks("connection down"):
async with self.servers_lock:
if server in self.disconnected_servers:
try:
raise Exception("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason))
@ -1075,29 +1080,16 @@ class Network(util.DaemonThread):
self.print_error("FATAL ERROR in ping_job")
return asyncio.ensure_future(job())
def all_server_locks(self, ctx):
class AllLocks:
def __init__(self2):
self2.list = list(self.get_servers().keys())
self2.ctx = ctx
async def __aenter__(self2):
for i in self2.list:
await asyncio.wait_for(self.connecting[i].acquire(), 3)
async def __aexit__(self2, exc_type, exc, tb):
for i in self2.list:
self.connecting[i].release()
return AllLocks()
async def maintain_interfaces(self):
if self.stopped: return
now = time.time()
# nodes
if len(self.interfaces) + sum((1 if x.locked() else 0) for x in self.connecting.values()) < self.num_server:
if len(self.interfaces) + len(self.connecting) < self.num_server:
await self.start_random_interface()
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
self.print_error('network: retrying connections')
async with self.all_server_locks("maintain_interfaces"):
async with self.servers_lock:
self.disconnected_servers = {}
self.nodes_retry_time = now
@ -1109,7 +1101,7 @@ class Network(util.DaemonThread):
else:
if self.default_server in self.disconnected_servers:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
async with self.all_server_locks("maintain_interfaces 2"):
async with self.servers_locks:
del self.disconnected_servers[self.default_server]
self.server_retry_time = now
else:
@ -1125,7 +1117,6 @@ class Network(util.DaemonThread):
self.loop = loop # so we store it in the instance too
self.init_headers_file()
self.pending_sends = asyncio.Queue()
self.connecting = collections.defaultdict(asyncio.Lock)
self.restartLock = asyncio.Lock()
async def job():