asyncio: add locks for more robust network handling

This commit is contained in:
Janus 2017-12-19 21:26:10 +01:00
parent 1cfdcf4e25
commit 3e2881bcfc
2 changed files with 84 additions and 58 deletions

View File

@ -189,7 +189,7 @@ class Interface(util.PrintError):
self.reader, self.writer = await asyncio.wait_for(open_coro, 5) self.reader, self.writer = await asyncio.wait_for(open_coro, 5)
else: else:
ssl_in_socks_coro = sslInSocksReaderWriter(self.addr, self.auth, self.host, self.port, ca_certs) ssl_in_socks_coro = sslInSocksReaderWriter(self.addr, self.auth, self.host, self.port, ca_certs)
self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 10) self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 5)
else: else:
context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs) if self.use_ssl else None context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs) if self.use_ssl else None
self.reader, self.writer = await asyncio.wait_for(self.conn_coro(context), 5) self.reader, self.writer = await asyncio.wait_for(self.conn_coro(context), 5)

View File

@ -20,6 +20,7 @@
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import sys
import collections import collections
from functools import partial from functools import partial
import time import time
@ -164,6 +165,7 @@ class Network(util.DaemonThread):
""" """
def __init__(self, config=None): def __init__(self, config=None):
self.disconnected_servers = {}
self.stopped = True self.stopped = True
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
if config is None: if config is None:
@ -382,7 +384,7 @@ class Network(util.DaemonThread):
return await self.new_interface(server) return await self.new_interface(server)
async def start_random_interface(self): async def start_random_interface(self):
exclude_set = self.disconnected_servers.union(set(self.interfaces.keys())) exclude_set = set(self.disconnected_servers.keys()).union(set(self.interfaces.keys()))
server = pick_random_server(self.get_servers(), self.protocol, exclude_set) server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
if server: if server:
return await self.start_interface(server) return await self.start_interface(server)
@ -398,7 +400,6 @@ class Network(util.DaemonThread):
assert not self.interface and not self.interfaces assert not self.interface and not self.interfaces
assert all(not i.locked() for i in self.connecting.values()) assert all(not i.locked() for i in self.connecting.values())
self.print_error('starting network') self.print_error('starting network')
self.disconnected_servers = set([])
self.protocol = protocol self.protocol = protocol
self.proxy = proxy self.proxy = proxy
await self.start_interfaces() await self.start_interfaces()
@ -407,20 +408,19 @@ class Network(util.DaemonThread):
self.stopped = True self.stopped = True
self.print_error("stopping network") self.print_error("stopping network")
async def stop(interface): async def stop(interface):
while True: await self.connection_down(interface.server, "stopping network")
try: await asyncio.wait_for(asyncio.shield(interface.future), 3)
await asyncio.wait_for(asyncio.shield(self.connection_down(interface.server)), 1.5) stopped_this_time = set()
break
except TimeoutError:
print("close_interface too slow...", interface.server)
try:
print(interface.server, await asyncio.wait_for(asyncio.shield(interface.future), 3))
except TimeoutError:
print("interface future too slow", interface.server)
if self.interface:
await stop(self.interface)
while self.interfaces: while self.interfaces:
await stop(next(iter(self.interfaces.values()))) do_next = next(iter(self.interfaces.values()))
assert do_next not in stopped_this_time
for i in self.disconnected_servers:
assert i not in self.interfaces.keys()
assert i != do_next.server
stopped_this_time.add(do_next)
await stop(do_next)
if self.interface:
assert self.interface.server in stopped_this_time, self.interface.server
await asyncio.wait_for(asyncio.shield(self.process_pending_sends_job), 5) await asyncio.wait_for(asyncio.shield(self.process_pending_sends_job), 5)
assert self.interface is None assert self.interface is None
for i in range(100): for i in range(100):
@ -430,8 +430,6 @@ class Network(util.DaemonThread):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
if self.interfaces: if self.interfaces:
assert False, "interfaces not empty after waiting: " + repr(self.interfaces) assert False, "interfaces not empty after waiting: " + repr(self.interfaces)
for i in self.connecting.values():
assert not i.locked()
# called from the Qt thread # called from the Qt thread
def set_parameters(self, host, port, protocol, proxy, auto_connect): def set_parameters(self, host, port, protocol, proxy, auto_connect):
@ -455,13 +453,20 @@ class Network(util.DaemonThread):
if self.proxy != proxy or self.protocol != protocol: if self.proxy != proxy or self.protocol != protocol:
async def job(): async def job():
try: try:
# Restart the network defaulting to the given server async with self.restartLock:
await self.stop_network() # Restart the network defaulting to the given server
self.default_server = server await self.stop_network()
await self.start_network(protocol, proxy) print("STOOOOOOOOOOOOOOOOOOOOOOOOOOPPED")
self.default_server = server
async with self.all_server_locks("restart job"):
self.disconnected_servers = {}
await self.start_network(protocol, proxy)
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
print("exception from restart job") print("exception from restart job")
if self.restartLock.locked():
print("NOT RESTARTING, RESTART IN PROGRESS")
return
asyncio.run_coroutine_threadsafe(job(), self.loop) asyncio.run_coroutine_threadsafe(job(), self.loop)
elif self.default_server != server: elif self.default_server != server:
async def job(): async def job():
@ -523,7 +528,11 @@ class Network(util.DaemonThread):
for i in interface.jobs: for i in interface.jobs:
asyncio.wait_for(i, 3) asyncio.wait_for(i, 3)
assert interface.boot_job assert interface.boot_job
await asyncio.wait_for(asyncio.shield(interface.boot_job), 3) try:
await asyncio.wait_for(asyncio.shield(interface.boot_job), 6) # longer than any timeout while connecting
except TimeoutError:
print("taking too long", interface.server)
raise
interface.close() interface.close()
def add_recent_server(self, server): def add_recent_server(self, server):
@ -609,7 +618,7 @@ class Network(util.DaemonThread):
self.subscribed_addresses.add(params[0]) self.subscribed_addresses.add(params[0])
else: else:
if not response: # Closed remotely / misbehaving if not response: # Closed remotely / misbehaving
await self.connection_down(interface.server) if not self.stopped: await self.connection_down(interface.server, "no response in process responses")
return return
# Rewrite response shape to match subscription request response # Rewrite response shape to match subscription request response
method = response.get('method') method = response.get('method')
@ -703,15 +712,19 @@ class Network(util.DaemonThread):
if callback in v: if callback in v:
v.remove(callback) v.remove(callback)
async def connection_down(self, server): async def connection_down(self, server, reason=None):
'''A connection to server either went down, or was never made. '''A connection to server either went down, or was never made.
We distinguish by whether it is in self.interfaces.''' We distinguish by whether it is in self.interfaces.'''
async with self.connecting[server]: async with self.all_server_locks("connection down"):
if server in self.disconnected_servers: if server in self.disconnected_servers:
self.print_error("already disconnected", server) try:
raise Exception("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason))
except:
traceback.print_exc()
sys.exit(1)
return return
self.print_error("connection down", server) self.print_error("connection down", server)
self.disconnected_servers.add(server) self.disconnected_servers[server] = reason
if server == self.default_server: if server == self.default_server:
self.set_status('disconnected') self.set_status('disconnected')
if server in self.interfaces: if server in self.interfaces:
@ -735,6 +748,7 @@ class Network(util.DaemonThread):
interface.boot_job = None interface.boot_job = None
self.boot_interface(interface) self.boot_interface(interface)
assert server not in self.interfaces assert server not in self.interfaces
assert not self.stopped
self.interfaces[server] = interface self.interfaces[server] = interface
return interface return interface
@ -759,7 +773,7 @@ class Network(util.DaemonThread):
self.requested_chunks.remove(index) self.requested_chunks.remove(index)
connect = interface.blockchain.connect_chunk(index, result) connect = interface.blockchain.connect_chunk(index, result)
if not connect: if not connect:
await self.connection_down(interface.server) await self.connection_down(interface.server, "could not connect chunk")
return return
# If not finished, get the next chunk # If not finished, get the next chunk
if interface.blockchain.height() < interface.tip: if interface.blockchain.height() < interface.tip:
@ -781,12 +795,12 @@ class Network(util.DaemonThread):
header = response.get('result') header = response.get('result')
if not header: if not header:
interface.print_error(response) interface.print_error(response)
await self.connection_down(interface.server) await self.connection_down(interface.server, "no header in on_get_header")
return return
height = header.get('block_height') height = header.get('block_height')
if interface.request != height: if interface.request != height:
interface.print_error("unsolicited header",interface.request, height) interface.print_error("unsolicited header",interface.request, height)
await self.connection_down(interface.server) await self.connection_down(interface.server, "unsolicited header")
return return
chain = blockchain.check_header(header) chain = blockchain.check_header(header)
if interface.mode == 'backward': if interface.mode == 'backward':
@ -806,7 +820,7 @@ class Network(util.DaemonThread):
assert next_height >= self.max_checkpoint(), (interface.bad, interface.good) assert next_height >= self.max_checkpoint(), (interface.bad, interface.good)
else: else:
if height == 0: if height == 0:
await self.connection_down(interface.server) await self.connection_down(interface.server, "height zero in on_get_header")
next_height = None next_height = None
else: else:
interface.bad = height interface.bad = height
@ -825,7 +839,7 @@ class Network(util.DaemonThread):
next_height = (interface.bad + interface.good) // 2 next_height = (interface.bad + interface.good) // 2
assert next_height >= self.max_checkpoint() assert next_height >= self.max_checkpoint()
elif not interface.blockchain.can_connect(interface.bad_header, check_height=False): elif not interface.blockchain.can_connect(interface.bad_header, check_height=False):
await self.connection_down(interface.server) await self.connection_down(interface.server, "blockchain can't connect")
next_height = None next_height = None
else: else:
branch = self.blockchains.get(interface.bad) branch = self.blockchains.get(interface.bad)
@ -906,7 +920,7 @@ class Network(util.DaemonThread):
for interface in list(self.interfaces.values()): for interface in list(self.interfaces.values()):
if interface.request and time.time() - interface.request_time > 20: if interface.request and time.time() - interface.request_time > 20:
interface.print_error("blockchain request timed out") interface.print_error("blockchain request timed out")
await self.connection_down(interface.server) await self.connection_down(interface.server, "blockchain request timed out")
def make_send_requests_job(self, interface): def make_send_requests_job(self, interface):
async def job(): async def job():
@ -916,8 +930,8 @@ class Network(util.DaemonThread):
result = await asyncio.wait_for(asyncio.shield(interface.send_request()), 1) result = await asyncio.wait_for(asyncio.shield(interface.send_request()), 1)
except TimeoutError: except TimeoutError:
continue continue
if not result: if not result and not self.stopped:
await self.connection_down(interface.server) await self.connection_down(interface.server, "send_request returned false")
except GeneratorExit: except GeneratorExit:
pass pass
except: except:
@ -933,7 +947,7 @@ class Network(util.DaemonThread):
except GeneratorExit: except GeneratorExit:
pass pass
except OSError: except OSError:
await self.connection_down(interface.server) await self.connection_down(interface.server, "OSError in process_responses")
print("OS error, connection downed") print("OS error, connection downed")
except BaseException: except BaseException:
if not self.stopped: if not self.stopped:
@ -974,21 +988,21 @@ class Network(util.DaemonThread):
try: try:
await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface) await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface)
if not await interface.send_request(): if not await interface.send_request():
asyncio.ensure_future(self.connection_down(interface.server)) if not self.stopped:
asyncio.ensure_future(self.connection_down(interface.server, "send_request false in boot_interface"))
interface.future.set_result("could not send request") interface.future.set_result("could not send request")
return return
if self.stopped: if self.stopped:
asyncio.ensure_future(self.connection_down(interface.server))
interface.future.set_result("stopped after sending request") interface.future.set_result("stopped after sending request")
return return
try: try:
await asyncio.wait_for(interface.get_response(), 1) await asyncio.wait_for(interface.get_response(), 1)
except TimeoutError: except TimeoutError:
asyncio.ensure_future(self.connection_down(interface.server)) if not self.stopped:
asyncio.ensure_future(self.connection_down(interface.server, "timeout in boot_interface while getting response"))
interface.future.set_result("timeout while getting response") interface.future.set_result("timeout while getting response")
return return
if self.stopped: if self.stopped:
asyncio.ensure_future(self.connection_down(interface.server))
interface.future.set_result("stopped after getting response") interface.future.set_result("stopped after getting response")
return return
#self.interfaces[interface.server] = interface #self.interfaces[interface.server] = interface
@ -996,15 +1010,12 @@ class Network(util.DaemonThread):
if interface.server == self.default_server: if interface.server == self.default_server:
await asyncio.wait_for(self.switch_to_interface(interface.server), 1) await asyncio.wait_for(self.switch_to_interface(interface.server), 1)
interface.jobs = [asyncio.ensure_future(x) for x in [self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)]] interface.jobs = [asyncio.ensure_future(x) for x in [self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)]]
def cb(num, fut): gathered = asyncio.gather(*interface.jobs)
while not self.stopped:
try: try:
fut.exception() await asyncio.wait_for(asyncio.shield(gathered), 1)
except e: except TimeoutError:
interface.future.set_exception(e) pass
else:
if not interface.future.done(): interface.future.set_result(str(num) + " done")
for num, i in enumerate(interface.jobs):
i.add_done_callback(partial(cb, num))
interface.future.set_result("finished") interface.future.set_result("finished")
return return
#self.notify('interfaces') #self.notify('interfaces')
@ -1035,7 +1046,7 @@ class Network(util.DaemonThread):
# must use copy of values # must use copy of values
if interface.has_timed_out(): if interface.has_timed_out():
print(interface.server, "timed out") print(interface.server, "timed out")
await self.connection_down(interface.server) await self.connection_down(interface.server, "time out in ping_job")
return return
elif interface.ping_required(): elif interface.ping_required():
params = [ELECTRUM_VERSION, PROTOCOL_VERSION] params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
@ -1048,6 +1059,19 @@ class Network(util.DaemonThread):
print("FATAL ERROR in ping_job") print("FATAL ERROR in ping_job")
return asyncio.ensure_future(job()) return asyncio.ensure_future(job())
def all_server_locks(self, ctx):
class AllLocks:
def __init__(self2):
self2.list = list(self.get_servers().keys())
self2.ctx = ctx
async def __aenter__(self2):
for i in self2.list:
await asyncio.wait_for(self.connecting[i].acquire(), 3)
async def __aexit__(self2, exc_type, exc, tb):
for i in self2.list:
self.connecting[i].release()
return AllLocks()
async def maintain_interfaces(self): async def maintain_interfaces(self):
if self.stopped: return if self.stopped: return
@ -1057,7 +1081,8 @@ class Network(util.DaemonThread):
await self.start_random_interface() await self.start_random_interface()
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL: if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
self.print_error('network: retrying connections') self.print_error('network: retrying connections')
self.disconnected_servers = set([]) async with self.all_server_locks("maintain_interfaces"):
self.disconnected_servers = {}
self.nodes_retry_time = now self.nodes_retry_time = now
# main interface # main interface
@ -1066,13 +1091,13 @@ class Network(util.DaemonThread):
if not self.is_connecting(): if not self.is_connecting():
await self.switch_to_random_interface() await self.switch_to_random_interface()
else: else:
async with self.connecting[self.default_server]: if self.default_server in self.disconnected_servers:
if self.default_server in self.disconnected_servers: if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL: async with self.all_server_locks("maintain_interfaces 2"):
self.disconnected_servers.remove(self.default_server) del self.disconnected_servers[self.default_server]
self.server_retry_time = now self.server_retry_time = now
else: else:
await self.switch_to_interface(self.default_server) await self.switch_to_interface(self.default_server)
else: else:
if self.config.is_fee_estimates_update_required(): if self.config.is_fee_estimates_update_required():
await self.request_fee_estimates() await self.request_fee_estimates()
@ -1085,6 +1110,7 @@ class Network(util.DaemonThread):
self.init_headers_file() self.init_headers_file()
self.pending_sends = asyncio.Queue() self.pending_sends = asyncio.Queue()
self.connecting = collections.defaultdict(asyncio.Lock) self.connecting = collections.defaultdict(asyncio.Lock)
self.restartLock = asyncio.Lock()
async def job(): async def job():
try: try:
@ -1132,7 +1158,7 @@ class Network(util.DaemonThread):
if not height: if not height:
return return
if height < self.max_checkpoint(): if height < self.max_checkpoint():
await self.connection_down(interface.server) await self.connection_down(interface.server, "height under max checkpoint in on_notify_header")
return return
interface.tip_header = header interface.tip_header = header
interface.tip = height interface.tip = height