asyncio: support switching servers

This commit is contained in:
Janus 2017-12-18 12:05:25 +01:00
parent 683205a3fa
commit 2d1ccfcc69
2 changed files with 69 additions and 33 deletions

View File

@ -143,10 +143,9 @@ class Interface(util.PrintError):
reader, writer = await asyncio.wait_for(self.conn_coro(context), 5) reader, writer = await asyncio.wait_for(self.conn_coro(context), 5)
dercert = writer.get_extra_info('ssl_object').getpeercert(True) dercert = writer.get_extra_info('ssl_object').getpeercert(True)
writer.close() writer.close()
except ConnectionError: except OSError as e: # not ConnectionError because we need socket.gaierror too
if self.is_running(): if self.is_running():
traceback.print_exc() print("Exception in _save_certificate", type(e))
print("Previous exception from _save_certificate")
return return
except TimeoutError: except TimeoutError:
return return
@ -199,8 +198,9 @@ class Interface(util.PrintError):
raise raise
except BaseException as e: except BaseException as e:
if self.is_running(): if self.is_running():
traceback.print_exc() if not isinstance(e, OSError):
print("Previous exception will now be reraised") traceback.print_exc()
print("Previous exception will now be reraised")
raise e raise e
if self.use_ssl and is_new: if self.use_ssl and is_new:
self.print_error("saving new certificate for", self.host) self.print_error("saving new certificate for", self.host)

View File

@ -163,6 +163,7 @@ class Network(util.DaemonThread):
""" """
def __init__(self, config=None): def __init__(self, config=None):
self.stopped = True
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
if config is None: if config is None:
config = {} # Do not use mutables as default values! config = {} # Do not use mutables as default values!
@ -393,6 +394,7 @@ class Network(util.DaemonThread):
await self.start_random_interface() await self.start_random_interface()
async def start_network(self, protocol, proxy): async def start_network(self, protocol, proxy):
self.stopped = False
# TODO proxy # TODO proxy
assert not self.interface and not self.interfaces assert not self.interface and not self.interfaces
assert not self.connecting assert not self.connecting
@ -403,17 +405,24 @@ class Network(util.DaemonThread):
await self.start_interfaces() await self.start_interfaces()
async def stop_network(self): async def stop_network(self):
self.stopped = True
self.print_error("stopping network") self.print_error("stopping network")
for num, interface in enumerate(list(self.interfaces.values())): list_with_main = ([self.interface] if self.interface else [])
await self.close_interface(interface) for num, interface in enumerate(list(self.interfaces.values()) + list_with_main):
await interface.future try:
#if self.interface: await asyncio.wait_for(asyncio.shield(self.close_interface(interface)), 5)
# await self.close_interface(self.interface) await asyncio.wait_for(asyncio.shield(interface.future), 5)
# await interface.future except TimeoutError:
await asyncio.wait_for(self.process_pending_sends_job, 5) print("close_interface too slow...")
await asyncio.wait_for(asyncio.shield(self.process_pending_sends_job), 5)
assert self.interface is None assert self.interface is None
while self.interfaces: for i in range(100):
asyncio.sleep(0.1) if not self.interfaces:
break
else:
await asyncio.sleep(0.1)
if self.interfaces:
assert False, "interfaces not empty after waiting: " + repr(self.interfaces)
self.connecting = set() self.connecting = set()
# called from the Qt thread # called from the Qt thread
@ -437,10 +446,15 @@ class Network(util.DaemonThread):
self.auto_connect = auto_connect self.auto_connect = auto_connect
if self.proxy != proxy or self.protocol != protocol: if self.proxy != proxy or self.protocol != protocol:
async def job(): async def job():
# Restart the network defaulting to the given server try:
await self.stop_network() # Restart the network defaulting to the given server
self.default_server = server await self.stop_network()
await self.start_network(protocol, proxy) self.default_server = server
await self.start_network(protocol, proxy)
self.notify('interfaces')
except BaseException as e:
traceback.print_exc()
print("exception from restart job")
asyncio.run_coroutine_threadsafe(job(), self.loop) asyncio.run_coroutine_threadsafe(job(), self.loop)
elif self.default_server != server: elif self.default_server != server:
async def job(): async def job():
@ -564,7 +578,7 @@ class Network(util.DaemonThread):
return str(method) + (':' + str(params[0]) if params else '') return str(method) + (':' + str(params[0]) if params else '')
async def process_responses(self, interface): async def process_responses(self, interface):
while self.is_running(): while not self.stopped:
request, response = await interface.get_response() request, response = await interface.get_response()
if request: if request:
method, params, message_id = request method, params, message_id = request
@ -698,7 +712,7 @@ class Network(util.DaemonThread):
async def new_interface(self, server): async def new_interface(self, server):
# todo: get tip first, then decide which checkpoint to use. # todo: get tip first, then decide which checkpoint to use.
self.add_recent_server(server) self.add_recent_server(server)
interface = Interface(server, self.config.path, self.proxy, lambda: self.is_running()) interface = Interface(server, self.config.path, self.proxy, lambda: not self.stopped)
interface.future = asyncio.Future() interface.future = asyncio.Future()
interface.blockchain = None interface.blockchain = None
interface.tip_header = None interface.tip_header = None
@ -884,9 +898,9 @@ class Network(util.DaemonThread):
def make_send_requests_job(self, interface): def make_send_requests_job(self, interface):
async def job(): async def job():
try: try:
while self.is_running(): while not self.stopped:
try: try:
result = await asyncio.wait_for(interface.send_request(), 1) result = await asyncio.wait_for(asyncio.shield(interface.send_request()), 1)
except TimeoutError: except TimeoutError:
continue continue
if not result: if not result:
@ -894,7 +908,7 @@ class Network(util.DaemonThread):
except GeneratorExit: except GeneratorExit:
pass pass
except: except:
if self.is_running(): if not self.stopped:
traceback.print_exc() traceback.print_exc()
print("FATAL ERROR ^^^") print("FATAL ERROR ^^^")
return asyncio.ensure_future(job()) return asyncio.ensure_future(job())
@ -909,7 +923,7 @@ class Network(util.DaemonThread):
await self.connection_down(interface.server) await self.connection_down(interface.server)
print("OS error, connection downed") print("OS error, connection downed")
except BaseException: except BaseException:
if self.is_running(): if not self.stopped:
traceback.print_exc() traceback.print_exc()
print("FATAL ERROR in process_responses") print("FATAL ERROR in process_responses")
return asyncio.ensure_future(job()) return asyncio.ensure_future(job())
@ -917,15 +931,16 @@ class Network(util.DaemonThread):
def make_process_pending_sends_job(self): def make_process_pending_sends_job(self):
async def job(): async def job():
try: try:
while self.is_running(): while not self.stopped:
print("pend send")
try: try:
await asyncio.wait_for(self.process_pending_sends(), 1) await asyncio.wait_for(asyncio.shield(self.process_pending_sends()), 1)
except TimeoutError: except TimeoutError:
continue continue
#except CancelledError: #except CancelledError:
# pass # pass
except BaseException as e: except BaseException as e:
if self.is_running(): if not self.stopped:
traceback.print_exc() traceback.print_exc()
print("FATAL ERROR in process_pending_sends") print("FATAL ERROR in process_pending_sends")
return asyncio.ensure_future(job()) return asyncio.ensure_future(job())
@ -947,11 +962,20 @@ class Network(util.DaemonThread):
try: try:
await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface) await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface)
def rem(): def rem():
if self.is_running(): self.connecting.remove(interface.server) if not self.stopped:
self.connecting.remove(interface.server)
if not await interface.send_request(): if not await interface.send_request():
await self.connection_down(interface.server)
rem() rem()
await self.connection_down(interface.server)
return return
if self.stopped: return
try:
await asyncio.wait_for(interface.get_response(), 1)
except TimeoutError:
rem()
await self.connection_down(interface.server)
return
if self.stopped: return
rem() rem()
self.interfaces[interface.server] = interface self.interfaces[interface.server] = interface
await self.queue_request('blockchain.headers.subscribe', [], interface) await self.queue_request('blockchain.headers.subscribe', [], interface)
@ -972,34 +996,45 @@ class Network(util.DaemonThread):
print(interface.server, "GENERATOR EXIT") print(interface.server, "GENERATOR EXIT")
pass pass
except BaseException as e: except BaseException as e:
if self.is_running(): if not self.stopped:
traceback.print_exc() traceback.print_exc()
print("FATAL ERROR in boot_interface") print("FATAL ERROR in boot_interface")
raise e raise e
interface.boot_job = asyncio.ensure_future(job()) interface.boot_job = asyncio.ensure_future(job())
interface.boot_job.server = interface.server
def boot_job_cb(fut):
try:
fut.exception()
except:
traceback.print_exc()
print("Previous exception in boot_job")
interface.boot_job.add_done_callback(boot_job_cb)
def make_ping_job(self, interface): def make_ping_job(self, interface):
async def job(): async def job():
try: try:
while self.is_running(): while not self.stopped:
await asyncio.sleep(1) await asyncio.sleep(1)
# Send pings and shut down stale interfaces # Send pings and shut down stale interfaces
# must use copy of values # must use copy of values
if interface.has_timed_out(): if interface.has_timed_out():
print(interface.server, "timed out") print(interface.server, "timed out")
await self.connection_down(interface.server) await self.connection_down(interface.server)
return
elif interface.ping_required(): elif interface.ping_required():
params = [ELECTRUM_VERSION, PROTOCOL_VERSION] params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
await self.queue_request('server.version', params, interface) await self.queue_request('server.version', params, interface)
except GeneratorExit: except GeneratorExit:
pass pass
except: except:
if self.is_running(): if not self.stopped:
traceback.print_exc() traceback.print_exc()
print("FATAL ERROR in ping_job") print("FATAL ERROR in ping_job")
return asyncio.ensure_future(job()) return asyncio.ensure_future(job())
async def maintain_interfaces(self): async def maintain_interfaces(self):
if self.stopped: return
now = time.time() now = time.time()
# nodes # nodes
if len(self.interfaces) + len(self.connecting) < self.num_server: if len(self.interfaces) + len(self.connecting) < self.num_server:
@ -1054,6 +1089,7 @@ class Network(util.DaemonThread):
async def run_async(self, future): async def run_async(self, future):
try: try:
while self.is_running(): while self.is_running():
#print(len(asyncio.Task.all_tasks()))
await asyncio.sleep(1) await asyncio.sleep(1)
await self.maintain_requests() await self.maintain_requests()
await self.maintain_interfaces() await self.maintain_interfaces()
@ -1063,7 +1099,7 @@ class Network(util.DaemonThread):
for i in asyncio.Task.all_tasks(): for i in asyncio.Task.all_tasks():
if asyncio.Task.current_task() == i: continue if asyncio.Task.current_task() == i: continue
try: try:
await asyncio.wait_for(i, 5) await asyncio.wait_for(asyncio.shield(i), 5)
except TimeoutError: except TimeoutError:
print("TOO SLOW TO SHUT DOWN, CANCELLING", i) print("TOO SLOW TO SHUT DOWN, CANCELLING", i)
i.cancel() i.cancel()