asyncio: add locks for more robust network handling
This commit is contained in:
parent
1cfdcf4e25
commit
3e2881bcfc
@ -189,7 +189,7 @@ class Interface(util.PrintError):
|
||||
self.reader, self.writer = await asyncio.wait_for(open_coro, 5)
|
||||
else:
|
||||
ssl_in_socks_coro = sslInSocksReaderWriter(self.addr, self.auth, self.host, self.port, ca_certs)
|
||||
self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 10)
|
||||
self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 5)
|
||||
else:
|
||||
context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs) if self.use_ssl else None
|
||||
self.reader, self.writer = await asyncio.wait_for(self.conn_coro(context), 5)
|
||||
|
||||
140
lib/network.py
140
lib/network.py
@ -20,6 +20,7 @@
|
||||
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import sys
|
||||
import collections
|
||||
from functools import partial
|
||||
import time
|
||||
@ -164,6 +165,7 @@ class Network(util.DaemonThread):
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
self.disconnected_servers = {}
|
||||
self.stopped = True
|
||||
asyncio.set_event_loop(None)
|
||||
if config is None:
|
||||
@ -382,7 +384,7 @@ class Network(util.DaemonThread):
|
||||
return await self.new_interface(server)
|
||||
|
||||
async def start_random_interface(self):
|
||||
exclude_set = self.disconnected_servers.union(set(self.interfaces.keys()))
|
||||
exclude_set = set(self.disconnected_servers.keys()).union(set(self.interfaces.keys()))
|
||||
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
|
||||
if server:
|
||||
return await self.start_interface(server)
|
||||
@ -398,7 +400,6 @@ class Network(util.DaemonThread):
|
||||
assert not self.interface and not self.interfaces
|
||||
assert all(not i.locked() for i in self.connecting.values())
|
||||
self.print_error('starting network')
|
||||
self.disconnected_servers = set([])
|
||||
self.protocol = protocol
|
||||
self.proxy = proxy
|
||||
await self.start_interfaces()
|
||||
@ -407,20 +408,19 @@ class Network(util.DaemonThread):
|
||||
self.stopped = True
|
||||
self.print_error("stopping network")
|
||||
async def stop(interface):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(self.connection_down(interface.server)), 1.5)
|
||||
break
|
||||
except TimeoutError:
|
||||
print("close_interface too slow...", interface.server)
|
||||
try:
|
||||
print(interface.server, await asyncio.wait_for(asyncio.shield(interface.future), 3))
|
||||
except TimeoutError:
|
||||
print("interface future too slow", interface.server)
|
||||
if self.interface:
|
||||
await stop(self.interface)
|
||||
await self.connection_down(interface.server, "stopping network")
|
||||
await asyncio.wait_for(asyncio.shield(interface.future), 3)
|
||||
stopped_this_time = set()
|
||||
while self.interfaces:
|
||||
await stop(next(iter(self.interfaces.values())))
|
||||
do_next = next(iter(self.interfaces.values()))
|
||||
assert do_next not in stopped_this_time
|
||||
for i in self.disconnected_servers:
|
||||
assert i not in self.interfaces.keys()
|
||||
assert i != do_next.server
|
||||
stopped_this_time.add(do_next)
|
||||
await stop(do_next)
|
||||
if self.interface:
|
||||
assert self.interface.server in stopped_this_time, self.interface.server
|
||||
await asyncio.wait_for(asyncio.shield(self.process_pending_sends_job), 5)
|
||||
assert self.interface is None
|
||||
for i in range(100):
|
||||
@ -430,8 +430,6 @@ class Network(util.DaemonThread):
|
||||
await asyncio.sleep(0.1)
|
||||
if self.interfaces:
|
||||
assert False, "interfaces not empty after waiting: " + repr(self.interfaces)
|
||||
for i in self.connecting.values():
|
||||
assert not i.locked()
|
||||
|
||||
# called from the Qt thread
|
||||
def set_parameters(self, host, port, protocol, proxy, auto_connect):
|
||||
@ -455,13 +453,20 @@ class Network(util.DaemonThread):
|
||||
if self.proxy != proxy or self.protocol != protocol:
|
||||
async def job():
|
||||
try:
|
||||
# Restart the network defaulting to the given server
|
||||
await self.stop_network()
|
||||
self.default_server = server
|
||||
await self.start_network(protocol, proxy)
|
||||
async with self.restartLock:
|
||||
# Restart the network defaulting to the given server
|
||||
await self.stop_network()
|
||||
print("STOOOOOOOOOOOOOOOOOOOOOOOOOOPPED")
|
||||
self.default_server = server
|
||||
async with self.all_server_locks("restart job"):
|
||||
self.disconnected_servers = {}
|
||||
await self.start_network(protocol, proxy)
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
print("exception from restart job")
|
||||
if self.restartLock.locked():
|
||||
print("NOT RESTARTING, RESTART IN PROGRESS")
|
||||
return
|
||||
asyncio.run_coroutine_threadsafe(job(), self.loop)
|
||||
elif self.default_server != server:
|
||||
async def job():
|
||||
@ -523,7 +528,11 @@ class Network(util.DaemonThread):
|
||||
for i in interface.jobs:
|
||||
asyncio.wait_for(i, 3)
|
||||
assert interface.boot_job
|
||||
await asyncio.wait_for(asyncio.shield(interface.boot_job), 3)
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(interface.boot_job), 6) # longer than any timeout while connecting
|
||||
except TimeoutError:
|
||||
print("taking too long", interface.server)
|
||||
raise
|
||||
interface.close()
|
||||
|
||||
def add_recent_server(self, server):
|
||||
@ -609,7 +618,7 @@ class Network(util.DaemonThread):
|
||||
self.subscribed_addresses.add(params[0])
|
||||
else:
|
||||
if not response: # Closed remotely / misbehaving
|
||||
await self.connection_down(interface.server)
|
||||
if not self.stopped: await self.connection_down(interface.server, "no response in process responses")
|
||||
return
|
||||
# Rewrite response shape to match subscription request response
|
||||
method = response.get('method')
|
||||
@ -703,15 +712,19 @@ class Network(util.DaemonThread):
|
||||
if callback in v:
|
||||
v.remove(callback)
|
||||
|
||||
async def connection_down(self, server):
|
||||
async def connection_down(self, server, reason=None):
|
||||
'''A connection to server either went down, or was never made.
|
||||
We distinguish by whether it is in self.interfaces.'''
|
||||
async with self.connecting[server]:
|
||||
async with self.all_server_locks("connection down"):
|
||||
if server in self.disconnected_servers:
|
||||
self.print_error("already disconnected", server)
|
||||
try:
|
||||
raise Exception("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason))
|
||||
except:
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
return
|
||||
self.print_error("connection down", server)
|
||||
self.disconnected_servers.add(server)
|
||||
self.disconnected_servers[server] = reason
|
||||
if server == self.default_server:
|
||||
self.set_status('disconnected')
|
||||
if server in self.interfaces:
|
||||
@ -735,6 +748,7 @@ class Network(util.DaemonThread):
|
||||
interface.boot_job = None
|
||||
self.boot_interface(interface)
|
||||
assert server not in self.interfaces
|
||||
assert not self.stopped
|
||||
self.interfaces[server] = interface
|
||||
return interface
|
||||
|
||||
@ -759,7 +773,7 @@ class Network(util.DaemonThread):
|
||||
self.requested_chunks.remove(index)
|
||||
connect = interface.blockchain.connect_chunk(index, result)
|
||||
if not connect:
|
||||
await self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server, "could not connect chunk")
|
||||
return
|
||||
# If not finished, get the next chunk
|
||||
if interface.blockchain.height() < interface.tip:
|
||||
@ -781,12 +795,12 @@ class Network(util.DaemonThread):
|
||||
header = response.get('result')
|
||||
if not header:
|
||||
interface.print_error(response)
|
||||
await self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server, "no header in on_get_header")
|
||||
return
|
||||
height = header.get('block_height')
|
||||
if interface.request != height:
|
||||
interface.print_error("unsolicited header",interface.request, height)
|
||||
await self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server, "unsolicited header")
|
||||
return
|
||||
chain = blockchain.check_header(header)
|
||||
if interface.mode == 'backward':
|
||||
@ -806,7 +820,7 @@ class Network(util.DaemonThread):
|
||||
assert next_height >= self.max_checkpoint(), (interface.bad, interface.good)
|
||||
else:
|
||||
if height == 0:
|
||||
await self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server, "height zero in on_get_header")
|
||||
next_height = None
|
||||
else:
|
||||
interface.bad = height
|
||||
@ -825,7 +839,7 @@ class Network(util.DaemonThread):
|
||||
next_height = (interface.bad + interface.good) // 2
|
||||
assert next_height >= self.max_checkpoint()
|
||||
elif not interface.blockchain.can_connect(interface.bad_header, check_height=False):
|
||||
await self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server, "blockchain can't connect")
|
||||
next_height = None
|
||||
else:
|
||||
branch = self.blockchains.get(interface.bad)
|
||||
@ -906,7 +920,7 @@ class Network(util.DaemonThread):
|
||||
for interface in list(self.interfaces.values()):
|
||||
if interface.request and time.time() - interface.request_time > 20:
|
||||
interface.print_error("blockchain request timed out")
|
||||
await self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server, "blockchain request timed out")
|
||||
|
||||
def make_send_requests_job(self, interface):
|
||||
async def job():
|
||||
@ -916,8 +930,8 @@ class Network(util.DaemonThread):
|
||||
result = await asyncio.wait_for(asyncio.shield(interface.send_request()), 1)
|
||||
except TimeoutError:
|
||||
continue
|
||||
if not result:
|
||||
await self.connection_down(interface.server)
|
||||
if not result and not self.stopped:
|
||||
await self.connection_down(interface.server, "send_request returned false")
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except:
|
||||
@ -933,7 +947,7 @@ class Network(util.DaemonThread):
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except OSError:
|
||||
await self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server, "OSError in process_responses")
|
||||
print("OS error, connection downed")
|
||||
except BaseException:
|
||||
if not self.stopped:
|
||||
@ -974,21 +988,21 @@ class Network(util.DaemonThread):
|
||||
try:
|
||||
await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface)
|
||||
if not await interface.send_request():
|
||||
asyncio.ensure_future(self.connection_down(interface.server))
|
||||
if not self.stopped:
|
||||
asyncio.ensure_future(self.connection_down(interface.server, "send_request false in boot_interface"))
|
||||
interface.future.set_result("could not send request")
|
||||
return
|
||||
if self.stopped:
|
||||
asyncio.ensure_future(self.connection_down(interface.server))
|
||||
interface.future.set_result("stopped after sending request")
|
||||
return
|
||||
try:
|
||||
await asyncio.wait_for(interface.get_response(), 1)
|
||||
except TimeoutError:
|
||||
asyncio.ensure_future(self.connection_down(interface.server))
|
||||
if not self.stopped:
|
||||
asyncio.ensure_future(self.connection_down(interface.server, "timeout in boot_interface while getting response"))
|
||||
interface.future.set_result("timeout while getting response")
|
||||
return
|
||||
if self.stopped:
|
||||
asyncio.ensure_future(self.connection_down(interface.server))
|
||||
interface.future.set_result("stopped after getting response")
|
||||
return
|
||||
#self.interfaces[interface.server] = interface
|
||||
@ -996,15 +1010,12 @@ class Network(util.DaemonThread):
|
||||
if interface.server == self.default_server:
|
||||
await asyncio.wait_for(self.switch_to_interface(interface.server), 1)
|
||||
interface.jobs = [asyncio.ensure_future(x) for x in [self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)]]
|
||||
def cb(num, fut):
|
||||
gathered = asyncio.gather(*interface.jobs)
|
||||
while not self.stopped:
|
||||
try:
|
||||
fut.exception()
|
||||
except e:
|
||||
interface.future.set_exception(e)
|
||||
else:
|
||||
if not interface.future.done(): interface.future.set_result(str(num) + " done")
|
||||
for num, i in enumerate(interface.jobs):
|
||||
i.add_done_callback(partial(cb, num))
|
||||
await asyncio.wait_for(asyncio.shield(gathered), 1)
|
||||
except TimeoutError:
|
||||
pass
|
||||
interface.future.set_result("finished")
|
||||
return
|
||||
#self.notify('interfaces')
|
||||
@ -1035,7 +1046,7 @@ class Network(util.DaemonThread):
|
||||
# must use copy of values
|
||||
if interface.has_timed_out():
|
||||
print(interface.server, "timed out")
|
||||
await self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server, "time out in ping_job")
|
||||
return
|
||||
elif interface.ping_required():
|
||||
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
|
||||
@ -1048,6 +1059,19 @@ class Network(util.DaemonThread):
|
||||
print("FATAL ERROR in ping_job")
|
||||
return asyncio.ensure_future(job())
|
||||
|
||||
def all_server_locks(self, ctx):
|
||||
class AllLocks:
|
||||
def __init__(self2):
|
||||
self2.list = list(self.get_servers().keys())
|
||||
self2.ctx = ctx
|
||||
async def __aenter__(self2):
|
||||
for i in self2.list:
|
||||
await asyncio.wait_for(self.connecting[i].acquire(), 3)
|
||||
async def __aexit__(self2, exc_type, exc, tb):
|
||||
for i in self2.list:
|
||||
self.connecting[i].release()
|
||||
return AllLocks()
|
||||
|
||||
async def maintain_interfaces(self):
|
||||
if self.stopped: return
|
||||
|
||||
@ -1057,7 +1081,8 @@ class Network(util.DaemonThread):
|
||||
await self.start_random_interface()
|
||||
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
|
||||
self.print_error('network: retrying connections')
|
||||
self.disconnected_servers = set([])
|
||||
async with self.all_server_locks("maintain_interfaces"):
|
||||
self.disconnected_servers = {}
|
||||
self.nodes_retry_time = now
|
||||
|
||||
# main interface
|
||||
@ -1066,13 +1091,13 @@ class Network(util.DaemonThread):
|
||||
if not self.is_connecting():
|
||||
await self.switch_to_random_interface()
|
||||
else:
|
||||
async with self.connecting[self.default_server]:
|
||||
if self.default_server in self.disconnected_servers:
|
||||
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
|
||||
self.disconnected_servers.remove(self.default_server)
|
||||
if self.default_server in self.disconnected_servers:
|
||||
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
|
||||
async with self.all_server_locks("maintain_interfaces 2"):
|
||||
del self.disconnected_servers[self.default_server]
|
||||
self.server_retry_time = now
|
||||
else:
|
||||
await self.switch_to_interface(self.default_server)
|
||||
else:
|
||||
await self.switch_to_interface(self.default_server)
|
||||
else:
|
||||
if self.config.is_fee_estimates_update_required():
|
||||
await self.request_fee_estimates()
|
||||
@ -1085,6 +1110,7 @@ class Network(util.DaemonThread):
|
||||
self.init_headers_file()
|
||||
self.pending_sends = asyncio.Queue()
|
||||
self.connecting = collections.defaultdict(asyncio.Lock)
|
||||
self.restartLock = asyncio.Lock()
|
||||
|
||||
async def job():
|
||||
try:
|
||||
@ -1132,7 +1158,7 @@ class Network(util.DaemonThread):
|
||||
if not height:
|
||||
return
|
||||
if height < self.max_checkpoint():
|
||||
await self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server, "height under max checkpoint in on_notify_header")
|
||||
return
|
||||
interface.tip_header = header
|
||||
interface.tip = height
|
||||
|
||||
Loading…
Reference in New Issue
Block a user