asyncio: more graceful shutdown
This commit is contained in:
parent
200a085778
commit
dcb0a24e6f
@ -228,20 +228,18 @@ class Interface(util.PrintError):
|
||||
self.buf = self.buf[pos+1:]
|
||||
self.last_action = time.time()
|
||||
return obj
|
||||
async def get(self):
|
||||
async def get(self, is_running):
|
||||
reader, _ = await self._get_read_write()
|
||||
|
||||
while True:
|
||||
while is_running():
|
||||
tried = self._try_extract()
|
||||
if tried: return tried
|
||||
temp = io.BytesIO()
|
||||
starttime = time.time()
|
||||
while time.time() - starttime < 1:
|
||||
try:
|
||||
data = await asyncio.wait_for(reader.read(2**8), 1)
|
||||
temp.write(data)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
try:
|
||||
data = await asyncio.wait_for(reader.read(2**10), 1)
|
||||
temp.write(data)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
self.buf += temp.getvalue()
|
||||
|
||||
def idle_time(self):
|
||||
@ -266,7 +264,10 @@ class Interface(util.PrintError):
|
||||
'''Sends queued requests. Returns False on failure.'''
|
||||
make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i}
|
||||
n = self.num_requests()
|
||||
prio, request = await self.unsent_requests.get()
|
||||
try:
|
||||
prio, request = await asyncio.wait_for(self.unsent_requests.get(), 1.5)
|
||||
except TimeoutError:
|
||||
return False
|
||||
try:
|
||||
await self.send_all([make_dict(*request)])
|
||||
except (SocksError, OSError, TimeoutError) as e:
|
||||
@ -298,7 +299,7 @@ class Interface(util.PrintError):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_response(self):
|
||||
async def get_response(self, is_running):
|
||||
'''Call if there is data available on the socket. Returns a list of
|
||||
(request, response) pairs. Notifications are singleton
|
||||
unsolicited responses presumably as a result of prior
|
||||
@ -307,12 +308,12 @@ class Interface(util.PrintError):
|
||||
corresponding request. If the connection was closed remotely
|
||||
or the remote server is misbehaving, a (None, None) will appear.
|
||||
'''
|
||||
response = await self.get()
|
||||
response = await self.get(is_running)
|
||||
if not type(response) is dict:
|
||||
print("response type not dict!", response)
|
||||
if response is None:
|
||||
self.closed_remotely = True
|
||||
self.print_error("connection closed remotely")
|
||||
if is_running():
|
||||
self.print_error("connection closed remotely")
|
||||
return None, None
|
||||
if self.debug:
|
||||
self.print_error("<--", response)
|
||||
|
||||
111
lib/network.py
111
lib/network.py
@ -20,6 +20,7 @@
|
||||
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
from functools import partial
|
||||
import time
|
||||
import queue
|
||||
import os
|
||||
@ -45,7 +46,7 @@ from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
|
||||
NODES_RETRY_INTERVAL = 60
|
||||
SERVER_RETRY_INTERVAL = 10
|
||||
|
||||
from concurrent.futures import CancelledError
|
||||
from concurrent.futures import TimeoutError, CancelledError
|
||||
|
||||
def parse_servers(result):
|
||||
""" parse servers list into dict format"""
|
||||
@ -403,12 +404,16 @@ class Network(util.DaemonThread):
|
||||
|
||||
async def stop_network(self):
|
||||
self.print_error("stopping network")
|
||||
for interface in list(self.interfaces.values()):
|
||||
self.close_interface(interface)
|
||||
if self.interface:
|
||||
self.close_interface(self.interface)
|
||||
for num, interface in enumerate(list(self.interfaces.values())):
|
||||
await self.close_interface(interface)
|
||||
await interface.future
|
||||
#if self.interface:
|
||||
# await self.close_interface(self.interface)
|
||||
# await interface.future
|
||||
await asyncio.wait_for(self.process_pending_sends_job, 5)
|
||||
assert self.interface is None
|
||||
assert not self.interfaces
|
||||
while self.interfaces:
|
||||
asyncio.sleep(0.1)
|
||||
self.connecting = set()
|
||||
|
||||
# called from the Qt thread
|
||||
@ -486,19 +491,20 @@ class Network(util.DaemonThread):
|
||||
self.set_status('connected')
|
||||
self.notify('updated')
|
||||
|
||||
def close_interface(self, interface):
|
||||
async def close_interface(self, interface):
|
||||
self.print_error('closing connection', interface.server)
|
||||
if interface:
|
||||
if interface.server in self.interfaces:
|
||||
self.interfaces.pop(interface.server)
|
||||
if interface.server == self.default_server:
|
||||
self.interface = None
|
||||
if interface.jobs:
|
||||
interface.jobs.cancel()
|
||||
if interface.boot_job is not None:
|
||||
interface.boot_job.cancel()
|
||||
if self.process_pending_sends_job is not None:
|
||||
self.process_pending_sends_job.cancel()
|
||||
while True:
|
||||
for i in interface.jobs:
|
||||
if not i.done():
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
break
|
||||
assert interface.boot_job.done()
|
||||
interface.close()
|
||||
|
||||
def add_recent_server(self, server):
|
||||
@ -559,7 +565,7 @@ class Network(util.DaemonThread):
|
||||
|
||||
async def process_responses(self, interface):
|
||||
while self.is_running():
|
||||
request, response = await interface.get_response()
|
||||
request, response = await interface.get_response(lambda: self.is_running())
|
||||
if request:
|
||||
method, params, message_id = request
|
||||
k = self.get_index(method, params)
|
||||
@ -584,7 +590,7 @@ class Network(util.DaemonThread):
|
||||
self.subscribed_addresses.add(params[0])
|
||||
else:
|
||||
if not response: # Closed remotely / misbehaving
|
||||
self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
return
|
||||
# Rewrite response shape to match subscription request response
|
||||
method = response.get('method')
|
||||
@ -676,7 +682,7 @@ class Network(util.DaemonThread):
|
||||
if callback in v:
|
||||
v.remove(callback)
|
||||
|
||||
def connection_down(self, server):
|
||||
async def connection_down(self, server):
|
||||
'''A connection to server either went down, or was never made.
|
||||
We distinguish by whether it is in self.interfaces.'''
|
||||
self.print_error("connection down", server)
|
||||
@ -684,7 +690,7 @@ class Network(util.DaemonThread):
|
||||
if server == self.default_server:
|
||||
self.set_status('disconnected')
|
||||
if server in self.interfaces:
|
||||
self.close_interface(self.interfaces[server])
|
||||
await self.close_interface(self.interfaces[server])
|
||||
self.notify('interfaces')
|
||||
for b in self.blockchains.values():
|
||||
if b.catch_up == server:
|
||||
@ -694,6 +700,7 @@ class Network(util.DaemonThread):
|
||||
# todo: get tip first, then decide which checkpoint to use.
|
||||
self.add_recent_server(server)
|
||||
interface = Interface(server, self.config.path, self.proxy)
|
||||
interface.future = asyncio.Future()
|
||||
interface.blockchain = None
|
||||
interface.tip_header = None
|
||||
interface.tip = 0
|
||||
@ -726,7 +733,7 @@ class Network(util.DaemonThread):
|
||||
self.requested_chunks.remove(index)
|
||||
connect = interface.blockchain.connect_chunk(index, result)
|
||||
if not connect:
|
||||
self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
return
|
||||
# If not finished, get the next chunk
|
||||
if interface.blockchain.height() < interface.tip:
|
||||
@ -748,12 +755,12 @@ class Network(util.DaemonThread):
|
||||
header = response.get('result')
|
||||
if not header:
|
||||
interface.print_error(response)
|
||||
self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
return
|
||||
height = header.get('block_height')
|
||||
if interface.request != height:
|
||||
interface.print_error("unsolicited header",interface.request, height)
|
||||
self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
return
|
||||
chain = blockchain.check_header(header)
|
||||
if interface.mode == 'backward':
|
||||
@ -773,7 +780,7 @@ class Network(util.DaemonThread):
|
||||
assert next_height >= self.max_checkpoint(), (interface.bad, interface.good)
|
||||
else:
|
||||
if height == 0:
|
||||
self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
next_height = None
|
||||
else:
|
||||
interface.bad = height
|
||||
@ -792,7 +799,7 @@ class Network(util.DaemonThread):
|
||||
next_height = (interface.bad + interface.good) // 2
|
||||
assert next_height >= self.max_checkpoint()
|
||||
elif not interface.blockchain.can_connect(interface.bad_header, check_height=False):
|
||||
self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
next_height = None
|
||||
else:
|
||||
branch = self.blockchains.get(interface.bad)
|
||||
@ -873,16 +880,19 @@ class Network(util.DaemonThread):
|
||||
for interface in list(self.interfaces.values()):
|
||||
if interface.request and time.time() - interface.request_time > 20:
|
||||
interface.print_error("blockchain request timed out")
|
||||
self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
|
||||
def make_send_requests_job(self, interface):
|
||||
async def job():
|
||||
try:
|
||||
while self.is_running():
|
||||
result = await interface.send_request()
|
||||
try:
|
||||
result = await asyncio.wait_for(interface.send_request(), 1)
|
||||
except TimeoutError:
|
||||
continue
|
||||
if not result:
|
||||
self.connection_down(interface.server)
|
||||
except CancelledError:
|
||||
await self.connection_down(interface.server)
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except:
|
||||
traceback.print_exc()
|
||||
@ -893,10 +903,10 @@ class Network(util.DaemonThread):
|
||||
async def job():
|
||||
try:
|
||||
await self.process_responses(interface)
|
||||
except CancelledError:
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except OSError:
|
||||
self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
print("OS error, connection downed")
|
||||
except BaseException:
|
||||
traceback.print_exc()
|
||||
@ -907,9 +917,12 @@ class Network(util.DaemonThread):
|
||||
async def job():
|
||||
try:
|
||||
while self.is_running():
|
||||
await self.process_pending_sends()
|
||||
except CancelledError:
|
||||
pass
|
||||
try:
|
||||
await asyncio.wait_for(self.process_pending_sends(), 1)
|
||||
except TimeoutError:
|
||||
continue
|
||||
#except CancelledError:
|
||||
# pass
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
print("FATAL ERROR in process_pending_sends")
|
||||
@ -931,28 +944,34 @@ class Network(util.DaemonThread):
|
||||
async def job():
|
||||
try:
|
||||
await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface)
|
||||
def rem():
|
||||
if self.is_running(): self.connecting.remove(interface.server)
|
||||
if not await interface.send_request():
|
||||
self.connection_down(interface.server)
|
||||
self.connecting.remove(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
rem()
|
||||
return
|
||||
self.connecting.remove(interface.server)
|
||||
rem()
|
||||
self.interfaces[interface.server] = interface
|
||||
await self.queue_request('blockchain.headers.subscribe', [], interface)
|
||||
if interface.server == self.default_server:
|
||||
await self.switch_to_interface(interface.server)
|
||||
interface.jobs = asyncio.ensure_future(asyncio.gather(self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)))
|
||||
def cb(fut):
|
||||
interface.jobs = [asyncio.ensure_future(x) for x in [self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)]]
|
||||
def cb(num, fut):
|
||||
try:
|
||||
fut.exception()
|
||||
except:
|
||||
pass
|
||||
interface.jobs.add_done_callback(cb)
|
||||
except e:
|
||||
interface.future.set_exception(e)
|
||||
else:
|
||||
if not interface.future.done(): interface.future.set_result(str(num) + " done")
|
||||
for num, i in enumerate(interface.jobs):
|
||||
i.add_done_callback(partial(cb, num))
|
||||
#self.notify('interfaces')
|
||||
except GeneratorExit:
|
||||
print(interface.server, "GENERATOR EXIT")
|
||||
pass
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
print("FATAL ERROR in start_interface")
|
||||
print("FATAL ERROR in boot_interface")
|
||||
raise e
|
||||
interface.boot_job = asyncio.ensure_future(job())
|
||||
|
||||
@ -965,11 +984,11 @@ class Network(util.DaemonThread):
|
||||
# must use copy of values
|
||||
if interface.has_timed_out():
|
||||
print(interface.server, "timed out")
|
||||
self.connection_down(interface.server)
|
||||
await self.connection_down(interface.server)
|
||||
elif interface.ping_required():
|
||||
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
|
||||
await self.queue_request('server.version', params, interface)
|
||||
except CancelledError:
|
||||
except GeneratorExit:
|
||||
pass
|
||||
except:
|
||||
traceback.print_exc()
|
||||
@ -1037,6 +1056,12 @@ class Network(util.DaemonThread):
|
||||
self.run_jobs()
|
||||
await self.stop_network()
|
||||
self.on_stop()
|
||||
for i in asyncio.Task.all_tasks():
|
||||
if asyncio.Task.current_task() == i: continue
|
||||
try:
|
||||
await i
|
||||
except CancelledError:
|
||||
pass
|
||||
future.set_result("run_async done")
|
||||
except BaseException as e:
|
||||
future.set_exception(e)
|
||||
@ -1046,7 +1071,7 @@ class Network(util.DaemonThread):
|
||||
if not height:
|
||||
return
|
||||
if height < self.max_checkpoint():
|
||||
self.connection_down(interface)
|
||||
await self.connection_down(interface)
|
||||
return
|
||||
interface.tip_header = header
|
||||
interface.tip = height
|
||||
|
||||
Loading…
Reference in New Issue
Block a user