From 9c25685eb99856085f0031c5ba7df27aecf0abac Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Wed, 6 Sep 2017 17:11:18 +0900 Subject: [PATCH] Handle client protocol range requests. Add more tests. --- lib/util.py | 24 ++++++++++++++++++++++++ server/session.py | 21 ++++++++++----------- tests/lib/test_util.py | 17 +++++++++++++++++ 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/lib/util.py b/lib/util.py index 697569d..f3887ee 100644 --- a/lib/util.py +++ b/lib/util.py @@ -278,3 +278,27 @@ def protocol_tuple(s): return tuple(int(part) for part in s.split('.')) except Exception: return (0, ) + +def protocol_version(client_req, server_min, server_max): + '''Given a client protocol request, return the protocol version + to use as a tuple. + + If a mutually acceptable protocol version does not exist, return None. + ''' + if isinstance(client_req, list) and len(client_req) == 2: + client_min, client_max = client_req + elif client_req is None: + client_min = client_max = server_min + else: + client_min = client_max = client_req + + client_min = protocol_tuple(client_min) + client_max = protocol_tuple(client_max) + server_min = protocol_tuple(server_min) + server_max = protocol_tuple(server_max) + + result = min(client_max, server_max) + if result < max(client_min, server_min) or result == (0, ): + result = None + + return result diff --git a/server/session.py b/server/session.py index ae75377..e6760d5 100644 --- a/server/session.py +++ b/server/session.py @@ -367,19 +367,18 @@ class ElectrumX(SessionBase): return message - def set_protocol_handlers(self, version_str): - controller = self.controller - if version_str is None: - version_str = version.PROTOCOL_MIN - ptuple = util.protocol_tuple(version_str) - # Disconnect if requested protocol version in unsupported - if (ptuple < util.protocol_tuple(version.PROTOCOL_MIN) - or ptuple > util.protocol_tuple(version.PROTOCOL_MAX)): - self.log_info('unsupported protocol version {}' - .format(version_str)) + def set_protocol_handlers(self, version_req): + # Find the highest common protocol version. Disconnect if + # that protocol version in unsupported. + ptuple = util.protocol_version(version_req, version.PROTOCOL_MIN, + version.PROTOCOL_MAX) + if ptuple is None: + self.log_info('unsupported protocol version request {}' + .format(version_req)) raise RPCError('unsupported protocol version: {}' - .format(version_str), JSONRPC.FATAL_ERROR) + .format(version_req), JSONRPC.FATAL_ERROR) + controller = self.controller handlers = { 'blockchain.address.get_balance': controller.address_get_balance, 'blockchain.address.get_history': controller.address_get_history, diff --git a/tests/lib/test_util.py b/tests/lib/test_util.py index cd9d14d..02780b4 100644 --- a/tests/lib/test_util.py +++ b/tests/lib/test_util.py @@ -93,3 +93,20 @@ def test_protocol_tuple(): assert util.protocol_tuple("0.1") == (0, 1) assert util.protocol_tuple("0.10") == (0, 10) assert util.protocol_tuple("2.5.3") == (2, 5, 3) + +def test_protocol_version(): + assert util.protocol_version(None, "1.0", "1.0") == (1, 0) + assert util.protocol_version("0.10", "0.10", "1.1") == (0, 10) + + assert util.protocol_version("1.0", "1.0", "1.0") == (1, 0) + assert util.protocol_version("1.0", "1.0", "1.1") == (1, 0) + assert util.protocol_version("1.1", "1.0", "1.1") == (1, 1) + assert util.protocol_version("1.2", "1.0", "1.1") is None + assert util.protocol_version("0.9", "1.0", "1.1") is None + + assert util.protocol_version(["0.9", "1.0"], "1.0", "1.1") == (1, 0) + assert util.protocol_version(["0.9", "1.1"], "1.0", "1.1") == (1, 1) + assert util.protocol_version(["1.1", "0.9"], "1.0", "1.1") is None + assert util.protocol_version(["0.8", "0.9"], "1.0", "1.1") is None + assert util.protocol_version(["1.1", "1.2"], "1.0", "1.1") == (1, 1) + assert util.protocol_version(["1.2", "1.3"], "1.0", "1.1") is None