Compare commits

...

10 Commits

7 changed files with 765 additions and 541 deletions

View File

@ -7,7 +7,7 @@ jsonrpclib-pelix==0.3.1
pbkdf2==1.3
protobuf==3.5.1
pyaes==1.6.1
PySocks==1.6.8
aiosocks==0.2.6
qrcode==5.3
requests==2.18.4
six==1.11.0

View File

@ -6,4 +6,4 @@ qrcode
protobuf
dnspython
jsonrpclib-pelix
PySocks>=1.6.6
aiosocks

View File

@ -4,7 +4,7 @@ from .wallet import Synchronizer, Wallet
from .storage import WalletStorage
from .coinchooser import COIN_CHOOSERS
from .network import Network, pick_random_server
from .interface import Connection, Interface
from .interface import Interface
from .simple_config import SimpleConfig, get_config, set_config
from . import bitcoin
from . import transaction

View File

@ -22,288 +22,257 @@
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import aiosocks
import os
import stat
import re
import socket
import ssl
import sys
import threading
import time
import traceback
import asyncio
import json
import asyncio.streams
from asyncio.sslproto import SSLProtocol
import io
import requests
from .util import print_error
from aiosocks.errors import SocksError
from concurrent.futures import TimeoutError
ca_path = requests.certs.where()
from .util import print_error
from .ssl_in_socks import sslInSocksReaderWriter
from . import util
from . import x509
from . import pem
def Connection(server, queue, config_path):
"""Makes asynchronous connections to a remote electrum server.
Returns the running thread that is making the connection.
Once the thread has connected, it finishes, placing a tuple on the
queue of the form (server, socket), where socket is None if
connection failed.
"""
host, port, protocol = server.rsplit(':', 2)
if not protocol in 'st':
raise Exception('Unknown protocol: %s' % protocol)
c = TcpConnection(server, queue, config_path)
c.start()
return c
class TcpConnection(threading.Thread, util.PrintError):
def __init__(self, server, queue, config_path):
threading.Thread.__init__(self)
self.config_path = config_path
self.queue = queue
self.server = server
self.host, self.port, self.protocol = self.server.rsplit(':', 2)
self.host = str(self.host)
self.port = int(self.port)
self.use_ssl = (self.protocol == 's')
self.daemon = True
def diagnostic_name(self):
return self.host
def check_host_name(self, peercert, name):
"""Simple certificate/host name checker. Returns True if the
certificate matches, False otherwise. Does not support
wildcards."""
# Check that the peer has supplied a certificate.
# None/{} is not acceptable.
if not peercert:
return False
if 'subjectAltName' in peercert:
for typ, val in peercert["subjectAltName"]:
if typ == "DNS" and val == name:
return True
else:
# Only check the subject DN if there is no subject alternative
# name.
cn = None
for attr, val in peercert["subject"]:
# Use most-specific (last) commonName attribute.
if attr == "commonName":
cn = val
if cn is not None:
return cn == name
return False
def get_simple_socket(self):
try:
l = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM)
except socket.gaierror:
self.print_error("cannot resolve hostname")
return
e = None
for res in l:
try:
s = socket.socket(res[0], socket.SOCK_STREAM)
s.settimeout(10)
s.connect(res[4])
s.settimeout(2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
return s
except BaseException as _e:
e = _e
continue
else:
self.print_error("failed to connect", str(e))
@staticmethod
def get_ssl_context(cert_reqs, ca_certs):
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_certs)
context.check_hostname = False
context.verify_mode = cert_reqs
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
return context
def get_socket(self):
if self.use_ssl:
cert_path = os.path.join(self.config_path, 'certs', self.host)
if not os.path.exists(cert_path):
is_new = True
s = self.get_simple_socket()
if s is None:
return
# try with CA first
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path)
s = context.wrap_socket(s, do_handshake_on_connect=True)
except ssl.SSLError as e:
print_error(e)
s = None
except:
return
if s and self.check_host_name(s.getpeercert(), self.host):
self.print_error("SSL certificate signed by CA")
return s
# get server certificate.
# Do not use ssl.get_server_certificate because it does not work with proxy
s = self.get_simple_socket()
if s is None:
return
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None)
s = context.wrap_socket(s)
except ssl.SSLError as e:
self.print_error("SSL error retrieving SSL certificate:", e)
return
except:
return
dercert = s.getpeercert(True)
s.close()
cert = ssl.DER_cert_to_PEM_cert(dercert)
# workaround android bug
cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
temporary_path = cert_path + '.temp'
with open(temporary_path,"w") as f:
f.write(cert)
else:
is_new = False
s = self.get_simple_socket()
if s is None:
return
if self.use_ssl:
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED,
ca_certs=(temporary_path if is_new else cert_path))
s = context.wrap_socket(s, do_handshake_on_connect=True)
except socket.timeout:
self.print_error('timeout')
return
except ssl.SSLError as e:
self.print_error("SSL error:", e)
if e.errno != 1:
return
if is_new:
rej = cert_path + '.rej'
if os.path.exists(rej):
os.unlink(rej)
os.rename(temporary_path, rej)
else:
with open(cert_path) as f:
cert = f.read()
try:
b = pem.dePem(cert, 'CERTIFICATE')
x = x509.X509(b)
except:
traceback.print_exc(file=sys.stderr)
self.print_error("wrong certificate")
return
try:
x.check_date()
except:
self.print_error("certificate has expired:", cert_path)
os.unlink(cert_path)
return
self.print_error("wrong certificate")
if e.errno == 104:
return
return
except BaseException as e:
self.print_error(e)
traceback.print_exc(file=sys.stderr)
return
if is_new:
self.print_error("saving certificate")
os.rename(temporary_path, cert_path)
return s
def run(self):
socket = self.get_socket()
if socket:
self.print_error("connected")
self.queue.put((self.server, socket))
class Interface(util.PrintError):
"""The Interface class handles a socket connected to a single remote
electrum server. It's exposed API is:
- Member functions close(), fileno(), get_responses(), has_timed_out(),
ping_required(), queue_request(), send_requests()
- Member functions close(), fileno(), get_response(), has_timed_out(),
ping_required(), queue_request(), send_request()
- Member variable server.
"""
def __init__(self, server, socket):
self.server = server
self.host, _, _ = server.rsplit(':', 2)
self.socket = socket
def __init__(self, server, config_path, proxy_config, is_running, exception_handler):
self.error_future = asyncio.Future()
self.error_future.add_done_callback(exception_handler)
self.is_running = lambda: is_running() and not self.error_future.done()
self.addr = self.auth = None
if proxy_config is not None:
if proxy_config["mode"] == "socks5":
self.addr = aiosocks.Socks5Addr(proxy_config["host"], proxy_config["port"])
self.auth = aiosocks.Socks5Auth(proxy_config["user"], proxy_config["password"]) if proxy_config["user"] != "" else None
elif proxy_config["mode"] == "socks4":
self.addr = aiosocks.Socks4Addr(proxy_config["host"], proxy_config["port"])
self.auth = aiosocks.Socks4Auth(proxy_config["password"]) if proxy_config["password"] != "" else None
else:
raise Exception("proxy mode not supported")
self.pipe = util.SocketPipe(socket)
self.pipe.set_timeout(0.0) # Don't wait for data
self.server = server
self.config_path = config_path
host, port, protocol = self.server.split(':')
self.host = host
self.port = int(port)
self.use_ssl = (protocol=='s')
self.reader = self.writer = None
self.lock = asyncio.Lock()
# Dump network messages. Set at runtime from the console.
self.debug = False
self.unsent_requests = []
self.unsent_requests = asyncio.PriorityQueue()
self.unanswered_requests = {}
# Set last ping to zero to ensure immediate ping
self.last_request = time.time()
self.last_ping = 0
self.closed_remotely = False
self.buf = bytes()
def conn_coro(self, context):
return asyncio.open_connection(self.host, self.port, ssl=context)
async def _save_certificate(self, cert_path, require_ca):
dercert = None
if require_ca:
context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path)
else:
context = get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None)
if self.addr is not None:
self.print_error("can't save certificate through socks!")
# just save the empty file to force use of PKI
# this will break all self-signed servers, of course
cert = ""
else:
reader, writer = await asyncio.wait_for(self.conn_coro(context), 3)
dercert = writer.get_extra_info('ssl_object').getpeercert(True)
# an exception will be thrown by now if require_ca is True (e.g. a certificate was supplied)
writer.close()
if not require_ca:
cert = ssl.DER_cert_to_PEM_cert(dercert)
else:
# Don't pin a CA signed certificate
cert = ""
temporary_path = cert_path + '.temp'
with open(temporary_path, "w") as f:
f.write(cert)
return temporary_path
async def _get_read_write(self):
async with self.lock:
if self.reader is not None and self.writer is not None:
return self.reader, self.writer, True
if self.use_ssl:
cert_path = os.path.join(self.config_path, 'certs', self.host)
if not os.path.exists(cert_path):
temporary_path = None
# first, we try to save a certificate signed through the PKI
try:
temporary_path = await self._save_certificate(cert_path, True)
except ssl.SSLError:
pass
except (TimeoutError, OSError) as e:
if not self.error_future.done(): self.error_future.set_result(e)
raise
# if the certificate verification failed, we try to save a self-signed certificate
if not temporary_path:
try:
temporary_path = await self._save_certificate(cert_path, False)
# we also catch SSLError here, but it shouldn't matter since no certificate is required,
# so the SSLError wouldn't mean certificate validation failed
except (TimeoutError, OSError) as e:
if not self.error_future.done(): self.error_future.set_result(e)
raise
if not temporary_path:
if not self.error_future.done(): self.error_future.set_result(ConnectionError("Could not get certificate"))
raise ConnectionError("Could not get certificate on second try")
is_new = True
else:
is_new = False
ca_certs = temporary_path if is_new else cert_path
size = os.stat(ca_certs)[stat.ST_SIZE]
self_signed = size != 0
if not self_signed:
ca_certs = ca_path
try:
if self.addr is not None:
if not self.use_ssl:
open_coro = aiosocks.open_connection(proxy=self.addr, proxy_auth=self.auth, dst=(self.host, self.port))
self.reader, self.writer = await asyncio.wait_for(open_coro, 5)
else:
ssl_in_socks_coro = sslInSocksReaderWriter(self.addr, self.auth, self.host, self.port, ca_certs)
self.reader, self.writer = await asyncio.wait_for(ssl_in_socks_coro, 5)
else:
context = get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_certs) if self.use_ssl else None
self.reader, self.writer = await asyncio.wait_for(self.conn_coro(context), 5)
except TimeoutError:
self.print_error("TimeoutError after getting certificate successfully...")
raise
except ssl.SSLError:
# FIXME TODO
assert not self_signed, "we shouldn't reject self-signed here since the certificate has been saved (has size {})".format(size)
raise
except BaseException as e:
if self.is_running():
if not isinstance(e, OSError):
traceback.print_exc()
self.print_error("Previous exception will now be reraised")
raise e
if self.use_ssl and is_new:
self.print_error("saving new certificate for", self.host)
os.rename(temporary_path, cert_path)
return self.reader, self.writer, False
async def send_all(self, list_of_requests):
_, w, usedExisting = await self._get_read_write()
starttime = time.time()
for i in list_of_requests:
w.write(json.dumps(i).encode("ascii") + b"\n")
await w.drain()
if time.time() - starttime > 2.5:
self.print_error("send_all: sending is taking too long. Used existing connection: ", usedExisting)
raise ConnectionError("sending is taking too long")
def close(self):
if self.writer:
self.writer.close()
def _try_extract(self):
try:
pos = self.buf.index(b"\n")
except ValueError:
return
obj = self.buf[:pos]
try:
obj = json.loads(obj.decode("ascii"))
except ValueError:
return
else:
self.buf = self.buf[pos+1:]
self.last_action = time.time()
return obj
async def get(self):
reader, _, _ = await self._get_read_write()
while self.is_running():
tried = self._try_extract()
if tried: return tried
temp = io.BytesIO()
try:
data = await asyncio.wait_for(reader.read(2**10), 1)
temp.write(data)
except asyncio.TimeoutError:
continue
self.buf += temp.getvalue()
def idle_time(self):
return time.time() - self.last_action
def diagnostic_name(self):
return self.host
def fileno(self):
# Needed for select
return self.socket.fileno()
def close(self):
if not self.closed_remotely:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
def queue_request(self, *args): # method, params, _id
async def queue_request(self, *args): # method, params, _id
'''Queue a request, later to be send with send_requests when the
socket is available for writing.
'''
self.request_time = time.time()
self.unsent_requests.append(args)
await self.unsent_requests.put((self.request_time, args))
await self.unsent_requests.join()
def num_requests(self):
'''Keep unanswered requests below 100'''
n = 100 - len(self.unanswered_requests)
return min(n, len(self.unsent_requests))
return min(n, self.unsent_requests.qsize())
def send_requests(self):
'''Sends queued requests. Returns False on failure.'''
async def send_request(self):
'''Sends a queued request.'''
make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i}
n = self.num_requests()
wire_requests = self.unsent_requests[0:n]
prio, request = await self.unsent_requests.get()
try:
self.pipe.send_all([make_dict(*r) for r in wire_requests])
except socket.error as e:
self.print_error("socket error:", e)
return False
self.unsent_requests = self.unsent_requests[n:]
for request in wire_requests:
await self.send_all([make_dict(*request)])
except (SocksError, OSError, TimeoutError) as e:
if type(e) is SocksError:
self.print_error(e)
await self.unsent_requests.put((prio, request))
return
self.unsent_requests.task_done()
if self.debug:
self.print_error("-->", request)
self.unanswered_requests[request[2]] = request
return True
self.last_action = time.time()
def ping_required(self):
'''Maintains time since last ping. Returns True if a ping should
@ -317,14 +286,14 @@ class Interface(util.PrintError):
def has_timed_out(self):
'''Returns True if the interface has timed out.'''
if self.error_future.done(): return True
if (self.unanswered_requests and time.time() - self.request_time > 10
and self.pipe.idle_time() > 10):
and self.idle_time() > 10):
self.print_error("timeout", len(self.unanswered_requests))
return True
return False
def get_responses(self):
async def get_response(self):
'''Call if there is data available on the socket. Returns a list of
(request, response) pairs. Notifications are singleton
unsolicited responses presumably as a result of prior
@ -333,34 +302,25 @@ class Interface(util.PrintError):
corresponding request. If the connection was closed remotely
or the remote server is misbehaving, a (None, None) will appear.
'''
responses = []
while True:
try:
response = self.pipe.get()
except util.timeout:
break
response = await self.get()
if not type(response) is dict:
responses.append((None, None))
if response is None:
self.closed_remotely = True
if self.is_running():
self.print_error("connection closed remotely")
break
return None, None
if self.debug:
self.print_error("<--", response)
wire_id = response.get('id', None)
if wire_id is None: # Notification
responses.append((None, response))
return None, response
else:
request = self.unanswered_requests.pop(wire_id, None)
if request:
responses.append((request, response))
return request, response
else:
self.print_error("unknown wire ID", wire_id)
responses.append((None, None)) # Signal
break
return responses
return None, None # Signal
def check_cert(host, cert):
try:

View File

@ -20,6 +20,9 @@
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import sys
import collections
from functools import partial
import time
import queue
import os
@ -30,15 +33,15 @@ import re
import select
from collections import defaultdict
import threading
import socket
import json
import asyncio
import traceback
import socks
from . import util
from . import bitcoin
from .bitcoin import *
from . import constants
from .interface import Connection, Interface
from .interface import Interface
from . import blockchain
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
from .i18n import _
@ -47,6 +50,7 @@ from .i18n import _
NODES_RETRY_INTERVAL = 60
SERVER_RETRY_INTERVAL = 10
from concurrent.futures import TimeoutError, CancelledError
def parse_servers(result):
""" parse servers list into dict format"""
@ -78,7 +82,7 @@ def filter_version(servers):
def is_recent(version):
try:
return util.normalize_version(version) >= util.normalize_version(PROTOCOL_VERSION)
except Exception as e:
except BaseException as e:
return False
return {k: v for k, v in servers.items() if is_recent(v.get('version'))}
@ -159,10 +163,18 @@ class Network(util.DaemonThread):
- Member functions get_header(), get_interfaces(), get_local_height(),
get_parameters(), get_server_height(), get_status_value(),
is_connected(), set_parameters(), stop()
is_connected(), set_parameters(), stop(), follow_chain()
"""
def __init__(self, config=None):
self.lock = threading.Lock()
# callbacks set by the GUI
self.callbacks = defaultdict(list)
self.set_status("disconnected")
self.disconnected_servers = {}
self.connecting = {}
self.stopped = True
asyncio.set_event_loop(None)
if config is None:
config = {} # Do not use mutables as default values!
util.DaemonThread.__init__(self)
@ -184,8 +196,6 @@ class Network(util.DaemonThread):
self.default_server = None
if not self.default_server:
self.default_server = pick_random_server()
self.lock = threading.Lock()
self.pending_sends = []
self.message_id = 0
self.debug = False
self.irc_servers = {} # returned by interface (list from irc)
@ -197,8 +207,6 @@ class Network(util.DaemonThread):
# callbacks passed with subscriptions
self.subscriptions = defaultdict(list)
self.sub_cache = {}
# callbacks set by the GUI
self.callbacks = defaultdict(list)
dir_path = os.path.join( self.config.path, 'certs')
if not os.path.exists(dir_path):
@ -214,16 +222,11 @@ class Network(util.DaemonThread):
self.server_retry_time = time.time()
self.nodes_retry_time = time.time()
# kick off the network. interface is the main server we are currently
# communicating with. interfaces is the set of servers we are connecting
# communicating with. interfaces is the set of servers we are connected
# to or have an ongoing connection with
self.interface = None
self.interfaces = {}
self.auto_connect = self.config.get('auto_connect', True)
self.connecting = set()
self.requested_chunks = set()
self.socket_queue = queue.Queue()
self.start_network(deserialize_server(self.default_server)[2],
deserialize_proxy(self.config.get('proxy')))
def register_callback(self, callback, events):
with self.lock:
@ -290,19 +293,20 @@ class Network(util.DaemonThread):
def is_up_to_date(self):
return self.unanswered_requests == {}
def queue_request(self, method, params, interface=None):
async def queue_request(self, method, params, interface=None):
# If you want to queue a request on any interface it must go
# through this function so message ids are properly tracked
if interface is None:
assert self.interface is not None
interface = self.interface
message_id = self.message_id
self.message_id += 1
if self.debug:
self.print_error(interface.host, "-->", method, params, message_id)
interface.queue_request(method, params, message_id)
await interface.queue_request(method, params, message_id)
return message_id
def send_subscriptions(self):
async def send_subscriptions(self):
self.print_error('sending subscriptions to', self.interface.server, len(self.unanswered_requests), len(self.subscribed_addresses))
self.sub_cache.clear()
# Resend unanswered requests
@ -310,24 +314,27 @@ class Network(util.DaemonThread):
self.unanswered_requests = {}
if self.interface.ping_required():
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
self.queue_request('server.version', params, self.interface)
await self.queue_request('server.version', params, self.interface)
for request in requests:
message_id = self.queue_request(request[0], request[1])
message_id = await self.queue_request(request[0], request[1])
self.unanswered_requests[message_id] = request
self.queue_request('server.banner', [])
self.queue_request('server.donation_address', [])
self.queue_request('server.peers.subscribe', [])
self.request_fee_estimates()
self.queue_request('blockchain.relayfee', [])
await self.queue_request('server.banner', [])
await self.queue_request('server.donation_address', [])
await self.queue_request('server.peers.subscribe', [])
await self.request_fee_estimates()
await self.queue_request('blockchain.relayfee', [])
if self.interface.ping_required():
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
await self.queue_request('server.version', params, self.interface)
for h in self.subscribed_addresses:
self.queue_request('blockchain.scripthash.subscribe', [h])
await self.queue_request('blockchain.scripthash.subscribe', [h])
def request_fee_estimates(self):
async def request_fee_estimates(self):
from .simple_config import FEE_ETA_TARGETS
self.config.requested_fee_estimates()
self.queue_request('mempool.get_fee_histogram', [])
await self.queue_request("mempool.get_fee_histogram", [])
for i in FEE_ETA_TARGETS:
self.queue_request('blockchain.estimatefee', [i])
await self.queue_request('blockchain.estimatefee', [i])
def get_status_value(self, key):
if key == 'status':
@ -336,7 +343,7 @@ class Network(util.DaemonThread):
value = self.banner
elif key == 'fee':
value = self.config.fee_estimates
elif key == 'fee_histogram':
elif key == "fee_histogram":
value = self.config.mempool_fees
elif key == 'updated':
value = (self.get_local_height(), self.get_server_height())
@ -378,68 +385,57 @@ class Network(util.DaemonThread):
out[host] = { protocol:port }
return out
def start_interface(self, server):
if (not server in self.interfaces and not server in self.connecting):
async def start_interface(self, server):
if server not in self.interfaces and server not in self.connecting and not self.stopped:
self.connecting[server] = asyncio.Future()
if server == self.default_server:
self.print_error("connecting to %s as new interface" % server)
self.set_status('connecting')
self.connecting.add(server)
c = Connection(server, self.socket_queue, self.config.path)
iface = self.new_interface(server)
success = await iface.launched
# if the interface was launched sucessfully, we save it in interfaces
# since that dictionary only stores "good" interfaces
if success:
assert server not in self.interfaces
self.interfaces[server] = iface
if not self.connecting[server].done(): self.connecting[server].set_result(True)
del self.connecting[server]
def start_random_interface(self):
exclude_set = self.disconnected_servers.union(set(self.interfaces))
async def start_random_interface(self):
exclude_set = set(self.disconnected_servers.keys()).union(set(self.interfaces.keys()))
server = pick_random_server(self.get_servers(), self.protocol, exclude_set)
if server:
self.start_interface(server)
def start_interfaces(self):
self.start_interface(self.default_server)
for i in range(self.num_server - 1):
self.start_random_interface()
def set_proxy(self, proxy):
self.proxy = proxy
# Store these somewhere so we can un-monkey-patch
if not hasattr(socket, "_socketobject"):
socket._socketobject = socket.socket
socket._getaddrinfo = socket.getaddrinfo
if proxy:
self.print_error('setting proxy', proxy)
proxy_mode = proxy_modes.index(proxy["mode"]) + 1
socks.setdefaultproxy(proxy_mode,
proxy["host"],
int(proxy["port"]),
# socks.py seems to want either None or a non-empty string
username=(proxy.get("user", "") or None),
password=(proxy.get("password", "") or None))
socket.socket = socks.socksocket
# prevent dns leaks, see http://stackoverflow.com/questions/13184205/dns-over-proxy
socket.getaddrinfo = lambda *args: [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))]
else:
socket.socket = socket._socketobject
socket.getaddrinfo = socket._getaddrinfo
await self.start_interface(server)
def start_network(self, protocol, proxy):
self.stopped = False
assert not self.interface and not self.interfaces
assert not self.connecting and self.socket_queue.empty()
assert len(self.connecting) == 0
self.print_error('starting network')
self.disconnected_servers = set([])
self.protocol = protocol
self.set_proxy(proxy)
self.start_interfaces()
self.proxy = proxy
def stop_network(self):
async def stop_network(self):
self.stopped = True
self.print_error("stopping network")
for interface in list(self.interfaces.values()):
self.close_interface(interface)
if self.interface:
self.close_interface(self.interface)
assert self.interface is None
assert not self.interfaces
self.connecting = set()
# Get a new queue - no old pending connections thanks!
self.socket_queue = queue.Queue()
async def stop(interface):
await self.connection_down(interface.server, "stopping network")
await interface.boot_job
stopped_this_time = set()
while self.connecting:
try:
await next(iter(self.connecting.values()))
except asyncio.CancelledError:
# ok since we are shutting down
pass
while self.interfaces:
do_next = next(iter(self.interfaces.values()))
stopped_this_time.add(do_next)
await stop(do_next)
self.process_pending_sends_job.cancel()
await self.process_pending_sends_job
# called from the Qt thread
def set_parameters(self, host, port, protocol, proxy, auto_connect):
proxy_str = serialize_proxy(proxy)
server = serialize_server(host, port, protocol)
@ -459,25 +455,41 @@ class Network(util.DaemonThread):
return
self.auto_connect = auto_connect
if self.proxy != proxy or self.protocol != protocol:
async def job():
try:
async with self.restartLock:
# Restart the network defaulting to the given server
self.stop_network()
await self.stop_network()
self.print_error("STOOOOOOOOOOOOOOOOOOOOOOOOOOPPED")
self.default_server = server
self.disconnected_servers = {}
self.start_network(protocol, proxy)
except BaseException as e:
traceback.print_exc()
self.print_error("exception from restart job")
if self.restartLock.locked():
self.print_error("NOT RESTARTING, RESTART IN PROGRESS")
return
asyncio.run_coroutine_threadsafe(job(), self.loop)
elif self.default_server != server:
self.switch_to_interface(server)
async def job():
await self.switch_to_interface(server)
asyncio.run_coroutine_threadsafe(job(), self.loop)
else:
self.switch_lagging_interface()
async def job():
await self.switch_lagging_interface()
self.notify('updated')
asyncio.run_coroutine_threadsafe(job(), self.loop)
def switch_to_random_interface(self):
async def switch_to_random_interface(self):
'''Switch to a random connected server other than the current one'''
servers = self.get_interfaces() # Those in connected state
if self.default_server in servers:
servers.remove(self.default_server)
if servers:
self.switch_to_interface(random.choice(servers))
await self.switch_to_interface(random.choice(servers))
def switch_lagging_interface(self):
async def switch_lagging_interface(self):
'''If auto_connect and lagging, switch interface'''
if self.server_is_lagging() and self.auto_connect:
# switch to one that has the correct header (not height)
@ -485,9 +497,9 @@ class Network(util.DaemonThread):
filtered = list(map(lambda x:x[0], filter(lambda x: x[1].tip_header==header, self.interfaces.items())))
if filtered:
choice = random.choice(filtered)
self.switch_to_interface(choice)
await self.switch_to_interface(choice)
def switch_to_interface(self, server):
async def switch_to_interface(self, server):
'''Switch to server as our interface. If no connection exists nor
being opened, start a thread to connect. The actual switch will
happen on receipt of the connection notification. Do nothing
@ -495,25 +507,28 @@ class Network(util.DaemonThread):
self.default_server = server
if server not in self.interfaces:
self.interface = None
self.start_interface(server)
await self.start_interface(server)
return
i = self.interfaces[server]
if self.interface != i:
self.print_error("switching to", server)
if self.interface != i and self.num_server != 0:
self.print_error("switching to", server, "from", self.interface.server if self.interface else "None")
# stop any current interface in order to terminate subscriptions
# fixme: we don't want to close headers sub
#self.close_interface(self.interface)
self.interface = i
self.send_subscriptions()
await self.send_subscriptions()
self.set_status('connected')
self.notify('updated')
def close_interface(self, interface):
if interface:
async def close_interface(self, interface):
self.print_error('closing connection', interface.server)
if interface.server in self.interfaces:
self.interfaces.pop(interface.server)
if interface.server == self.default_server:
self.interface = None
assert not interface.boot_job.cancelled()
interface.boot_job.cancel()
await interface.boot_job
interface.close()
def add_recent_server(self, server):
@ -524,7 +539,7 @@ class Network(util.DaemonThread):
self.recent_servers = self.recent_servers[0:20]
self.save_recent_servers()
def process_response(self, interface, response, callbacks):
async def process_response(self, interface, response, callbacks):
if self.debug:
self.print_error("<--", response)
error = response.get('error')
@ -537,7 +552,7 @@ class Network(util.DaemonThread):
interface.server_version = result
elif method == 'blockchain.headers.subscribe':
if error is None:
self.on_notify_header(interface, result)
await self.on_notify_header(interface, result)
elif method == 'server.peers.subscribe':
if error is None:
self.irc_servers = parse_servers(result)
@ -549,11 +564,11 @@ class Network(util.DaemonThread):
elif method == 'server.donation_address':
if error is None:
self.donation_address = result
elif method == 'mempool.get_fee_histogram':
elif method == "mempool.get_fee_histogram":
if error is None:
self.print_error('fee_histogram', result)
self.print_error("fee_histogram", result)
self.config.mempool_fees = result
self.notify('fee_histogram')
self.notify("fee_histogram")
elif method == 'blockchain.estimatefee':
if error is None and result > 0:
i = params[0]
@ -566,20 +581,25 @@ class Network(util.DaemonThread):
self.relay_fee = int(result * COIN)
self.print_error("relayfee", self.relay_fee)
elif method == 'blockchain.block.get_chunk':
self.on_get_chunk(interface, response)
await self.on_get_chunk(interface, response)
elif method == 'blockchain.block.get_header':
self.on_get_header(interface, response)
await self.on_get_header(interface, response)
for callback in callbacks:
if asyncio.iscoroutinefunction(callback):
if response is None:
print("RESPONSE IS NONE")
await callback(response)
else:
callback(response)
def get_index(self, method, params):
""" hashable index for subscriptions and cache"""
return str(method) + (':' + str(params[0]) if params else '')
def process_responses(self, interface):
responses = interface.get_responses()
for request, response in responses:
async def process_responses(self, interface):
while interface.is_running():
request, response = await interface.get_response()
if request:
method, params, message_id = request
k = self.get_index(method, params)
@ -604,8 +624,8 @@ class Network(util.DaemonThread):
self.subscribed_addresses.add(params[0])
else:
if not response: # Closed remotely / misbehaving
self.connection_down(interface.server)
break
if interface.is_running(): await self.connection_down(interface.server, "no response in process responses")
return
# Rewrite response shape to match subscription request response
method = response.get('method')
params = response.get('params')
@ -622,7 +642,9 @@ class Network(util.DaemonThread):
if method.endswith('.subscribe'):
self.sub_cache[k] = response
# Response is now in canonical form
self.process_response(interface, response, callbacks)
await self.process_response(interface, response, callbacks)
await self.run_coroutines() # Synchronizer and Verifier
def addr_to_scripthash(self, addr):
h = bitcoin.address_to_scripthash(addr)
@ -651,20 +673,23 @@ class Network(util.DaemonThread):
def send(self, messages, callback):
'''Messages is a list of (method, params) tuples'''
messages = list(messages)
with self.lock:
self.pending_sends.append((messages, callback))
async def job(future):
await self.pending_sends.put((messages, callback))
if future: future.set_result("put pending send: " + repr(messages))
asyncio.run_coroutine_threadsafe(job(None), self.loop)
def process_pending_sends(self):
async def process_pending_sends(self):
# Requests needs connectivity. If we don't have an interface,
# we cannot process them.
if not self.interface:
await asyncio.sleep(1)
return
with self.lock:
sends = self.pending_sends
self.pending_sends = []
try:
messages, callback = await asyncio.wait_for(self.pending_sends.get(), 1)
except TimeoutError:
return
for messages, callback in sends:
for method, params in messages:
r = None
if method.endswith('.subscribe'):
@ -680,7 +705,7 @@ class Network(util.DaemonThread):
util.print_error("cache hit", k)
callback(r)
else:
message_id = self.queue_request(method, params)
message_id = await self.queue_request(method, params)
self.unanswered_requests[message_id] = method, params, callback
def unsubscribe(self, callback):
@ -693,88 +718,56 @@ class Network(util.DaemonThread):
if callback in v:
v.remove(callback)
def connection_down(self, server):
async def connection_down(self, server, reason=None):
'''A connection to server either went down, or was never made.
We distinguish by whether it is in self.interfaces.'''
self.disconnected_servers.add(server)
if server in self.disconnected_servers:
print("already disconnected " + server + " because " + repr(self.disconnected_servers[server]) + ". new reason: " + repr(reason))
return
self.print_error("connection down", server, reason)
self.disconnected_servers[server] = reason
if server == self.default_server:
self.set_status('disconnected')
if server in self.connecting:
await self.connecting[server]
if server in self.interfaces:
self.close_interface(self.interfaces[server])
await self.close_interface(self.interfaces[server])
self.notify('interfaces')
for b in self.blockchains.values():
if b.catch_up == server:
b.catch_up = None
def new_interface(self, server, socket):
def new_interface(self, server):
# todo: get tip first, then decide which checkpoint to use.
self.add_recent_server(server)
interface = Interface(server, socket)
def interface_exception_handler(exception_future):
try:
raise exception_future.result()
except BaseException as e:
asyncio.ensure_future(self.connection_down(server, "error in interface: " + str(e)))
interface = Interface(server, self.config.path, self.proxy, lambda: not self.stopped and (server in self.connecting or server in self.interfaces), interface_exception_handler)
# this future has its result set when the interface should be considered opened,
# e.g. should be moved from connecting to interfaces
interface.launched = asyncio.Future()
interface.blockchain = None
interface.tip_header = None
interface.tip = 0
interface.mode = 'default'
interface.request = None
self.interfaces[server] = interface
self.queue_request('blockchain.headers.subscribe', [], interface)
if server == self.default_server:
self.switch_to_interface(server)
#self.notify('interfaces')
interface.jobs = None
interface.boot_job = None
self.boot_interface(interface)
assert server not in self.interfaces
assert not self.stopped
return interface
def maintain_sockets(self):
'''Socket maintenance.'''
# Responses to connection attempts?
while not self.socket_queue.empty():
server, socket = self.socket_queue.get()
if server in self.connecting:
self.connecting.remove(server)
if socket:
self.new_interface(server, socket)
else:
self.connection_down(server)
async def request_chunk(self, interface, idx):
interface.print_error("requesting chunk %d" % idx)
await self.queue_request('blockchain.block.get_chunk', [idx], interface)
interface.request = idx
interface.req_time = time.time()
# Send pings and shut down stale interfaces
# must use copy of values
for interface in list(self.interfaces.values()):
if interface.has_timed_out():
self.connection_down(interface.server)
elif interface.ping_required():
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
self.queue_request('server.version', params, interface)
now = time.time()
# nodes
if len(self.interfaces) + len(self.connecting) < self.num_server:
self.start_random_interface()
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
self.print_error('network: retrying connections')
self.disconnected_servers = set([])
self.nodes_retry_time = now
# main interface
if not self.is_connected():
if self.auto_connect:
if not self.is_connecting():
self.switch_to_random_interface()
else:
if self.default_server in self.disconnected_servers:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
self.disconnected_servers.remove(self.default_server)
self.server_retry_time = now
else:
self.switch_to_interface(self.default_server)
else:
if self.config.is_fee_estimates_update_required():
self.request_fee_estimates()
def request_chunk(self, interface, index):
if index in self.requested_chunks:
return
interface.print_error("requesting chunk %d" % index)
self.requested_chunks.add(index)
self.queue_request('blockchain.block.get_chunk', [index], interface)
def on_get_chunk(self, interface, response):
async def on_get_chunk(self, interface, response):
'''Handle receiving a chunk of block headers'''
error = response.get('error')
result = response.get('result')
@ -784,43 +777,36 @@ class Network(util.DaemonThread):
interface.print_error(error or 'bad response')
return
index = params[0]
# Ignore unsolicited chunks
if index not in self.requested_chunks:
interface.print_error("received chunk %d (unsolicited)" % index)
return
else:
interface.print_error("received chunk %d" % index)
self.requested_chunks.remove(index)
connect = blockchain.connect_chunk(index, result)
if not connect:
self.connection_down(interface.server)
await self.connection_down(interface.server, "could not connect chunk")
return
# If not finished, get the next chunk
if index >= len(blockchain.checkpoints) and blockchain.height() < interface.tip:
self.request_chunk(interface, index+1)
if index >= len(blockchain.checkpoints) and interface.blockchain.height() < interface.tip:
await self.request_chunk(interface, index+1)
else:
interface.mode = 'default'
interface.print_error('catch up done', blockchain.height())
blockchain.catch_up = None
self.notify('updated')
def request_header(self, interface, height):
async def request_header(self, interface, height):
#interface.print_error("requesting header %d" % height)
self.queue_request('blockchain.block.get_header', [height], interface)
await self.queue_request('blockchain.block.get_header', [height], interface)
interface.request = height
interface.req_time = time.time()
def on_get_header(self, interface, response):
async def on_get_header(self, interface, response):
'''Handle receiving a single block header'''
header = response.get('result')
if not header:
interface.print_error(response)
self.connection_down(interface.server)
await self.connection_down(interface.server, "no header in on_get_header")
return
height = header.get('block_height')
if interface.request != height:
interface.print_error("unsolicited header",interface.request, height)
self.connection_down(interface.server)
await self.connection_down(interface.server, "unsolicited header")
return
chain = blockchain.check_header(header)
if interface.mode == 'backward':
@ -840,7 +826,7 @@ class Network(util.DaemonThread):
assert next_height >= self.max_checkpoint(), (interface.bad, interface.good)
else:
if height == 0:
self.connection_down(interface.server)
await self.connection_down(interface.server, "height zero in on_get_header")
next_height = None
else:
interface.bad = height
@ -859,7 +845,7 @@ class Network(util.DaemonThread):
next_height = (interface.bad + interface.good) // 2
assert next_height >= self.max_checkpoint()
elif not interface.blockchain.can_connect(interface.bad_header, check_height=False):
self.connection_down(interface.server)
await self.connection_down(interface.server, "blockchain can't connect")
next_height = None
else:
branch = self.blockchains.get(interface.bad)
@ -918,7 +904,7 @@ class Network(util.DaemonThread):
# exit catch_up state
interface.print_error('catch up done', interface.blockchain.height())
interface.blockchain.catch_up = None
self.switch_lagging_interface()
await self.switch_lagging_interface()
self.notify('updated')
else:
@ -926,9 +912,9 @@ class Network(util.DaemonThread):
# If not finished, get the next header
if next_height:
if interface.mode == 'catch_up' and interface.tip > next_height + 50:
self.request_chunk(interface, next_height // 2016)
await self.request_chunk(interface, next_height // 2016)
else:
self.request_header(interface, next_height)
await self.request_header(interface, next_height)
else:
interface.mode = 'default'
interface.request = None
@ -936,34 +922,44 @@ class Network(util.DaemonThread):
# refresh network dialog
self.notify('interfaces')
def maintain_requests(self):
for interface in list(self.interfaces.values()):
if interface.request and time.time() - interface.request_time > 20:
interface.print_error("blockchain request timed out")
self.connection_down(interface.server)
continue
def wait_on_sockets(self):
# Python docs say Windows doesn't like empty selects.
# Sleep to prevent busy looping
if not self.interfaces:
time.sleep(0.1)
return
rin = [i for i in self.interfaces.values()]
win = [i for i in self.interfaces.values() if i.num_requests()]
def make_send_requests_job(self, interface):
async def job():
try:
rout, wout, xout = select.select(rin, win, [], 0.1)
except socket.error as e:
# TODO: py3, get code from e
code = None
if code == errno.EINTR:
return
raise
assert not xout
for interface in wout:
interface.send_requests()
for interface in rout:
self.process_responses(interface)
while True:
# this wait_for is necessary since CancelledError
# doesn't seem to get thrown without it
# when using ssl_in_socks
try:
await asyncio.wait_for(interface.send_request(), 1)
except TimeoutError:
pass
except CancelledError:
pass
except BaseException as e:
await self.connection_down(interface.server, "exp while send_request: " + str(e))
return asyncio.ensure_future(job())
def make_process_responses_job(self, interface):
async def job():
try:
await self.process_responses(interface)
except CancelledError:
pass
except:
await self.connection_down(interface.server, "exp in process_responses")
return asyncio.ensure_future(job())
def make_process_pending_sends_job(self):
async def job():
try:
while True:
await self.process_pending_sends()
except CancelledError:
pass
except:
await self.connection_down(interface.server, "exp in process_pending_sends_job")
return asyncio.ensure_future(job())
def init_headers_file(self):
b = self.blockchains[0]
@ -977,23 +973,116 @@ class Network(util.DaemonThread):
with b.lock:
b.update_size()
def run(self):
self.init_headers_file()
while self.is_running():
self.maintain_sockets()
self.wait_on_sockets()
self.maintain_requests()
self.run_jobs() # Synchronizer and Verifier
self.process_pending_sends()
self.stop_network()
self.on_stop()
def boot_interface(self, interface):
async def job():
try:
interface.jobs = [self.make_send_requests_job(interface)] # we need this job to process the request queued below
await asyncio.wait_for(self.queue_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION], interface), 10)
handshakeResult = await asyncio.wait_for(interface.get_response(), 10)
if not interface.is_running():
self.print_error("WARNING: quitting bootjob instead of handling CancelledError!")
return
await self.queue_request('blockchain.headers.subscribe', [], interface)
if interface.server == self.default_server:
await asyncio.wait_for(self.switch_to_interface(interface.server), 5)
interface.jobs += [self.make_process_responses_job(interface)]
gathered = asyncio.gather(*interface.jobs)
interface.launched.set_result(handshakeResult)
while interface.is_running():
try:
await asyncio.wait_for(asyncio.shield(gathered), 1)
except TimeoutError:
pass
if interface.has_timed_out():
self.print_error(interface.server, "timed out")
await self.connection_down(interface.server, "time out in ping_job")
break
elif interface.ping_required():
params = [ELECTRUM_VERSION, PROTOCOL_VERSION]
await self.queue_request('server.version', params, interface)
#self.notify('interfaces')
except TimeoutError:
asyncio.ensure_future(self.connection_down(interface.server, "timeout in boot_interface while getting response"))
except CancelledError:
pass
except:
traceback.print_exc()
finally:
for i in interface.jobs: i.cancel()
if not interface.launched.done(): interface.launched.set_result(None)
interface.boot_job = asyncio.ensure_future(job())
interface.boot_job.server = interface.server
def on_notify_header(self, interface, header):
async def maintain_interfaces(self):
if self.stopped:
return
now = time.time()
# nodes
if len(self.interfaces) + len(self.connecting) < self.num_server:
# ensure future so that servers can be connected to in parallel
asyncio.ensure_future(self.start_random_interface())
if now - self.nodes_retry_time > NODES_RETRY_INTERVAL:
self.print_error('network: retrying connections')
self.disconnected_servers = {}
self.nodes_retry_time = now
# main interface
if not self.is_connected():
if self.auto_connect:
if not self.is_connecting():
await self.switch_to_random_interface()
else:
if self.default_server in self.disconnected_servers:
if now - self.server_retry_time > SERVER_RETRY_INTERVAL:
del self.disconnected_servers[self.default_server]
self.server_retry_time = now
else:
await self.switch_to_interface(self.default_server)
else:
if self.config.is_fee_estimates_update_required():
await self.request_fee_estimates()
def run(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) # this does not set the loop on the qt thread
self.loop = loop # so we store it in the instance too
self.init_headers_file()
self.pending_sends = asyncio.Queue()
self.restartLock = asyncio.Lock()
self.start_network(deserialize_server(self.default_server)[2], deserialize_proxy(self.config.get('proxy')))
self.process_pending_sends_job = self.make_process_pending_sends_job()
run_future = asyncio.Future()
self.run_forever_coroutines()
asyncio.ensure_future(self.run_async(run_future))
loop.run_until_complete(run_future)
loop.run_until_complete(self.forever_coroutines_task)
run_future.exception()
# this would provoke errors of unfinished futures
# since we havn't stopped network.
# (fast shutdown is more important)
#loop.close()
async def run_async(self, future):
try:
while self.is_running():
await self.maintain_interfaces()
self.run_jobs()
await asyncio.sleep(1)
# not shutting down network if we are going to quit anyway, the OS can clean up
#await self.stop_network()
self.on_stop()
future.set_result("run_async done")
except BaseException as e:
future.set_exception(e)
async def on_notify_header(self, interface, header):
height = header.get('block_height')
if not height:
return
if height < self.max_checkpoint():
self.connection_down(interface.server)
await self.connection_down(interface.server, "height under max checkpoint in on_notify_header")
return
interface.tip_header = header
interface.tip = height
@ -1002,7 +1091,7 @@ class Network(util.DaemonThread):
b = blockchain.check_header(header)
if b:
interface.blockchain = b
self.switch_lagging_interface()
await self.switch_lagging_interface()
self.notify('updated')
self.notify('interfaces')
return
@ -1010,7 +1099,7 @@ class Network(util.DaemonThread):
if b:
interface.blockchain = b
b.save_header(header)
self.switch_lagging_interface()
await self.switch_lagging_interface()
self.notify('updated')
self.notify('interfaces')
return
@ -1019,17 +1108,14 @@ class Network(util.DaemonThread):
interface.mode = 'backward'
interface.bad = height
interface.bad_header = header
self.request_header(interface, min(tip +1, height - 1))
await self.request_header(interface, min(tip + 1, height - 1))
else:
chain = self.blockchains[0]
if chain.catch_up is None:
chain.catch_up = interface
interface.mode = 'catch_up'
interface.blockchain = chain
self.print_error("switching to catchup mode", tip, self.blockchains)
self.request_header(interface, 0)
else:
self.print_error("chain already catching up with", chain.catch_up.server)
await self.request_header(interface, 0)
def blockchain(self):
if self.interface and self.interface.blockchain is not None:
@ -1044,6 +1130,7 @@ class Network(util.DaemonThread):
out[k] = r
return out
# called from the Qt thread
def follow_chain(self, index):
blockchain = self.blockchains.get(index)
if blockchain:
@ -1051,16 +1138,19 @@ class Network(util.DaemonThread):
self.config.set_key('blockchain_index', index)
for i in self.interfaces.values():
if i.blockchain == blockchain:
self.switch_to_interface(i.server)
asyncio.run_coroutine_threadsafe(self.switch_to_interface(i.server), self.loop)
break
else:
raise BaseException('blockchain not found', index)
if self.interface:
server = self.interface.server
host, port, protocol, proxy, auto_connect = self.get_parameters()
host, port, protocol = server.split(':')
self.set_parameters(host, port, protocol, proxy, auto_connect)
# commented out on migration to asyncio. not clear if it
# relies on the coroutine to be done:
#if self.interface:
# server = self.interface.server
# host, port, protocol, proxy, auto_connect = self.get_parameters()
# host, port, protocol = server.split(':')
# self.set_parameters(host, port, protocol, proxy, auto_connect)
def get_local_height(self):
return self.blockchain().height()
@ -1094,3 +1184,34 @@ class Network(util.DaemonThread):
def max_checkpoint(self):
return max(0, len(constants.net.CHECKPOINTS) * 2016 - 1)
async def send_async(self, messages, callback=None):
""" if callback is None, it returns the result """
chosenCallback = callback
if callback is None:
queue = asyncio.Queue()
chosenCallback = queue.put
assert type(messages[0]) is tuple and len(messages[0]) == 2, repr(messages) + " does not contain a pair-tuple in first position"
await self.pending_sends.put((messages, chosenCallback))
if callback is None:
#assert queue.qsize() == 1, "queue does not have a single result, it has length " + str(queue.qsize())
return await asyncio.wait_for(queue.get(), 5)
async def asynchronous_get(self, request):
assert type(request) is tuple
assert type(request[1]) is list
res = await self.send_async([request])
try:
return res.get("result")
except:
print("asynchronous_get could not get result from", res)
raise BaseException("Could not get result: " + repr(res))
async def broadcast_async(self, tx):
tx_hash = tx.txid()
try:
return True, await self.asynchronous_get(('blockchain.transaction.broadcast', [str(tx)]))
except BaseException as e:
traceback.print_exc()
print("previous trace was captured and printed in broadcast_async")
return False, str(e)

85
lib/ssl_in_socks.py Normal file
View File

@ -0,0 +1,85 @@
import traceback
import ssl
from asyncio.sslproto import SSLProtocol
import aiosocks
import asyncio
from . import interface
class AppProto(asyncio.Protocol):
def __init__(self, receivedQueue, connUpLock):
self.buf = bytearray()
self.receivedQueue = receivedQueue
self.connUpLock = connUpLock
def connection_made(self, transport):
self.connUpLock.release()
def data_received(self, data):
self.buf.extend(data)
NEWLINE = b"\n"[0]
for idx, val in enumerate(self.buf):
if NEWLINE == val:
asyncio.ensure_future(self.receivedQueue.put(bytes(self.buf[:idx+1])))
self.buf = self.buf[idx+1:]
def makeProtocolFactory(receivedQueue, connUpLock, ca_certs):
class MySSLProtocol(SSLProtocol):
def __init__(self):
context = interface.get_ssl_context(\
cert_reqs=ssl.CERT_REQUIRED if ca_certs is not None else ssl.CERT_NONE,\
ca_certs=ca_certs)
proto = AppProto(receivedQueue, connUpLock)
super().__init__(asyncio.get_event_loop(), proto, context, None)
return MySSLProtocol
class ReaderEmulator:
def __init__(self, receivedQueue):
self.receivedQueue = receivedQueue
async def read(self, _bufferSize):
return await self.receivedQueue.get()
class WriterEmulator:
def __init__(self, transport):
self.transport = transport
def write(self, data):
self.transport.write(data)
async def drain(self):
pass
def close(self):
self.transport.close()
async def sslInSocksReaderWriter(socksAddr, socksAuth, host, port, ca_certs):
receivedQueue = asyncio.Queue()
connUpLock = asyncio.Lock()
await connUpLock.acquire()
transport, protocol = await aiosocks.create_connection(\
makeProtocolFactory(receivedQueue, connUpLock, ca_certs),\
proxy=socksAddr,\
proxy_auth=socksAuth, dst=(host, port))
await connUpLock.acquire()
return ReaderEmulator(receivedQueue), WriterEmulator(protocol._app_transport)
if __name__ == "__main__":
async def l(fut):
try:
# aiosocks.Socks4Addr("127.0.0.1", 9050), None, "songbird.bauerj.eu", 50002, None)
args = aiosocks.Socks4Addr("127.0.0.1", 9050), None, "electrum.akinbo.org", 51002, None
reader, writer = await sslInSocksReaderWriter(*args)
writer.write(b'{"id":0,"method":"server.version","args":["3.0.2", "1.1"]}\n')
await writer.drain()
print(await reader.read(4096))
writer.write(b'{"id":0,"method":"server.version","args":["3.0.2", "1.1"]}\n')
await writer.drain()
print(await reader.read(4096))
writer.close()
fut.set_result("finished")
except BaseException as e:
fut.set_exception(e)
def f():
loop = asyncio.get_event_loop()
fut = asyncio.Future()
asyncio.ensure_future(l(fut))
loop.run_until_complete(fut)
print(fut.result())
loop.close()
f()

View File

@ -29,6 +29,7 @@ import traceback
import urllib
import threading
import hmac
import asyncio
from .i18n import _
@ -148,6 +149,24 @@ class PrintError(object):
def print_msg(self, *msg):
print_msg("[%s]" % self.diagnostic_name(), *msg)
class ForeverCoroutineJob(PrintError):
"""A job that is run from a thread's main loop. run() is
called from that thread's context.
"""
async def run(self, is_running):
"""Called once from the thread"""
pass
class CoroutineJob(PrintError):
"""A job that is run periodically from a thread's main loop. run() is
called from that thread's context.
"""
async def run(self):
"""Called periodically from the thread"""
pass
class ThreadJob(PrintError):
"""A job that is run periodically from a thread's main loop. run() is
called from that thread's context.
@ -192,6 +211,44 @@ class DaemonThread(threading.Thread, PrintError):
self.running_lock = threading.Lock()
self.job_lock = threading.Lock()
self.jobs = []
self.coroutines = []
self.forever_coroutines_task = None
def add_coroutines(self, jobs):
for i in jobs: assert isinstance(i, CoroutineJob), i.__class__.__name__ + " does not inherit from CoroutineJob"
self.coroutines.extend(jobs)
def set_forever_coroutines(self, jobs):
for i in jobs: assert isinstance(i, ForeverCoroutineJob), i.__class__.__name__ + " does not inherit from ForeverCoroutineJob"
async def put():
await self.forever_coroutines_queue.put(jobs)
asyncio.run_coroutine_threadsafe(put(), self.loop)
def run_forever_coroutines(self):
self.forever_coroutines_queue = asyncio.Queue() # making queue here because __init__ is called from non-network thread
self.loop = asyncio.get_event_loop()
async def getFromQueueAndStart():
jobs = []
while True:
try:
jobs = await asyncio.wait_for(self.forever_coroutines_queue.get(), 1)
break
except asyncio.TimeoutError:
if not self.is_running(): break
continue
await asyncio.gather(*[i.run(self.is_running) for i in jobs])
self.forever_coroutines_task = asyncio.ensure_future(getFromQueueAndStart())
return self.forever_coroutines_task
async def run_coroutines(self):
for coroutine in self.coroutines:
assert isinstance(coroutine, CoroutineJob)
await coroutine.run()
def remove_coroutines(self, jobs):
for i in jobs: assert isinstance(i, CoroutineJob)
for job in jobs:
self.coroutines.remove(job)
def add_jobs(self, jobs):
with self.job_lock:
@ -203,6 +260,7 @@ class DaemonThread(threading.Thread, PrintError):
# malformed or malicious server responses
with self.job_lock:
for job in self.jobs:
assert isinstance(job, ThreadJob)
try:
job.run()
except Exception as e: