diff --git a/contrib/deterministic-build/requirements.txt b/contrib/deterministic-build/requirements.txt index e7d8925b..638eba7f 100644 --- a/contrib/deterministic-build/requirements.txt +++ b/contrib/deterministic-build/requirements.txt @@ -7,7 +7,7 @@ jsonrpclib-pelix==0.3.1 pbkdf2==1.3 protobuf==3.5.1 pyaes==1.6.1 -PySocks==1.6.8 +aiosocks==0.2.6 qrcode==5.3 requests==2.18.4 six==1.11.0 diff --git a/contrib/requirements/requirements.txt b/contrib/requirements/requirements.txt index 227ec1cd..7caa043e 100644 --- a/contrib/requirements/requirements.txt +++ b/contrib/requirements/requirements.txt @@ -6,4 +6,4 @@ qrcode protobuf dnspython jsonrpclib-pelix -PySocks>=1.6.6 +aiosocks diff --git a/lib/__init__.py b/lib/__init__.py index 286e4b6a..eb54db47 100644 --- a/lib/__init__.py +++ b/lib/__init__.py @@ -4,7 +4,7 @@ from .wallet import Synchronizer, Wallet from .storage import WalletStorage from .coinchooser import COIN_CHOOSERS from .network import Network, pick_random_server -from .interface import Connection, Interface +from .interface import Interface from .simple_config import SimpleConfig, get_config, set_config from . import bitcoin from . import transaction diff --git a/lib/interface.py b/lib/interface.py index ac1495fb..95c30f5c 100644 --- a/lib/interface.py +++ b/lib/interface.py @@ -22,287 +22,272 @@ # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import aiosocks import os +import stat import re -import socket import ssl import sys import threading import time import traceback +import asyncio +import json +import asyncio.streams +from asyncio.sslproto import SSLProtocol +import io import requests -from .util import print_error +from aiosocks.errors import SocksError +from concurrent.futures import TimeoutError ca_path = requests.certs.where() +from .util import print_error +from .ssl_in_socks import sslInSocksReaderWriter from . import util from . import x509 from . import pem - -def Connection(server, queue, config_path): - """Makes asynchronous connections to a remote electrum server. - Returns the running thread that is making the connection. - - Once the thread has connected, it finishes, placing a tuple on the - queue of the form (server, socket), where socket is None if - connection failed. - """ - host, port, protocol = server.rsplit(':', 2) - if not protocol in 'st': - raise Exception('Unknown protocol: %s' % protocol) - c = TcpConnection(server, queue, config_path) - c.start() - return c - - -class TcpConnection(threading.Thread, util.PrintError): - - def __init__(self, server, queue, config_path): - threading.Thread.__init__(self) - self.config_path = config_path - self.queue = queue - self.server = server - self.host, self.port, self.protocol = self.server.rsplit(':', 2) - self.host = str(self.host) - self.port = int(self.port) - self.use_ssl = (self.protocol == 's') - self.daemon = True - - def diagnostic_name(self): - return self.host - - def check_host_name(self, peercert, name): - """Simple certificate/host name checker. Returns True if the - certificate matches, False otherwise. Does not support - wildcards.""" - # Check that the peer has supplied a certificate. - # None/{} is not acceptable. - if not peercert: - return False - if 'subjectAltName' in peercert: - for typ, val in peercert["subjectAltName"]: - if typ == "DNS" and val == name: - return True - else: - # Only check the subject DN if there is no subject alternative - # name. - cn = None - for attr, val in peercert["subject"]: - # Use most-specific (last) commonName attribute. - if attr == "commonName": - cn = val - if cn is not None: - return cn == name - return False - - def get_simple_socket(self): - try: - l = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) - except socket.gaierror: - self.print_error("cannot resolve hostname") - return - e = None - for res in l: - try: - s = socket.socket(res[0], socket.SOCK_STREAM) - s.settimeout(10) - s.connect(res[4]) - s.settimeout(2) - s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - return s - except BaseException as _e: - e = _e - continue - else: - self.print_error("failed to connect", str(e)) - - @staticmethod - def get_ssl_context(cert_reqs, ca_certs): - context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_certs) - context.check_hostname = False - context.verify_mode = cert_reqs - - context.options |= ssl.OP_NO_SSLv2 - context.options |= ssl.OP_NO_SSLv3 - context.options |= ssl.OP_NO_TLSv1 - - return context - - def get_socket(self): - if self.use_ssl: - cert_path = os.path.join(self.config_path, 'certs', self.host) - if not os.path.exists(cert_path): - is_new = True - s = self.get_simple_socket() - if s is None: - return - # try with CA first - try: - context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path) - s = context.wrap_socket(s, do_handshake_on_connect=True) - except ssl.SSLError as e: - print_error(e) - s = None - except: - return - - if s and self.check_host_name(s.getpeercert(), self.host): - self.print_error("SSL certificate signed by CA") - return s - # get server certificate. - # Do not use ssl.get_server_certificate because it does not work with proxy - s = self.get_simple_socket() - if s is None: - return - try: - context = self.get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None) - s = context.wrap_socket(s) - except ssl.SSLError as e: - self.print_error("SSL error retrieving SSL certificate:", e) - return - except: - return - - dercert = s.getpeercert(True) - s.close() - cert = ssl.DER_cert_to_PEM_cert(dercert) - # workaround android bug - cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert) - temporary_path = cert_path + '.temp' - with open(temporary_path,"w") as f: - f.write(cert) - else: - is_new = False - - s = self.get_simple_socket() - if s is None: - return - - if self.use_ssl: - try: - context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, - ca_certs=(temporary_path if is_new else cert_path)) - s = context.wrap_socket(s, do_handshake_on_connect=True) - except socket.timeout: - self.print_error('timeout') - return - except ssl.SSLError as e: - self.print_error("SSL error:", e) - if e.errno != 1: - return - if is_new: - rej = cert_path + '.rej' - if os.path.exists(rej): - os.unlink(rej) - os.rename(temporary_path, rej) - else: - with open(cert_path) as f: - cert = f.read() - try: - b = pem.dePem(cert, 'CERTIFICATE') - x = x509.X509(b) - except: - traceback.print_exc(file=sys.stderr) - self.print_error("wrong certificate") - return - try: - x.check_date() - except: - self.print_error("certificate has expired:", cert_path) - os.unlink(cert_path) - return - self.print_error("wrong certificate") - if e.errno == 104: - return - return - except BaseException as e: - self.print_error(e) - traceback.print_exc(file=sys.stderr) - return - - if is_new: - self.print_error("saving certificate") - os.rename(temporary_path, cert_path) - - return s - - def run(self): - socket = self.get_socket() - if socket: - self.print_error("connected") - self.queue.put((self.server, socket)) - +def get_ssl_context(cert_reqs, ca_certs): + context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_certs) + context.check_hostname = False + context.verify_mode = cert_reqs + context.options |= ssl.OP_NO_SSLv2 + context.options |= ssl.OP_NO_SSLv3 + context.options |= ssl.OP_NO_TLSv1 + return context class Interface(util.PrintError): """The Interface class handles a socket connected to a single remote electrum server. It's exposed API is: - - Member functions close(), fileno(), get_responses(), has_timed_out(), - ping_required(), queue_request(), send_requests() + - Member functions close(), fileno(), get_response(), has_timed_out(), + ping_required(), queue_request(), send_request() - Member variable server. """ - def __init__(self, server, socket): - self.server = server - self.host, _, _ = server.rsplit(':', 2) - self.socket = socket + def __init__(self, server, config_path, proxy_config, is_running): + self.is_running = is_running + self.addr = self.auth = None + if proxy_config is not None: + if proxy_config["mode"] == "socks5": + self.addr = aiosocks.Socks5Addr(proxy_config["host"], proxy_config["port"]) + self.auth = aiosocks.Socks5Auth(proxy_config["user"], proxy_config["password"]) if proxy_config["user"] != "" else None + elif proxy_config["mode"] == "socks4": + self.addr = aiosocks.Socks4Addr(proxy_config["host"], proxy_config["port"]) + self.auth = aiosocks.Socks4Auth(proxy_config["password"]) if proxy_config["password"] != "" else None + else: + raise Exception("proxy mode not supported") - self.pipe = util.SocketPipe(socket) - self.pipe.set_timeout(0.0) # Don't wait for data + self.server = server + self.config_path = config_path + host, port, protocol = self.server.split(':') + self.host = host + self.port = int(port) + self.use_ssl = (protocol=='s') + self.reader = self.writer = None + self.lock = asyncio.Lock() # Dump network messages. Set at runtime from the console. self.debug = False - self.unsent_requests = [] + self.unsent_requests = asyncio.PriorityQueue() self.unanswered_requests = {} - # Set last ping to zero to ensure immediate ping - self.last_request = time.time() self.last_ping = 0 self.closed_remotely = False + self.buf = bytes() + + def conn_coro(self, context): + return asyncio.open_connection(self.host, self.port, ssl=context) + + async def _save_certificate(self, cert_path, require_ca): + dercert = None + if require_ca: + context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path) + else: + context = get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None) + try: + if self.addr is not None: + proto_factory = lambda: SSLProtocol(asyncio.get_event_loop(), asyncio.Protocol(), context, None) + socks_create_coro = aiosocks.create_connection(proto_factory, \ + proxy=self.addr, \ + proxy_auth=self.auth, \ + dst=(self.host, self.port)) + transport, protocol = await asyncio.wait_for(socks_create_coro, 5) + async def job(fut): + try: + if protocol._sslpipe is not None: + fut.set_result(protocol._sslpipe.ssl_object.getpeercert(True)) + except BaseException as e: + fut.set_exception(e) + while self.is_running(): + fut = asyncio.Future() + asyncio.ensure_future(job(fut)) + try: + await fut + except: + pass + try: + fut.exception() + dercert = fut.result() + except ValueError: + await asyncio.sleep(1) + continue + except: + if self.is_running(): + traceback.print_exc() + print("Previous exception from _save_certificate") + continue + break + if not self.is_running(): return + transport.close() + else: + reader, writer = await asyncio.wait_for(self.conn_coro(context), 3) + dercert = writer.get_extra_info('ssl_object').getpeercert(True) + writer.close() + except OSError as e: # not ConnectionError because we need socket.gaierror too + if self.is_running(): + self.print_error(self.server, "Exception in _save_certificate", type(e)) + return + except TimeoutError: + return + assert dercert + if not require_ca: + cert = ssl.DER_cert_to_PEM_cert(dercert) + else: + # Don't pin a CA signed certificate + cert = "" + temporary_path = cert_path + '.temp' + with open(temporary_path, "w") as f: + f.write(cert) + return temporary_path + + async def _get_read_write(self): + async with self.lock: + if self.reader is not None and self.writer is not None: + return self.reader, self.writer, True + if self.use_ssl: + cert_path = os.path.join(self.config_path, 'certs', self.host) + if not os.path.exists(cert_path): + temporary_path = await self._save_certificate(cert_path, True) + if not temporary_path: + temporary_path = await self._save_certificate(cert_path, False) + if not temporary_path: + raise ConnectionError("Could not get certificate on second try") + + is_new = True + else: + is_new = False + ca_certs = temporary_path if is_new else cert_path + + size = os.stat(ca_certs)[stat.ST_SIZE] + self_signed = size != 0 + if not self_signed: + ca_certs = ca_path + try: + if self.addr is not None: + if not self.use_ssl: + open_coro = aiosocks.open_connection(proxy=self.addr, proxy_auth=self.auth, dst=(self.host, self.port)) + self.reader, self.writer = await asyncio.wait_for(open_coro, 5) + else: + ssl_in_socks_coro = sslInSocksReaderWriter(self.addr, self.auth, self.host, self.port, ca_certs) + self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 5) + else: + context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs) if self.use_ssl else None + self.reader, self.writer = await asyncio.wait_for(self.conn_coro(context), 5) + except TimeoutError: + self.print_error("TimeoutError after getting certificate successfully...") + raise + except BaseException as e: + if self.is_running(): + if not isinstance(e, OSError): + traceback.print_exc() + self.print_error("Previous exception will now be reraised") + raise e + if self.use_ssl and is_new: + self.print_error("saving new certificate for", self.host) + os.rename(temporary_path, cert_path) + return self.reader, self.writer, False + + async def send_all(self, list_of_requests): + _, w, usedExisting = await self._get_read_write() + starttime = time.time() + for i in list_of_requests: + w.write(json.dumps(i).encode("ascii") + b"\n") + await w.drain() + if time.time() - starttime > 2.5: + self.print_error("send_all: sending is taking too long. Used existing connection: ", usedExisting) + raise ConnectionError("sending is taking too long") + + def close(self): + if self.writer: + self.writer.close() + + def _try_extract(self): + try: + pos = self.buf.index(b"\n") + except ValueError: + return + obj = self.buf[:pos] + try: + obj = json.loads(obj.decode("ascii")) + except ValueError: + return + else: + self.buf = self.buf[pos+1:] + self.last_action = time.time() + return obj + async def get(self): + reader, _, _ = await self._get_read_write() + + while self.is_running(): + tried = self._try_extract() + if tried: return tried + temp = io.BytesIO() + try: + data = await asyncio.wait_for(reader.read(2**10), 1) + temp.write(data) + except asyncio.TimeoutError: + continue + self.buf += temp.getvalue() + + def idle_time(self): + return time.time() - self.last_action def diagnostic_name(self): return self.host - def fileno(self): - # Needed for select - return self.socket.fileno() - - def close(self): - if not self.closed_remotely: - try: - self.socket.shutdown(socket.SHUT_RDWR) - except socket.error: - pass - self.socket.close() - - def queue_request(self, *args): # method, params, _id + async def queue_request(self, *args): # method, params, _id '''Queue a request, later to be send with send_requests when the socket is available for writing. ''' self.request_time = time.time() - self.unsent_requests.append(args) + await self.unsent_requests.put((self.request_time, args)) def num_requests(self): '''Keep unanswered requests below 100''' n = 100 - len(self.unanswered_requests) - return min(n, len(self.unsent_requests)) + return min(n, self.unsent_requests.qsize()) - def send_requests(self): + async def send_request(self): '''Sends queued requests. Returns False on failure.''' make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i} n = self.num_requests() - wire_requests = self.unsent_requests[0:n] try: - self.pipe.send_all([make_dict(*r) for r in wire_requests]) - except socket.error as e: - self.print_error("socket error:", e) + prio, request = await asyncio.wait_for(self.unsent_requests.get(), 1.5) + except TimeoutError: return False - self.unsent_requests = self.unsent_requests[n:] - for request in wire_requests: - if self.debug: - self.print_error("-->", request) - self.unanswered_requests[request[2]] = request + try: + await self.send_all([make_dict(*request)]) + except (SocksError, OSError, TimeoutError) as e: + if type(e) is SocksError: + self.print_error(e) + await self.unsent_requests.put((prio, request)) + return False + if self.debug: + self.print_error("-->", request) + self.unanswered_requests[request[2]] = request + self.last_action = time.time() return True def ping_required(self): @@ -318,13 +303,12 @@ class Interface(util.PrintError): def has_timed_out(self): '''Returns True if the interface has timed out.''' if (self.unanswered_requests and time.time() - self.request_time > 10 - and self.pipe.idle_time() > 10): + and self.idle_time() > 10): self.print_error("timeout", len(self.unanswered_requests)) return True - return False - def get_responses(self): + async def get_response(self): '''Call if there is data available on the socket. Returns a list of (request, response) pairs. Notifications are singleton unsolicited responses presumably as a result of prior @@ -333,34 +317,25 @@ class Interface(util.PrintError): corresponding request. If the connection was closed remotely or the remote server is misbehaving, a (None, None) will appear. ''' - responses = [] - while True: - try: - response = self.pipe.get() - except util.timeout: - break - if not type(response) is dict: - responses.append((None, None)) - if response is None: - self.closed_remotely = True + response = await self.get() + if not type(response) is dict: + if response is None: + self.closed_remotely = True + if self.is_running(): self.print_error("connection closed remotely") - break - if self.debug: - self.print_error("<--", response) - wire_id = response.get('id', None) - if wire_id is None: # Notification - responses.append((None, response)) + return None, None + if self.debug: + self.print_error("<--", response) + wire_id = response.get('id', None) + if wire_id is None: # Notification + return None, response + else: + request = self.unanswered_requests.pop(wire_id, None) + if request: + return request, response else: - request = self.unanswered_requests.pop(wire_id, None) - if request: - responses.append((request, response)) - else: - self.print_error("unknown wire ID", wire_id) - responses.append((None, None)) # Signal - break - - return responses - + self.print_error("unknown wire ID", wire_id) + return None, None # Signal def check_cert(host, cert): try: diff --git a/lib/network.py b/lib/network.py index b803a782..3bf478ad 100644 --- a/lib/network.py +++ b/lib/network.py @@ -20,6 +20,9 @@ # ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import sys +import collections +from functools import partial import time import queue import os @@ -30,23 +33,22 @@ import re import select from collections import defaultdict import threading -import socket import json +import asyncio +import traceback -import socks from . import util from . import bitcoin from .bitcoin import * -from . import constants -from .interface import Connection, Interface +from .interface import Interface from . import blockchain from .version import ELECTRUM_VERSION, PROTOCOL_VERSION -from .i18n import _ NODES_RETRY_INTERVAL = 60 SERVER_RETRY_INTERVAL = 10 +from concurrent.futures import TimeoutError, CancelledError def parse_servers(result): """ parse servers list into dict format""" @@ -61,7 +63,7 @@ def parse_servers(result): for v in item[2]: if re.match("[st]\d*", v): protocol, port = v[0], v[1:] - if port == '': port = constants.net.DEFAULT_PORTS[protocol] + if port == '': port = bitcoin.NetworkConstants.DEFAULT_PORTS[protocol] out[protocol] = port elif re.match("v(.?)+", v): version = v[1:] @@ -78,7 +80,7 @@ def filter_version(servers): def is_recent(version): try: return util.normalize_version(version) >= util.normalize_version(PROTOCOL_VERSION) - except Exception as e: + except BaseException as e: return False return {k: v for k, v in servers.items() if is_recent(v.get('version'))} @@ -95,7 +97,7 @@ def filter_protocol(hostmap, protocol = 's'): def pick_random_server(hostmap = None, protocol = 's', exclude_set = set()): if hostmap is None: - hostmap = constants.net.DEFAULT_SERVERS + hostmap = bitcoin.NetworkConstants.DEFAULT_SERVERS eligible = list(set(filter_protocol(hostmap, protocol)) - exclude_set) return random.choice(eligible) if eligible else None @@ -159,10 +161,13 @@ class Network(util.DaemonThread): - Member functions get_header(), get_interfaces(), get_local_height(), get_parameters(), get_server_height(), get_status_value(), - is_connected(), set_parameters(), stop() + is_connected(), set_parameters(), stop(), follow_chain() """ def __init__(self, config=None): + self.disconnected_servers = {} + self.stopped = True + asyncio.set_event_loop(None) if config is None: config = {} # Do not use mutables as default values! util.DaemonThread.__init__(self) @@ -185,7 +190,6 @@ class Network(util.DaemonThread): if not self.default_server: self.default_server = pick_random_server() self.lock = threading.Lock() - self.pending_sends = [] self.message_id = 0 self.debug = False self.irc_servers = {} # returned by interface (list from irc) @@ -219,11 +223,6 @@ class Network(util.DaemonThread): self.interface = None self.interfaces = {} self.auto_connect = self.config.get('auto_connect', True) - self.connecting = set() - self.requested_chunks = set() - self.socket_queue = queue.Queue() - self.start_network(deserialize_server(self.default_server)[2], - deserialize_proxy(self.config.get('proxy'))) def register_callback(self, callback, events): with self.lock: @@ -290,19 +289,20 @@ class Network(util.DaemonThread): def is_up_to_date(self): return self.unanswered_requests == {} - def queue_request(self, method, params, interface=None): + async def queue_request(self, method, params, interface=None): # If you want to queue a request on any interface it must go # through this function so message ids are properly tracked if interface is None: + assert self.interface is not None interface = self.interface message_id = self.message_id self.message_id += 1 if self.debug: self.print_error(interface.host, "-->", method, params, message_id) - interface.queue_request(method, params, message_id) + await interface.queue_request(method, params, message_id) return message_id - def send_subscriptions(self): + async def send_subscriptions(self): self.print_error('sending subscriptions to', self.interface.server, len(self.unanswered_requests), len(self.subscribed_addresses)) self.sub_cache.clear() # Resend unanswered requests @@ -310,24 +310,25 @@ class Network(util.DaemonThread): self.unanswered_requests = {} if self.interface.ping_required(): params = [ELECTRUM_VERSION, PROTOCOL_VERSION] - self.queue_request('server.version', params, self.interface) + await self.queue_request('server.version', params, self.interface) for request in requests: - message_id = self.queue_request(request[0], request[1]) + message_id = await self.queue_request(request[0], request[1]) self.unanswered_requests[message_id] = request - self.queue_request('server.banner', []) - self.queue_request('server.donation_address', []) - self.queue_request('server.peers.subscribe', []) - self.request_fee_estimates() - self.queue_request('blockchain.relayfee', []) + await self.queue_request('server.banner', []) + await self.queue_request('server.donation_address', []) + await self.queue_request('server.peers.subscribe', []) + await self.request_fee_estimates() + await self.queue_request('blockchain.relayfee', []) + if self.interface.ping_required(): + params = [ELECTRUM_VERSION, PROTOCOL_VERSION] + await self.queue_request('server.version', params, self.interface) for h in self.subscribed_addresses: - self.queue_request('blockchain.scripthash.subscribe', [h]) + await self.queue_request('blockchain.scripthash.subscribe', [h]) - def request_fee_estimates(self): - from .simple_config import FEE_ETA_TARGETS + async def request_fee_estimates(self): self.config.requested_fee_estimates() - self.queue_request('mempool.get_fee_histogram', []) - for i in FEE_ETA_TARGETS: - self.queue_request('blockchain.estimatefee', [i]) + for i in bitcoin.FEE_TARGETS: + await self.queue_request('blockchain.estimatefee', [i]) def get_status_value(self, key): if key == 'status': @@ -336,8 +337,6 @@ class Network(util.DaemonThread): value = self.banner elif key == 'fee': value = self.config.fee_estimates - elif key == 'fee_histogram': - value = self.config.mempool_fees elif key == 'updated': value = (self.get_local_height(), self.get_server_height()) elif key == 'servers': @@ -365,7 +364,7 @@ class Network(util.DaemonThread): return list(self.interfaces.keys()) def get_servers(self): - out = constants.net.DEFAULT_SERVERS + out = bitcoin.NetworkConstants.DEFAULT_SERVERS if self.irc_servers: out.update(filter_version(self.irc_servers.copy())) else: @@ -378,68 +377,64 @@ class Network(util.DaemonThread): out[host] = { protocol:port } return out - def start_interface(self, server): - if (not server in self.interfaces and not server in self.connecting): - if server == self.default_server: - self.print_error("connecting to %s as new interface" % server) - self.set_status('connecting') - self.connecting.add(server) - c = Connection(server, self.socket_queue, self.config.path) + async def start_interface(self, server): + assert not self.connecting[server].locked() + async with self.connecting[server]: + if (not server in self.interfaces): + if server == self.default_server: + self.print_error("connecting to %s as new interface" % server) + self.set_status('connecting') + return await self.new_interface(server) - def start_random_interface(self): - exclude_set = self.disconnected_servers.union(set(self.interfaces)) + async def start_random_interface(self): + exclude_set = set(self.disconnected_servers.keys()).union(set(self.interfaces.keys())) server = pick_random_server(self.get_servers(), self.protocol, exclude_set) if server: - self.start_interface(server) + return await self.start_interface(server) - def start_interfaces(self): - self.start_interface(self.default_server) + async def start_interfaces(self): + await self.start_interface(self.default_server) + self.print_error("started default server interface") for i in range(self.num_server - 1): - self.start_random_interface() + await self.start_random_interface() - def set_proxy(self, proxy): - self.proxy = proxy - # Store these somewhere so we can un-monkey-patch - if not hasattr(socket, "_socketobject"): - socket._socketobject = socket.socket - socket._getaddrinfo = socket.getaddrinfo - if proxy: - self.print_error('setting proxy', proxy) - proxy_mode = proxy_modes.index(proxy["mode"]) + 1 - socks.setdefaultproxy(proxy_mode, - proxy["host"], - int(proxy["port"]), - # socks.py seems to want either None or a non-empty string - username=(proxy.get("user", "") or None), - password=(proxy.get("password", "") or None)) - socket.socket = socks.socksocket - # prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy - socket.getaddrinfo = lambda *args: [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))] - else: - socket.socket = socket._socketobject - socket.getaddrinfo = socket._getaddrinfo - - def start_network(self, protocol, proxy): + async def start_network(self, protocol, proxy): + self.stopped = False assert not self.interface and not self.interfaces - assert not self.connecting and self.socket_queue.empty() + assert all(not i.locked() for i in self.connecting.values()) self.print_error('starting network') - self.disconnected_servers = set([]) self.protocol = protocol - self.set_proxy(proxy) - self.start_interfaces() + self.proxy = proxy + await self.start_interfaces() - def stop_network(self): + async def stop_network(self): + self.stopped = True self.print_error("stopping network") - for interface in list(self.interfaces.values()): - self.close_interface(interface) + async def stop(interface): + await self.connection_down(interface.server, "stopping network") + await asyncio.wait_for(asyncio.shield(interface.future), 3) + stopped_this_time = set() + while self.interfaces: + do_next = next(iter(self.interfaces.values())) + assert do_next not in stopped_this_time + for i in self.disconnected_servers: + assert i not in self.interfaces.keys() + assert i != do_next.server + stopped_this_time.add(do_next) + await stop(do_next) if self.interface: - self.close_interface(self.interface) + assert self.interface.server in stopped_this_time, self.interface.server + await asyncio.wait_for(asyncio.shield(self.process_pending_sends_job), 5) assert self.interface is None - assert not self.interfaces - self.connecting = set() - # Get a new queue - no old pending connections thanks! - self.socket_queue = queue.Queue() + for i in range(100): + if not self.interfaces: + break + else: + await asyncio.sleep(0.1) + if self.interfaces: + assert False, "interfaces not empty after waiting: " + repr(self.interfaces) + # called from the Qt thread def set_parameters(self, host, port, protocol, proxy, auto_connect): proxy_str = serialize_proxy(proxy) server = serialize_server(host, port, protocol) @@ -459,25 +454,42 @@ class Network(util.DaemonThread): return self.auto_connect = auto_connect if self.proxy != proxy or self.protocol != protocol: - # Restart the network defaulting to the given server - self.stop_network() - self.default_server = server - self.start_network(protocol, proxy) + async def job(): + try: + async with self.restartLock: + # Restart the network defaulting to the given server + await self.stop_network() + self.print_error("STOOOOOOOOOOOOOOOOOOOOOOOOOOPPED") + self.default_server = server + async with self.all_server_locks("restart job"): + self.disconnected_servers = {} + await self.start_network(protocol, proxy) + except BaseException as e: + traceback.print_exc() + self.print_error("exception from restart job") + if self.restartLock.locked(): + self.print_error("NOT RESTARTING, RESTART IN PROGRESS") + return + asyncio.run_coroutine_threadsafe(job(), self.loop) elif self.default_server != server: - self.switch_to_interface(server) + async def job(): + await self.switch_to_interface(server) + asyncio.run_coroutine_threadsafe(job(), self.loop) else: - self.switch_lagging_interface() - self.notify('updated') + async def job(): + await self.switch_lagging_interface() + self.notify('updated') + asyncio.run_coroutine_threadsafe(job(), self.loop) - def switch_to_random_interface(self): + async def switch_to_random_interface(self): '''Switch to a random connected server other than the current one''' servers = self.get_interfaces() # Those in connected state if self.default_server in servers: servers.remove(self.default_server) if servers: - self.switch_to_interface(random.choice(servers)) + await self.switch_to_interface(random.choice(servers)) - def switch_lagging_interface(self): + async def switch_lagging_interface(self): '''If auto_connect and lagging, switch interface''' if self.server_is_lagging() and self.auto_connect: # switch to one that has the correct header (not height) @@ -485,9 +497,9 @@ class Network(util.DaemonThread): filtered = list(map(lambda x:x[0], filter(lambda x: x[1].tip_header==header, self.interfaces.items()))) if filtered: choice = random.choice(filtered) - self.switch_to_interface(choice) + await self.switch_to_interface(choice) - def switch_to_interface(self, server): + async def switch_to_interface(self, server): '''Switch to server as our interface. If no connection exists nor being opened, start a thread to connect. The actual switch will happen on receipt of the connection notification. Do nothing @@ -495,7 +507,7 @@ class Network(util.DaemonThread): self.default_server = server if server not in self.interfaces: self.interface = None - self.start_interface(server) + await self.start_interface(server) return i = self.interfaces[server] if self.interface != i: @@ -504,16 +516,26 @@ class Network(util.DaemonThread): # fixme: we don't want to close headers sub #self.close_interface(self.interface) self.interface = i - self.send_subscriptions() + await self.send_subscriptions() self.set_status('connected') self.notify('updated') - def close_interface(self, interface): + async def close_interface(self, interface): + self.print_error('closing connection', interface.server) if interface: if interface.server in self.interfaces: self.interfaces.pop(interface.server) if interface.server == self.default_server: self.interface = None + if interface.jobs: + for i in interface.jobs: + asyncio.wait_for(i, 3) + assert interface.boot_job + try: + await asyncio.wait_for(asyncio.shield(interface.boot_job), 6) # longer than any timeout while connecting + except TimeoutError: + self.print_error("taking too long", interface.server) + raise interface.close() def add_recent_server(self, server): @@ -524,7 +546,7 @@ class Network(util.DaemonThread): self.recent_servers = self.recent_servers[0:20] self.save_recent_servers() - def process_response(self, interface, response, callbacks): + async def process_response(self, interface, response, callbacks): if self.debug: self.print_error("<--", response) error = response.get('error') @@ -537,7 +559,7 @@ class Network(util.DaemonThread): interface.server_version = result elif method == 'blockchain.headers.subscribe': if error is None: - self.on_notify_header(interface, result) + await self.on_notify_header(interface, result) elif method == 'server.peers.subscribe': if error is None: self.irc_servers = parse_servers(result) @@ -549,11 +571,6 @@ class Network(util.DaemonThread): elif method == 'server.donation_address': if error is None: self.donation_address = result - elif method == 'mempool.get_fee_histogram': - if error is None: - self.print_error('fee_histogram', result) - self.config.mempool_fees = result - self.notify('fee_histogram') elif method == 'blockchain.estimatefee': if error is None and result > 0: i = params[0] @@ -566,20 +583,25 @@ class Network(util.DaemonThread): self.relay_fee = int(result * COIN) self.print_error("relayfee", self.relay_fee) elif method == 'blockchain.block.get_chunk': - self.on_get_chunk(interface, response) + await self.on_get_chunk(interface, response) elif method == 'blockchain.block.get_header': - self.on_get_header(interface, response) + await self.on_get_header(interface, response) for callback in callbacks: - callback(response) + if asyncio.iscoroutinefunction(callback): + if response is None: + print("RESPONSE IS NONE") + await callback(response) + else: + callback(response) def get_index(self, method, params): """ hashable index for subscriptions and cache""" return str(method) + (':' + str(params[0]) if params else '') - def process_responses(self, interface): - responses = interface.get_responses() - for request, response in responses: + async def process_responses(self, interface): + while interface.is_running(): + request, response = await interface.get_response() if request: method, params, message_id = request k = self.get_index(method, params) @@ -604,8 +626,8 @@ class Network(util.DaemonThread): self.subscribed_addresses.add(params[0]) else: if not response: # Closed remotely / misbehaving - self.connection_down(interface.server) - break + if interface.is_running(): await self.connection_down(interface.server, "no response in process responses") + return # Rewrite response shape to match subscription request response method = response.get('method') params = response.get('params') @@ -622,7 +644,9 @@ class Network(util.DaemonThread): if method.endswith('.subscribe'): self.sub_cache[k] = response # Response is now in canonical form - self.process_response(interface, response, callbacks) + await self.process_response(interface, response, callbacks) + await self.run_coroutines() # Synchronizer and Verifier + def addr_to_scripthash(self, addr): h = bitcoin.address_to_scripthash(addr) @@ -651,37 +675,40 @@ class Network(util.DaemonThread): def send(self, messages, callback): '''Messages is a list of (method, params) tuples''' messages = list(messages) - with self.lock: - self.pending_sends.append((messages, callback)) + async def job(future): + await self.pending_sends.put((messages, callback)) + if future: future.set_result("put pending send: " + repr(messages)) + asyncio.run_coroutine_threadsafe(job(None), self.loop) - def process_pending_sends(self): + async def process_pending_sends(self): # Requests needs connectivity. If we don't have an interface, # we cannot process them. if not self.interface: + await asyncio.sleep(1) return - with self.lock: - sends = self.pending_sends - self.pending_sends = [] + try: + messages, callback = await asyncio.wait_for(self.pending_sends.get(), 1) + except TimeoutError: + return - for messages, callback in sends: - for method, params in messages: - r = None - if method.endswith('.subscribe'): - k = self.get_index(method, params) - # add callback to list - l = self.subscriptions.get(k, []) - if callback not in l: - l.append(callback) - self.subscriptions[k] = l - # check cached response for subscriptions - r = self.sub_cache.get(k) - if r is not None: - util.print_error("cache hit", k) - callback(r) - else: - message_id = self.queue_request(method, params) - self.unanswered_requests[message_id] = method, params, callback + for method, params in messages: + r = None + if method.endswith('.subscribe'): + k = self.get_index(method, params) + # add callback to list + l = self.subscriptions.get(k, []) + if callback not in l: + l.append(callback) + self.subscriptions[k] = l + # check cached response for subscriptions + r = self.sub_cache.get(k) + if r is not None: + util.print_error("cache hit", k) + callback(r) + else: + message_id = await self.queue_request(method, params) + self.unanswered_requests[message_id] = method, params, callback def unsubscribe(self, callback): '''Unsubscribe a callback to free object references to enable GC.''' @@ -693,134 +720,91 @@ class Network(util.DaemonThread): if callback in v: v.remove(callback) - def connection_down(self, server): + async def connection_down(self, server, reason=None): '''A connection to server either went down, or was never made. We distinguish by whether it is in self.interfaces.''' - self.disconnected_servers.add(server) - if server == self.default_server: - self.set_status('disconnected') - if server in self.interfaces: - self.close_interface(self.interfaces[server]) - self.notify('interfaces') - for b in self.blockchains.values(): - if b.catch_up == server: - b.catch_up = None + async with self.all_server_locks("connection down"): + if server in self.disconnected_servers: + try: + raise Exception("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason)) + except: + traceback.print_exc() + sys.exit(1) + return + self.print_error("connection down", server) + self.disconnected_servers[server] = reason + if server == self.default_server: + self.set_status('disconnected') + if server in self.interfaces: + await self.close_interface(self.interfaces[server]) + self.notify('interfaces') + for b in self.blockchains.values(): + if b.catch_up == server: + b.catch_up = None - def new_interface(self, server, socket): + async def new_interface(self, server): # todo: get tip first, then decide which checkpoint to use. self.add_recent_server(server) - interface = Interface(server, socket) + interface = Interface(server, self.config.path, self.proxy, lambda: not self.stopped and server in self.interfaces) + interface.future = asyncio.Future() interface.blockchain = None interface.tip_header = None interface.tip = 0 interface.mode = 'default' interface.request = None + interface.jobs = None + interface.boot_job = None + self.boot_interface(interface) + assert server not in self.interfaces + assert not self.stopped self.interfaces[server] = interface - self.queue_request('blockchain.headers.subscribe', [], interface) - if server == self.default_server: - self.switch_to_interface(server) - #self.notify('interfaces') + return interface - def maintain_sockets(self): - '''Socket maintenance.''' - # Responses to connection attempts? - while not self.socket_queue.empty(): - server, socket = self.socket_queue.get() - if server in self.connecting: - self.connecting.remove(server) - if socket: - self.new_interface(server, socket) - else: - self.connection_down(server) + async def request_chunk(self, interface, idx): + interface.print_error("requesting chunk %d" % idx) + await self.queue_request('blockchain.block.get_chunk', [idx], interface) + interface.request = idx + interface.req_time = time.time() - # Send pings and shut down stale interfaces - # must use copy of values - for interface in list(self.interfaces.values()): - if interface.has_timed_out(): - self.connection_down(interface.server) - elif interface.ping_required(): - params = [ELECTRUM_VERSION, PROTOCOL_VERSION] - self.queue_request('server.version', params, interface) - - now = time.time() - # nodes - if len(self.interfaces) + len(self.connecting) < self.num_server: - self.start_random_interface() - if now - self.nodes_retry_time > NODES_RETRY_INTERVAL: - self.print_error('network: retrying connections') - self.disconnected_servers = set([]) - self.nodes_retry_time = now - - # main interface - if not self.is_connected(): - if self.auto_connect: - if not self.is_connecting(): - self.switch_to_random_interface() - else: - if self.default_server in self.disconnected_servers: - if now - self.server_retry_time > SERVER_RETRY_INTERVAL: - self.disconnected_servers.remove(self.default_server) - self.server_retry_time = now - else: - self.switch_to_interface(self.default_server) - else: - if self.config.is_fee_estimates_update_required(): - self.request_fee_estimates() - - def request_chunk(self, interface, index): - if index in self.requested_chunks: - return - interface.print_error("requesting chunk %d" % index) - self.requested_chunks.add(index) - self.queue_request('blockchain.block.get_chunk', [index], interface) - - def on_get_chunk(self, interface, response): + async def on_get_chunk(self, interface, response): '''Handle receiving a chunk of block headers''' error = response.get('error') result = response.get('result') params = response.get('params') - blockchain = interface.blockchain if result is None or params is None or error is not None: interface.print_error(error or 'bad response') return index = params[0] - # Ignore unsolicited chunks - if index not in self.requested_chunks: - interface.print_error("received chunk %d (unsolicited)" % index) - return - else: - interface.print_error("received chunk %d" % index) - self.requested_chunks.remove(index) - connect = blockchain.connect_chunk(index, result) + connect = interface.blockchain.connect_chunk(index, result) if not connect: - self.connection_down(interface.server) + await self.connection_down(interface.server, "could not connect chunk") return # If not finished, get the next chunk - if index >= len(blockchain.checkpoints) and blockchain.height() < interface.tip: - self.request_chunk(interface, index+1) + if interface.blockchain.height() < interface.tip: + await self.request_chunk(interface, index+1) else: interface.mode = 'default' - interface.print_error('catch up done', blockchain.height()) - blockchain.catch_up = None + interface.print_error('catch up done', interface.blockchain.height()) + interface.blockchain.catch_up = None self.notify('updated') - def request_header(self, interface, height): + async def request_header(self, interface, height): #interface.print_error("requesting header %d" % height) - self.queue_request('blockchain.block.get_header', [height], interface) + await self.queue_request('blockchain.block.get_header', [height], interface) interface.request = height interface.req_time = time.time() - def on_get_header(self, interface, response): + async def on_get_header(self, interface, response): '''Handle receiving a single block header''' header = response.get('result') if not header: interface.print_error(response) - self.connection_down(interface.server) + await self.connection_down(interface.server, "no header in on_get_header") return height = header.get('block_height') if interface.request != height: interface.print_error("unsolicited header",interface.request, height) - self.connection_down(interface.server) + await self.connection_down(interface.server, "unsolicited header") return chain = blockchain.check_header(header) if interface.mode == 'backward': @@ -840,7 +824,7 @@ class Network(util.DaemonThread): assert next_height >= self.max_checkpoint(), (interface.bad, interface.good) else: if height == 0: - self.connection_down(interface.server) + await self.connection_down(interface.server, "height zero in on_get_header") next_height = None else: interface.bad = height @@ -859,7 +843,7 @@ class Network(util.DaemonThread): next_height = (interface.bad + interface.good) // 2 assert next_height >= self.max_checkpoint() elif not interface.blockchain.can_connect(interface.bad_header, check_height=False): - self.connection_down(interface.server) + await self.connection_down(interface.server, "blockchain can't connect") next_height = None else: branch = self.blockchains.get(interface.bad) @@ -918,7 +902,7 @@ class Network(util.DaemonThread): # exit catch_up state interface.print_error('catch up done', interface.blockchain.height()) interface.blockchain.catch_up = None - self.switch_lagging_interface() + await self.switch_lagging_interface() self.notify('updated') else: @@ -926,9 +910,9 @@ class Network(util.DaemonThread): # If not finished, get the next header if next_height: if interface.mode == 'catch_up' and interface.tip > next_height + 50: - self.request_chunk(interface, next_height // 2016) + await self.request_chunk(interface, next_height // 2016) else: - self.request_header(interface, next_height) + await self.request_header(interface, next_height) else: interface.mode = 'default' interface.request = None @@ -936,39 +920,65 @@ class Network(util.DaemonThread): # refresh network dialog self.notify('interfaces') - def maintain_requests(self): + async def maintain_requests(self): for interface in list(self.interfaces.values()): if interface.request and time.time() - interface.request_time > 20: interface.print_error("blockchain request timed out") - self.connection_down(interface.server) - continue + await self.connection_down(interface.server, "blockchain request timed out") - def wait_on_sockets(self): - # Python docs say Windows doesn't like empty selects. - # Sleep to prevent busy looping - if not self.interfaces: - time.sleep(0.1) - return - rin = [i for i in self.interfaces.values()] - win = [i for i in self.interfaces.values() if i.num_requests()] - try: - rout, wout, xout = select.select(rin, win, [], 0.1) - except socket.error as e: - # TODO: py3, get code from e - code = None - if code == errno.EINTR: - return - raise - assert not xout - for interface in wout: - interface.send_requests() - for interface in rout: - self.process_responses(interface) + def make_send_requests_job(self, interface): + async def job(): + try: + while interface.is_running(): + try: + result = await asyncio.wait_for(asyncio.shield(interface.send_request()), 1) + except TimeoutError: + continue + if not result and interface.is_running(): + await self.connection_down(interface.server, "send_request returned false") + except GeneratorExit: + pass + except: + if interface.is_running(): + traceback.print_exc() + self.print_error("FATAL ERROR ^^^") + return asyncio.ensure_future(job()) + + def make_process_responses_job(self, interface): + async def job(): + try: + await self.process_responses(interface) + except GeneratorExit: + pass + except OSError: + await self.connection_down(interface.server, "OSError in process_responses") + self.print_error("OS error, connection downed") + except BaseException: + if interface.is_running(): + traceback.print_exc() + self.print_error("FATAL ERROR in process_responses") + return asyncio.ensure_future(job()) + + def make_process_pending_sends_job(self): + async def job(): + try: + while not self.stopped: + try: + await asyncio.wait_for(asyncio.shield(self.process_pending_sends()), 1) + except TimeoutError: + continue + #except CancelledError: + # pass + except BaseException as e: + if not self.stopped: + traceback.print_exc() + self.print_error("FATAL ERROR in process_pending_sends") + return asyncio.ensure_future(job()) def init_headers_file(self): b = self.blockchains[0] filename = b.path() - length = 80 * len(constants.net.CHECKPOINTS) * 2016 + length = 80 * len(bitcoin.NetworkConstants.CHECKPOINTS) * 2016 if not os.path.exists(filename) or os.path.getsize(filename) < length: with open(filename, 'wb') as f: if length>0: @@ -977,23 +987,184 @@ class Network(util.DaemonThread): with b.lock: b.update_size() - def run(self): - self.init_headers_file() - while self.is_running(): - self.maintain_sockets() - self.wait_on_sockets() - self.maintain_requests() - self.run_jobs() # Synchronizer and Verifier - self.process_pending_sends() - self.stop_network() - self.on_stop() + def boot_interface(self, interface): + async def job(): + try: + await self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface) + if not await interface.send_request(): + if interface.is_running(): + asyncio.ensure_future(self.connection_down(interface.server, "send_request false in boot_interface")) + interface.future.set_result("could not send request") + return + if not interface.is_running(): + interface.future.set_result("stopped after sending request") + return + try: + await asyncio.wait_for(interface.get_response(), 1) + except TimeoutError: + if interface.is_running(): + asyncio.ensure_future(self.connection_down(interface.server, "timeout in boot_interface while getting response")) + interface.future.set_result("timeout while getting response") + return + if not interface.is_running(): + interface.future.set_result("stopped after getting response") + return + #self.interfaces[interface.server] = interface + await self.queue_request('blockchain.headers.subscribe', [], interface) + if interface.server == self.default_server: + await asyncio.wait_for(self.switch_to_interface(interface.server), 1) + interface.jobs = [asyncio.ensure_future(x) for x in [self.make_ping_job(interface), self.make_send_requests_job(interface), self.make_process_responses_job(interface)]] + gathered = asyncio.gather(*interface.jobs) + while interface.is_running(): + try: + await asyncio.wait_for(asyncio.shield(gathered), 1) + except TimeoutError: + pass + interface.future.set_result("finished") + return + #self.notify('interfaces') + except GeneratorExit: + self.print_error(interface.server, "GENERATOR EXIT") + pass + except BaseException as e: + if interface.is_running(): + traceback.print_exc() + self.print_error("FATAL ERROR in boot_interface") + raise e + interface.boot_job = asyncio.ensure_future(job()) + interface.boot_job.server = interface.server + def boot_job_cb(fut): + try: + fut.exception() + except: + traceback.print_exc() + self.print_error("Previous exception in boot_job") + interface.boot_job.add_done_callback(boot_job_cb) - def on_notify_header(self, interface, header): + def make_ping_job(self, interface): + async def job(): + try: + while interface.is_running(): + await asyncio.sleep(1) + # Send pings and shut down stale interfaces + # must use copy of values + if interface.has_timed_out(): + self.print_error(interface.server, "timed out") + await self.connection_down(interface.server, "time out in ping_job") + return + elif interface.ping_required(): + params = [ELECTRUM_VERSION, PROTOCOL_VERSION] + await self.queue_request('server.version', params, interface) + except GeneratorExit: + pass + except: + if interface.is_running(): + traceback.print_exc() + self.print_error("FATAL ERROR in ping_job") + return asyncio.ensure_future(job()) + + def all_server_locks(self, ctx): + class AllLocks: + def __init__(self2): + self2.list = list(self.get_servers().keys()) + self2.ctx = ctx + async def __aenter__(self2): + for i in self2.list: + await asyncio.wait_for(self.connecting[i].acquire(), 3) + async def __aexit__(self2, exc_type, exc, tb): + for i in self2.list: + self.connecting[i].release() + return AllLocks() + + async def maintain_interfaces(self): + if self.stopped: return + + now = time.time() + # nodes + if len(self.interfaces) + sum((1 if x.locked() else 0) for x in self.connecting.values()) < self.num_server: + await self.start_random_interface() + if now - self.nodes_retry_time > NODES_RETRY_INTERVAL: + self.print_error('network: retrying connections') + async with self.all_server_locks("maintain_interfaces"): + self.disconnected_servers = {} + self.nodes_retry_time = now + + # main interface + if not self.is_connected(): + if self.auto_connect: + if not self.is_connecting(): + await self.switch_to_random_interface() + else: + if self.default_server in self.disconnected_servers: + if now - self.server_retry_time > SERVER_RETRY_INTERVAL: + async with self.all_server_locks("maintain_interfaces 2"): + del self.disconnected_servers[self.default_server] + self.server_retry_time = now + else: + await self.switch_to_interface(self.default_server) + else: + if self.config.is_fee_estimates_update_required(): + await self.request_fee_estimates() + + + def run(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) # this does not set the loop on the qt thread + self.loop = loop # so we store it in the instance too + self.init_headers_file() + self.pending_sends = asyncio.Queue() + self.connecting = collections.defaultdict(asyncio.Lock) + self.restartLock = asyncio.Lock() + + async def job(): + try: + await self.start_network(deserialize_server(self.default_server)[2], + deserialize_proxy(self.config.get('proxy'))) + self.process_pending_sends_job = self.make_process_pending_sends_job() + except: + traceback.print_exc() + self.print_error("Previous exception in start_network") + raise + asyncio.ensure_future(job()) + run_future = asyncio.Future() + self.run_forever_coroutines() + asyncio.ensure_future(self.run_async(run_future)) + + loop.run_until_complete(run_future) + assert self.forever_coroutines_task.done() + run_future.exception() + self.print_error("run future result", run_future.result()) + loop.close() + + async def run_async(self, future): + try: + while self.is_running(): + #self.print_error(len(asyncio.Task.all_tasks())) + await asyncio.sleep(1) + await self.maintain_requests() + await self.maintain_interfaces() + self.run_jobs() + await self.stop_network() + self.on_stop() + for i in asyncio.Task.all_tasks(): + if asyncio.Task.current_task() == i: continue + try: + await asyncio.wait_for(asyncio.shield(i), 2) + except TimeoutError: + self.print_error("TOO SLOW TO SHUT DOWN, CANCELLING", i) + i.cancel() + except CancelledError: + pass + future.set_result("run_async done") + except BaseException as e: + future.set_exception(e) + + async def on_notify_header(self, interface, header): height = header.get('block_height') if not height: return if height < self.max_checkpoint(): - self.connection_down(interface.server) + await self.connection_down(interface.server, "height under max checkpoint in on_notify_header") return interface.tip_header = header interface.tip = height @@ -1002,7 +1173,7 @@ class Network(util.DaemonThread): b = blockchain.check_header(header) if b: interface.blockchain = b - self.switch_lagging_interface() + await self.switch_lagging_interface() self.notify('updated') self.notify('interfaces') return @@ -1010,7 +1181,7 @@ class Network(util.DaemonThread): if b: interface.blockchain = b b.save_header(header) - self.switch_lagging_interface() + await self.switch_lagging_interface() self.notify('updated') self.notify('interfaces') return @@ -1019,17 +1190,14 @@ class Network(util.DaemonThread): interface.mode = 'backward' interface.bad = height interface.bad_header = header - self.request_header(interface, min(tip +1, height - 1)) + await self.request_header(interface, min(tip + 1, height - 1)) else: chain = self.blockchains[0] if chain.catch_up is None: chain.catch_up = interface interface.mode = 'catch_up' interface.blockchain = chain - self.print_error("switching to catchup mode", tip, self.blockchains) - self.request_header(interface, 0) - else: - self.print_error("chain already catching up with", chain.catch_up.server) + await self.request_header(interface, 0) def blockchain(self): if self.interface and self.interface.blockchain is not None: @@ -1044,6 +1212,7 @@ class Network(util.DaemonThread): out[k] = r return out + # called from the Qt thread def follow_chain(self, index): blockchain = self.blockchains.get(index) if blockchain: @@ -1051,16 +1220,19 @@ class Network(util.DaemonThread): self.config.set_key('blockchain_index', index) for i in self.interfaces.values(): if i.blockchain == blockchain: - self.switch_to_interface(i.server) + asyncio.run_coroutine_threadsafe(self.switch_to_interface(i.server), self.loop) break else: raise BaseException('blockchain not found', index) - if self.interface: - server = self.interface.server - host, port, protocol, proxy, auto_connect = self.get_parameters() - host, port, protocol = server.split(':') - self.set_parameters(host, port, protocol, proxy, auto_connect) + # commented out on migration to asyncio. not clear if it + # relies on the coroutine to be done: + + #if self.interface: + # server = self.interface.server + # host, port, protocol, proxy, auto_connect = self.get_parameters() + # host, port, protocol = server.split(':') + # self.set_parameters(host, port, protocol, proxy, auto_connect) def get_local_height(self): return self.blockchain().height() @@ -1071,7 +1243,7 @@ class Network(util.DaemonThread): try: r = q.get(True, timeout) except queue.Empty: - raise util.TimeoutException(_('Server did not answer')) + raise BaseException('Server did not answer') if r.get('error'): raise BaseException(r.get('error')) return r.get('result') @@ -1093,4 +1265,35 @@ class Network(util.DaemonThread): f.write(json.dumps(cp, indent=4)) def max_checkpoint(self): - return max(0, len(constants.net.CHECKPOINTS) * 2016 - 1) + return max(0, len(bitcoin.NetworkConstants.CHECKPOINTS) * 2016 - 1) + + async def send_async(self, messages, callback=None): + """ if callback is None, it returns the result """ + chosenCallback = callback + if callback is None: + queue = asyncio.Queue() + chosenCallback = queue.put + assert type(messages[0]) is tuple and len(messages[0]) == 2, repr(messages) + " does not contain a pair-tuple in first position" + await self.pending_sends.put((messages, chosenCallback)) + if callback is None: + #assert queue.qsize() == 1, "queue does not have a single result, it has length " + str(queue.qsize()) + return await asyncio.wait_for(queue.get(), 5) + + async def asynchronous_get(self, request): + assert type(request) is tuple + assert type(request[1]) is list + res = await self.send_async([request]) + try: + return res.get("result") + except: + print("asynchronous_get could not get result from", res) + raise BaseException("Could not get result: " + repr(res)) + + async def broadcast_async(self, tx): + tx_hash = tx.txid() + try: + return True, await self.asynchronous_get(('blockchain.transaction.broadcast', [str(tx)])) + except BaseException as e: + traceback.print_exc() + print("previous trace was captured and printed in broadcast_async") + return False, str(e) diff --git a/lib/ssl_in_socks.py b/lib/ssl_in_socks.py new file mode 100644 index 00000000..9f46fbde --- /dev/null +++ b/lib/ssl_in_socks.py @@ -0,0 +1,85 @@ +import traceback +import ssl +from asyncio.sslproto import SSLProtocol +import aiosocks +import asyncio +from . import interface + +class AppProto(asyncio.Protocol): + def __init__(self, receivedQueue, connUpLock): + self.buf = bytearray() + self.receivedQueue = receivedQueue + self.connUpLock = connUpLock + def connection_made(self, transport): + self.connUpLock.release() + def data_received(self, data): + self.buf.extend(data) + NEWLINE = b"\n"[0] + for idx, val in enumerate(self.buf): + if NEWLINE == val: + asyncio.ensure_future(self.receivedQueue.put(bytes(self.buf[:idx+1]))) + self.buf = self.buf[idx+1:] + +def makeProtocolFactory(receivedQueue, connUpLock, ca_certs): + class MySSLProtocol(SSLProtocol): + def __init__(self): + context = interface.get_ssl_context(\ + cert_reqs=ssl.CERT_REQUIRED if ca_certs is not None else ssl.CERT_NONE,\ + ca_certs=ca_certs) + proto = AppProto(receivedQueue, connUpLock) + super().__init__(asyncio.get_event_loop(), proto, context, None) + return MySSLProtocol + +class ReaderEmulator: + def __init__(self, receivedQueue): + self.receivedQueue = receivedQueue + async def read(self, _bufferSize): + return await self.receivedQueue.get() + +class WriterEmulator: + def __init__(self, transport): + self.transport = transport + def write(self, data): + self.transport.write(data) + async def drain(self): + pass + def close(self): + self.transport.close() + +async def sslInSocksReaderWriter(socksAddr, socksAuth, host, port, ca_certs): + receivedQueue = asyncio.Queue() + connUpLock = asyncio.Lock() + await connUpLock.acquire() + transport, protocol = await aiosocks.create_connection(\ + makeProtocolFactory(receivedQueue, connUpLock, ca_certs),\ + proxy=socksAddr,\ + proxy_auth=socksAuth, dst=(host, port)) + await connUpLock.acquire() + return ReaderEmulator(receivedQueue), WriterEmulator(protocol._app_transport) + +if __name__ == "__main__": + async def l(fut): + try: + # aiosocks.Socks4Addr("127.0.0.1", 9050), None, "songbird.bauerj.eu", 50002, None) + args = aiosocks.Socks4Addr("127.0.0.1", 9050), None, "electrum.akinbo.org", 51002, None + reader, writer = await sslInSocksReaderWriter(*args) + writer.write(b'{"id":0,"method":"server.version","args":["3.0.2", "1.1"]}\n') + await writer.drain() + print(await reader.read(4096)) + writer.write(b'{"id":0,"method":"server.version","args":["3.0.2", "1.1"]}\n') + await writer.drain() + print(await reader.read(4096)) + writer.close() + fut.set_result("finished") + except BaseException as e: + fut.set_exception(e) + + def f(): + loop = asyncio.get_event_loop() + fut = asyncio.Future() + asyncio.ensure_future(l(fut)) + loop.run_until_complete(fut) + print(fut.result()) + loop.close() + + f() diff --git a/lib/util.py b/lib/util.py index 1714c78d..9122609c 100644 --- a/lib/util.py +++ b/lib/util.py @@ -148,6 +148,24 @@ class PrintError(object): def print_msg(self, *msg): print_msg("[%s]" % self.diagnostic_name(), *msg) +class ForeverCoroutineJob(PrintError): + """A job that is run from a thread's main loop. run() is + called from that thread's context. + """ + + async def run(self, is_running): + """Called once from the thread""" + pass + +class CoroutineJob(PrintError): + """A job that is run periodically from a thread's main loop. run() is + called from that thread's context. + """ + + async def run(self): + """Called periodically from the thread""" + pass + class ThreadJob(PrintError): """A job that is run periodically from a thread's main loop. run() is called from that thread's context. @@ -192,6 +210,38 @@ class DaemonThread(threading.Thread, PrintError): self.running_lock = threading.Lock() self.job_lock = threading.Lock() self.jobs = [] + self.coroutines = [] + self.forever_coroutines_task = None + + def add_coroutines(self, jobs): + for i in jobs: assert isinstance(i, CoroutineJob), i.__class__.__name__ + " does not inherit from CoroutineJob" + self.coroutines.extend(jobs) + + def set_forever_coroutines(self, jobs): + for i in jobs: assert isinstance(i, ForeverCoroutineJob), i.__class__.__name__ + " does not inherit from ForeverCoroutineJob" + async def put(): + await self.forever_coroutines_queue.put(jobs) + asyncio.run_coroutine_threadsafe(put(), self.loop) + + def run_forever_coroutines(self): + self.forever_coroutines_queue = asyncio.Queue() # making queue here because __init__ is called from non-network thread + self.loop = asyncio.get_event_loop() + async def getFromQueueAndStart(): + jobs = await self.forever_coroutines_queue.get() + await asyncio.gather(*[i.run(self.is_running) for i in jobs]) + self.print_error("FOREVER JOBS DONE") + self.forever_coroutines_task = asyncio.ensure_future(getFromQueueAndStart()) + return self.forever_coroutines_task + + async def run_coroutines(self): + for coroutine in self.coroutines: + assert isinstance(coroutine, CoroutineJob) + await coroutine.run() + + def remove_coroutines(self, jobs): + for i in jobs: assert isinstance(i, CoroutineJob) + for job in jobs: + self.coroutines.remove(job) def add_jobs(self, jobs): with self.job_lock: @@ -203,6 +253,7 @@ class DaemonThread(threading.Thread, PrintError): # malformed or malicious server responses with self.job_lock: for job in self.jobs: + assert isinstance(job, ThreadJob) try: job.run() except Exception as e: