asyncio: more graceful shutdown

This commit is contained in:
Janus 2017-12-15 16:18:52 +01:00
parent 200a085778
commit dcb0a24e6f
2 changed files with 83 additions and 57 deletions

View File

@ -228,20 +228,18 @@ class Interface(util.PrintError):
self.buf = self.buf[pos+1:]
self.last_action = time.time()
return obj
async def get(self):
async def get(self, is_running):
reader, _ = await self._get_read_write()
while True:
while is_running():
tried = self._try_extract()
if tried: return tried
temp = io.BytesIO()
starttime = time.time()
while time.time() - starttime < 1:
try:
data = await asyncio.wait_for(reader.read(2**8), 1)
temp.write(data)
except asyncio.TimeoutError:
break
try:
data = await asyncio.wait_for(reader.read(2**10), 1)
temp.write(data)
except asyncio.TimeoutError:
continue
self.buf += temp.getvalue()
def idle_time(self):
@ -266,7 +264,10 @@ class Interface(util.PrintError):
'''Sends queued requests. Returns False on failure.'''
make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i}
n = self.num_requests()
prio, request = await self.unsent_requests.get()
try:
prio, request = await asyncio.wait_for(self.unsent_requests.get(), 1.5)
except TimeoutError:
return False
try:
await self.send_all([make_dict(*request)])
except (SocksError, OSError, TimeoutError) as e:
@ -298,7 +299,7 @@ class Interface(util.PrintError):
return True
return False
async def get_response(self):
async def get_response(self, is_running):
'''Call if there is data available on the socket. Returns a list of
(request, response) pairs. Notifications are singleton
unsolicited responses presumably as a result of prior
@ -307,12 +308,12 @@ class Interface(util.PrintError):
corresponding request. If the connection was closed remotely
or the remote server is misbehaving, a (None, None) will appear.
'''
response = await self.get()
response = await self.get(is_running)
if not type(response) is dict:
print("response type not dict!", response)
if response is None:
self.closed_remotely = True
self.print_error("connection closed remotely")
if is_running():
self.print_error("connection closed remotely")
return None, None
if self.debug:
self.print_error("<--", response)

View File

@ -20,6 +20,7 @@
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from functools import partial
import time
import queue
import os
@ -45,7 +46,7 @@ from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
NODES_RETRY_INTERVAL = 60
SERVER_RETRY_INTERVAL = 10
from concurrent.futures import CancelledError
from concurrent.futures import TimeoutError, CancelledError
def parse_servers(result):
""" parse servers list into dict format"""
@ -403,12 +404,16 @@ class Network(util.DaemonThread):
async def stop_network(self):
self.print_error("stopping network")
for interface in list(self.interfaces.values()):
self.close_interface(interface)
if self.interface:
self.close_interface(self.interface)
for num, interface in enumerate(list(self.interfaces.values())):
await self.close_interface(interface)
await interface.future
#if self.interface:
# await self.close_interface(self.interface)
# await interface.future
await asyncio.wait_for(self.process_pending_sends_job, 5)
assert self.interface is None
assert not self.interfaces
while self.interfaces:
asyncio.sleep(0.1)
self.connecting = set()
# called from the Qt thread
@ -486,19 +491,20 @@ class Network(util.DaemonThread):
self.set_status('connected')
self.notify('updated')
def close_interface(self, interface):
async def close_interface(self, interface):
self.print_error('closing connection', interface.server)
if interface:
if interface.server in self.interfaces:
self.interfaces.pop(interface.server)
if interface.server == self.default_server:
self.interface = None
if interface.jobs:
interface.jobs.cancel()
if interface.boot_job is not None:
interface.boot_job.cancel()
if self.process_pending_sends_job is not None:
self.process_pending_sends_job.cancel()
while True:
for i in interface.jobs:
if not i.done():
await asyncio.sleep(0.1)
continue
break
assert interface.boot_job.done()
interface.close()
def add_recent_server(self, server):
@ -559,7 +565,7 @@ class Network(util.DaemonThread):
async def process_responses(self, interface):
while self.is_running():
request, response = await interface.get_response()
request, response = await interface.get_response(lambda: self.is_running())
if request:
method, params, message_id = request
k = self.get_index(method, params)
@ -584,7 +590,7 @@ class Network(util.DaemonThread):
self.subscribed_addresses.add(params[0])
else:
if not response: # Closed remotely / misbehaving
self.connection_down(interface.server)
await self.connection_down(interface.server)
return
# Rewrite response shape to match subscription request response
method = response.get('method')
@ -676,7 +682,7 @@ class Network(util.DaemonThread):
if callback in v:
v.remove(callback)
def connection_down(self, server):
async def connection_down(self, server):
'''A connection to server either went down, or was never made.
We distinguish by whether it is in self.interfaces.'''
self.print_error("connection down", server)
@ -684,7 +690,7 @@ class Network(util.DaemonThread):
if server == self.default_server:
self.set_status('disconnected')
if server in self.interfaces:
self.close_interface(self.interfaces[server])
await self.close_interface(self.interfaces[server])
self.notify('interfaces')
for b in self.blockchains.values():
if b.catch_up == server:
@ -694,6 +700,7 @@ class Network(util.DaemonThread):
# todo: get tip first, then decide which checkpoint to use.
self.add_recent_server(server)
interface = Interface(server, self.config.path, self.proxy)
interface.future = asyncio.Future()
interface.blockchain = None
interface.tip_header = None
interface.tip = 0
@ -726,7 +733,7 @@ class Network(util.DaemonThread):
self.requested_chunks.remove(index)
connect = interface.blockchain.connect_chunk(index, result)
if not connect:
self.connection_down(interface.server)
await self.connection_down(interface.server)
return
# If not finished, get the next chunk
if interface.blockchain.height() < interface.tip:
@ -748,12 +755,12 @@ class Network(util.DaemonThread):
header = response.get('result')
if not header:
interface.print_error(response)
self.connection_down(interface.server)
await self.connection_down(interface.server)
return
height = header.get('block_height')
if interface.request != height:
interface.print_error("unsolicited header",interface.request, height)
self.connection_down(interface.server)
await self.connection_down(interface.server)
return
chain = blockchain.check_header(header)
if interface.mode == 'backward':
@ -773,7 +780,7 @@ class Network(util.DaemonThread):
assert next_height >= self.max_checkpoint(), (interface.bad, interface.good)
else:
if height == 0:
self.connection_down(interface.server)
await self.connection_down(interface.server)
next_height = None
else:
interface.bad = height
@ -792,7 +799,7 @@ class Network(util.DaemonThread):
next_height = (interface.bad + interface.good) // 2
assert next_height >= self.max_checkpoint()
elif not interface.blockchain.can_connect(interface.bad_header, check_height=False):
self.connection_down(interface.server)
await self.connection_down(interface.server)
next_height = None
else:
branch = self.blockchains.get(interface.bad)
@ -873,16 +880,19 @@ class Network(util.DaemonThread):
for interface in list(self.interfaces.values()):
if interface.request and time.time() - interface.request_time > 20:
interface.print_error("blockchain request timed out")
self.connection_down(interface.server)
await self.connection_down(interface.server)
def make_send_requests_job(self, interface):
async def job():
try:
while self.is_running():
result = await interface.send_request()
try:
result = await asyncio.wait_for(interface.send_request(), 1)
except TimeoutError:
continue
if not result:
self.connection_down(interface.server)
except CancelledError:
await self.connection_down(interface.server)
except GeneratorExit:
pass
except:
traceback.print_exc()
@ -893,10 +903,10 @@ class Network(util.DaemonThread):
async def job():
try:
await self.process_responses(interface)
except CancelledError:
except GeneratorExit:
pass
except OSError:
self.connection_down(interface.server)
await self.connection_down(interface.server)
print("OS error, connection downed")
except BaseException:
traceback.print_exc()
@ -907,9 +917,12 @@ class Network(util.DaemonThread):
async def job():
try:
while self.is_running():
await self.process_pending_sends()
except CancelledError:
pass
try:
await asyncio.wait_for(self.process_pending_sends(), 1)
except TimeoutError:
continue
#except CancelledError:
# pass
except BaseException as e:
traceback.print_exc()
print("FATAL ERROR in process_pending_sends")
@ -931,28 +944,34 @@ class Network(util.DaemonThread):
async def job():
try:
await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface)
def rem():
if self.is_running(): self.connecting.remove(interface.server)
if not await interface.send_request():
self.connection_down(interface.server)
self.connecting.remove(interface.server)
await self.connection_down(interface.server)
rem()
return
self.connecting.remove(interface.server)
rem()
self.interfaces[interface.server] = interface
await self.queue_request('blockchain.headers.subscribe', [], interface)
if interface.server == self.default_server:
await self.switch_to_interface(interface.server)
interface.jobs = asyncio.ensure_future(asyncio.gather(self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)))
def cb(fut):
interface.jobs = [asyncio.ensure_future(x) for x in [self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)]]
def cb(num, fut):
try:
fut.exception()
except:
pass
interface.jobs.add_done_callback(cb)
except e:
interface.future.set_exception(e)
else:
if not interface.future.done(): interface.future.set_result(str(num) + " done")
for num, i in enumerate(interface.jobs):
i.add_done_callback(partial(cb, num))
#self.notify('interfaces')
except GeneratorExit:
print(interface.server, "GENERATOR EXIT")
pass
except BaseException as e:
traceback.print_exc()
print("FATAL ERROR in start_interface")
print("FATAL ERROR in boot_interface")
raise e
interface.boot_job = asyncio.ensure_future(job())
@ -965,11 +984,11 @@ class Network(util.DaemonThread):
# must use copy of values
if interface.has_timed_out():
print(interface.server, "timed out")
self.connection_down(interface.server)
await self.connection_down(interface.server)
elif interface.ping_required():
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
await self.queue_request('server.version', params, interface)
except CancelledError:
except GeneratorExit:
pass
except:
traceback.print_exc()
@ -1037,6 +1056,12 @@ class Network(util.DaemonThread):
self.run_jobs()
await self.stop_network()
self.on_stop()
for i in asyncio.Task.all_tasks():
if asyncio.Task.current_task() == i: continue
try:
await i
except CancelledError:
pass
future.set_result("run_async done")
except BaseException as e:
future.set_exception(e)
@ -1046,7 +1071,7 @@ class Network(util.DaemonThread):
if not height:
return
if height < self.max_checkpoint():
self.connection_down(interface)
await self.connection_down(interface)
return
interface.tip_header = header
interface.tip = height