asyncio: add locks for more robust network handling

This commit is contained in:
Janus 2017-12-19 21:26:10 +01:00
parent 1cfdcf4e25
commit 3e2881bcfc
2 changed files with 84 additions and 58 deletions

View File

@ -189,7 +189,7 @@ class Interface(util.PrintError):
self.reader, self.writer = await asyncio.wait_for(open_coro, 5)
else:
ssl_in_socks_coro = sslInSocksReaderWriter(self.addr, self.auth, self.host, self.port, ca_certs)
self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 10)
self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 5)
else:
context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs) if self.use_ssl else None
self.reader, self.writer = await asyncio.wait_for(self.conn_coro(context), 5)

View File

@ -20,6 +20,7 @@
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import sys
import collections
from functools import partial
import time
@ -164,6 +165,7 @@ class Network(util.DaemonThread):
"""
def __init__(self, config=None):
self.disconnected_servers = {}
self.stopped = True
asyncio.set_event_loop(None)
if config is None:
@ -382,7 +384,7 @@ class Network(util.DaemonThread):
return await self.new_interface(server)
async def start_random_interface(self):
exclude_set = self.disconnected_servers.union(set(self.interfaces.keys()))
exclude_set = set(self.disconnected_servers.keys()).union(set(self.interfaces.keys()))
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
if server:
return await self.start_interface(server)
@ -398,7 +400,6 @@ class Network(util.DaemonThread):
assert not self.interface and not self.interfaces
assert all(not i.locked() for i in self.connecting.values())
self.print_error('starting network')
self.disconnected_servers = set([])
self.protocol = protocol
self.proxy = proxy
await self.start_interfaces()
@ -407,20 +408,19 @@ class Network(util.DaemonThread):
self.stopped = True
self.print_error("stopping network")
async def stop(interface):
while True:
try:
await asyncio.wait_for(asyncio.shield(self.connection_down(interface.server)), 1.5)
break
except TimeoutError:
print("close_interface too slow...", interface.server)
try:
print(interface.server, await asyncio.wait_for(asyncio.shield(interface.future), 3))
except TimeoutError:
print("interface future too slow", interface.server)
if self.interface:
await stop(self.interface)
await self.connection_down(interface.server, "stopping network")
await asyncio.wait_for(asyncio.shield(interface.future), 3)
stopped_this_time = set()
while self.interfaces:
await stop(next(iter(self.interfaces.values())))
do_next = next(iter(self.interfaces.values()))
assert do_next not in stopped_this_time
for i in self.disconnected_servers:
assert i not in self.interfaces.keys()
assert i != do_next.server
stopped_this_time.add(do_next)
await stop(do_next)
if self.interface:
assert self.interface.server in stopped_this_time, self.interface.server
await asyncio.wait_for(asyncio.shield(self.process_pending_sends_job), 5)
assert self.interface is None
for i in range(100):
@ -430,8 +430,6 @@ class Network(util.DaemonThread):
await asyncio.sleep(0.1)
if self.interfaces:
assert False, "interfaces not empty after waiting: " + repr(self.interfaces)
for i in self.connecting.values():
assert not i.locked()
# called from the Qt thread
def set_parameters(self, host, port, protocol, proxy, auto_connect):
@ -455,13 +453,20 @@ class Network(util.DaemonThread):
if self.proxy != proxy or self.protocol != protocol:
async def job():
try:
# Restart the network defaulting to the given server
await self.stop_network()
self.default_server = server
await self.start_network(protocol, proxy)
async with self.restartLock:
# Restart the network defaulting to the given server
await self.stop_network()
print("STOOOOOOOOOOOOOOOOOOOOOOOOOOPPED")
self.default_server = server
async with self.all_server_locks("restart job"):
self.disconnected_servers = {}
await self.start_network(protocol, proxy)
except BaseException as e:
traceback.print_exc()
print("exception from restart job")
if self.restartLock.locked():
print("NOT RESTARTING, RESTART IN PROGRESS")
return
asyncio.run_coroutine_threadsafe(job(), self.loop)
elif self.default_server != server:
async def job():
@ -523,7 +528,11 @@ class Network(util.DaemonThread):
for i in interface.jobs:
asyncio.wait_for(i, 3)
assert interface.boot_job
await asyncio.wait_for(asyncio.shield(interface.boot_job), 3)
try:
await asyncio.wait_for(asyncio.shield(interface.boot_job), 6) # longer than any timeout while connecting
except TimeoutError:
print("taking too long", interface.server)
raise
interface.close()
def add_recent_server(self, server):
@ -609,7 +618,7 @@ class Network(util.DaemonThread):
self.subscribed_addresses.add(params[0])
else:
if not response: # Closed remotely / misbehaving
await self.connection_down(interface.server)
if not self.stopped: await self.connection_down(interface.server, "no response in process responses")
return
# Rewrite response shape to match subscription request response
method = response.get('method')
@ -703,15 +712,19 @@ class Network(util.DaemonThread):
if callback in v:
v.remove(callback)
async def connection_down(self, server):
async def connection_down(self, server, reason=None):
'''A connection to server either went down, or was never made.
We distinguish by whether it is in self.interfaces.'''
async with self.connecting[server]:
async with self.all_server_locks("connection down"):
if server in self.disconnected_servers:
self.print_error("already disconnected", server)
try:
raise Exception("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason))
except:
traceback.print_exc()
sys.exit(1)
return
self.print_error("connection down", server)
self.disconnected_servers.add(server)
self.disconnected_servers[server] = reason
if server == self.default_server:
self.set_status('disconnected')
if server in self.interfaces:
@ -735,6 +748,7 @@ class Network(util.DaemonThread):
interface.boot_job = None
self.boot_interface(interface)
assert server not in self.interfaces
assert not self.stopped
self.interfaces[server] = interface
return interface
@ -759,7 +773,7 @@ class Network(util.DaemonThread):
self.requested_chunks.remove(index)
connect = interface.blockchain.connect_chunk(index, result)
if not connect:
await self.connection_down(interface.server)
await self.connection_down(interface.server, "could not connect chunk")
return
# If not finished, get the next chunk
if interface.blockchain.height() < interface.tip:
@ -781,12 +795,12 @@ class Network(util.DaemonThread):
header = response.get('result')
if not header:
interface.print_error(response)
await self.connection_down(interface.server)
await self.connection_down(interface.server, "no header in on_get_header")
return
height = header.get('block_height')
if interface.request != height:
interface.print_error("unsolicited header",interface.request, height)
await self.connection_down(interface.server)
await self.connection_down(interface.server, "unsolicited header")
return
chain = blockchain.check_header(header)
if interface.mode == 'backward':
@ -806,7 +820,7 @@ class Network(util.DaemonThread):
assert next_height >= self.max_checkpoint(), (interface.bad, interface.good)
else:
if height == 0:
await self.connection_down(interface.server)
await self.connection_down(interface.server, "height zero in on_get_header")
next_height = None
else:
interface.bad = height
@ -825,7 +839,7 @@ class Network(util.DaemonThread):
next_height = (interface.bad + interface.good) // 2
assert next_height >= self.max_checkpoint()
elif not interface.blockchain.can_connect(interface.bad_header, check_height=False):
await self.connection_down(interface.server)
await self.connection_down(interface.server, "blockchain can't connect")
next_height = None
else:
branch = self.blockchains.get(interface.bad)
@ -906,7 +920,7 @@ class Network(util.DaemonThread):
for interface in list(self.interfaces.values()):
if interface.request and time.time() - interface.request_time > 20:
interface.print_error("blockchain request timed out")
await self.connection_down(interface.server)
await self.connection_down(interface.server, "blockchain request timed out")
def make_send_requests_job(self, interface):
async def job():
@ -916,8 +930,8 @@ class Network(util.DaemonThread):
result = await asyncio.wait_for(asyncio.shield(interface.send_request()), 1)
except TimeoutError:
continue
if not result:
await self.connection_down(interface.server)
if not result and not self.stopped:
await self.connection_down(interface.server, "send_request returned false")
except GeneratorExit:
pass
except:
@ -933,7 +947,7 @@ class Network(util.DaemonThread):
except GeneratorExit:
pass
except OSError:
await self.connection_down(interface.server)
await self.connection_down(interface.server, "OSError in process_responses")
print("OS error, connection downed")
except BaseException:
if not self.stopped:
@ -974,21 +988,21 @@ class Network(util.DaemonThread):
try:
await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface)
if not await interface.send_request():
asyncio.ensure_future(self.connection_down(interface.server))
if not self.stopped:
asyncio.ensure_future(self.connection_down(interface.server, "send_request false in boot_interface"))
interface.future.set_result("could not send request")
return
if self.stopped:
asyncio.ensure_future(self.connection_down(interface.server))
interface.future.set_result("stopped after sending request")
return
try:
await asyncio.wait_for(interface.get_response(), 1)
except TimeoutError:
asyncio.ensure_future(self.connection_down(interface.server))
if not self.stopped:
asyncio.ensure_future(self.connection_down(interface.server, "timeout in boot_interface while getting response"))
interface.future.set_result("timeout while getting response")
return
if self.stopped:
asyncio.ensure_future(self.connection_down(interface.server))
interface.future.set_result("stopped after getting response")
return
#self.interfaces[interface.server] = interface
@ -996,15 +1010,12 @@ class Network(util.DaemonThread):
if interface.server == self.default_server:
await asyncio.wait_for(self.switch_to_interface(interface.server), 1)
interface.jobs = [asyncio.ensure_future(x) for x in [self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)]]
def cb(num, fut):
gathered = asyncio.gather(*interface.jobs)
while not self.stopped:
try:
fut.exception()
except e:
interface.future.set_exception(e)
else:
if not interface.future.done(): interface.future.set_result(str(num) + " done")
for num, i in enumerate(interface.jobs):
i.add_done_callback(partial(cb, num))
await asyncio.wait_for(asyncio.shield(gathered), 1)
except TimeoutError:
pass
interface.future.set_result("finished")
return
#self.notify('interfaces')
@ -1035,7 +1046,7 @@ class Network(util.DaemonThread):
# must use copy of values
if interface.has_timed_out():
print(interface.server, "timed out")
await self.connection_down(interface.server)
await self.connection_down(interface.server, "time out in ping_job")
return
elif interface.ping_required():
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
@ -1048,6 +1059,19 @@ class Network(util.DaemonThread):
print("FATAL ERROR in ping_job")
return asyncio.ensure_future(job())
def all_server_locks(self, ctx):
class AllLocks:
def __init__(self2):
self2.list = list(self.get_servers().keys())
self2.ctx = ctx
async def __aenter__(self2):
for i in self2.list:
await asyncio.wait_for(self.connecting[i].acquire(), 3)
async def __aexit__(self2, exc_type, exc, tb):
for i in self2.list:
self.connecting[i].release()
return AllLocks()
async def maintain_interfaces(self):
if self.stopped: return
@ -1057,7 +1081,8 @@ class Network(util.DaemonThread):
await self.start_random_interface()
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
self.print_error('network: retrying connections')
self.disconnected_servers = set([])
async with self.all_server_locks("maintain_interfaces"):
self.disconnected_servers = {}
self.nodes_retry_time = now
# main interface
@ -1066,13 +1091,13 @@ class Network(util.DaemonThread):
if not self.is_connecting():
await self.switch_to_random_interface()
else:
async with self.connecting[self.default_server]:
if self.default_server in self.disconnected_servers:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
self.disconnected_servers.remove(self.default_server)
if self.default_server in self.disconnected_servers:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
async with self.all_server_locks("maintain_interfaces 2"):
del self.disconnected_servers[self.default_server]
self.server_retry_time = now
else:
await self.switch_to_interface(self.default_server)
else:
await self.switch_to_interface(self.default_server)
else:
if self.config.is_fee_estimates_update_required():
await self.request_fee_estimates()
@ -1085,6 +1110,7 @@ class Network(util.DaemonThread):
self.init_headers_file()
self.pending_sends = asyncio.Queue()
self.connecting = collections.defaultdict(asyncio.Lock)
self.restartLock = asyncio.Lock()
async def job():
try:
@ -1132,7 +1158,7 @@ class Network(util.DaemonThread):
if not height:
return
if height < self.max_checkpoint():
await self.connection_down(interface.server)
await self.connection_down(interface.server, "height under max checkpoint in on_notify_header")
return
interface.tip_header = header
interface.tip = height