From 416834185700483136b3ea1ef9895d04c86eac33 Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Sat, 14 Jul 2018 23:35:54 +0800 Subject: [PATCH] Implement a markle cache with tests --- electrumx/lib/merkle.py | 81 +++++++++++++++++- tests/lib/test_merkle.py | 177 +++++++++++++++++++++++++++++---------- 2 files changed, 210 insertions(+), 48 deletions(-) diff --git a/electrumx/lib/merkle.py b/electrumx/lib/merkle.py index 731b8d0..65c4089 100644 --- a/electrumx/lib/merkle.py +++ b/electrumx/lib/merkle.py @@ -121,7 +121,8 @@ class Merkle(object): return [root(hashes[n: n + size], depth_higher) for n in range(0, len(hashes), size)] - def branch_from_level(self, level, leaf_hashes, index, depth_higher): + def branch_and_root_from_level(self, level, leaf_hashes, index, + depth_higher): '''Return a (merkle branch, merkle_root) pair when a merkle-tree has a level cached. @@ -142,7 +143,7 @@ class Merkle(object): if not isinstance(level, list): raise TypeError("level must be a list") if not isinstance(leaf_hashes, list): - raise TypeError("level must be a list") + raise TypeError("leaf_hashes must be a list") leaf_index = (index >> depth_higher) << depth_higher leaf_branch, leaf_root = self.branch_and_root( leaf_hashes, index - leaf_index, depth_higher) @@ -152,3 +153,79 @@ class Merkle(object): if leaf_root != level[index]: raise ValueError('leaf hashes inconsistent with level') return leaf_branch + level_branch, root + + +class MerkleCache(object): + '''A cache to calculate merkle branches efficiently.''' + + def __init__(self, merkle, source, length): + '''Initialise a cache of length hashes taken from source.''' + self.merkle = merkle + self.source = source + self.length = length + self.depth_higher = merkle.tree_depth(length) // 2 + self.level = self._level(source.hashes(0, length)) + + def _segment_length(self): + return 1 << self.depth_higher + + def _leaf_start(self, index): + '''Given a level's depth higher and a hash index, return the leaf + index and leaf hash count needed to calculate a merkle branch. + ''' + depth_higher = self.depth_higher + return (index >> depth_higher) << depth_higher + + def _level(self, hashes): + return self.merkle.level(hashes, self.depth_higher) + + def _extend_to(self, length): + '''Extend the length of the cache if necessary.''' + if length <= self.length: + return + # Start from the beginning of any final partial segment. + # Retain the value of depth_higher; in practice this is fine + start = self._leaf_start(self.length) + hashes = self.source.hashes(start, length - start) + self.level[start >> self.depth_higher] = self._level(hashes) + self.length = length + + def _level_for(self, length): + '''Return a (level_length, final_hash) pair for a truncation + of the hashes to the given length. Length may be an extension, + in which case extra hashes are requested from the source.''' + if length > self.length: + hashes = self.source.hashes(self.length, length - self.length) + return self.level + self._level(hashes) + if length < self.length: + level = self.level[:length >> self.depth_higher] + leaf_start = self._leaf_start(length) + count = min(self._segment_length(), length - leaf_start) + hashes = self.source.hashes(leaf_start, count) + level += self._level(hashes) + return level + return self.level + + def branch_and_root(self, length, index): + '''Return a merkle branch and root. Length is the number of + hashes used to calculate the merkle root, index is the position + of the hash to calculate the branch of. + + index must be less than length, which must be at least 1.''' + if not isinstance(length, int): + raise TypeError('length must be an integer') + if not isinstance(index, int): + raise TypeError('index must be an integer') + if length <= 0: + raise ValueError('length must be positive') + if index >= length: + raise ValueError('index must be less than length') + self._extend_to(length) + leaf_start = self._leaf_start(index) + count = min(self._segment_length(), length - leaf_start) + leaf_hashes = self.source.hashes(leaf_start, count) + if length < self._segment_length(): + return self.merkle.branch_and_root(leaf_hashes, index) + level = self._level_for(length) + return self.merkle.branch_and_root_from_level( + level, leaf_hashes, index, self.depth_higher) diff --git a/tests/lib/test_merkle.py b/tests/lib/test_merkle.py index 2d8e4e5..c860020 100644 --- a/tests/lib/test_merkle.py +++ b/tests/lib/test_merkle.py @@ -1,10 +1,11 @@ +import os import pytest -from electrumx.lib.merkle import Merkle +from electrumx.lib.merkle import Merkle, MerkleCache -Merkle = Merkle() -hashes = [Merkle.hash_func(bytes([x])) for x in range(8)] +merkle = Merkle() +hashes = [merkle.hash_func(bytes([x])) for x in range(8)] roots = [ b'\x14\x06\xe0X\x81\xe2\x996wf\xd3\x13\xe2l\x05VN\xc9\x1b\xf7!\xd3\x17&\xbdnF\xe6\x06\x89S\x9a', b'K\xbe\x83\xbc8\xeb\xe2\xbc\xc7R\r#A9\xdf\x1c\x0e\xb9\xff\xa5\x1f\x83\xea\xb1\xc5\x12\x9b[\x90kvU', @@ -19,125 +20,209 @@ roots = [ def test_branch_length(): - assert Merkle.branch_length(1) == 0 - assert Merkle.branch_length(2) == 1 + assert merkle.branch_length(1) == 0 + assert merkle.branch_length(2) == 1 for n in range(3, 5): - assert Merkle.branch_length(n) == 2 + assert merkle.branch_length(n) == 2 for n in range(5, 9): - assert Merkle.branch_length(n) == 3 + assert merkle.branch_length(n) == 3 def test_branch_length_bad(): with pytest.raises(TypeError): - Merkle.branch_length(1.0) + merkle.branch_length(1.0) for n in (-1, 0): with pytest.raises(ValueError): - Merkle.branch_length(n) + merkle.branch_length(n) def test_tree_depth(): for n in range(1, 10): - assert Merkle.tree_depth(n) == Merkle.branch_length(n) + 1 + assert merkle.tree_depth(n) == merkle.branch_length(n) + 1 def test_root(): for n in range(len(hashes)): - assert Merkle.root(hashes[:n + 1]) == roots[n] + assert merkle.root(hashes[:n + 1]) == roots[n] def test_root_bad(): with pytest.raises(TypeError): - Merkle.root(0) + merkle.root(0) with pytest.raises(ValueError): - Merkle.root([]) + merkle.root([]) def test_branch_and_root_from_proof(): for n in range(len(hashes)): for m in range(n + 1): - branch, root = Merkle.branch_and_root(hashes[:n + 1], m) + branch, root = merkle.branch_and_root(hashes[:n + 1], m) assert root == roots[n] - root = Merkle.root_from_proof(hashes[m], branch, m) + root = merkle.root_from_proof(hashes[m], branch, m) assert root == roots[n] def test_branch_bad(): with pytest.raises(TypeError): - Merkle.branch_and_root(0, 0) + merkle.branch_and_root(0, 0) with pytest.raises(ValueError): - Merkle.branch_and_root([], 0) + merkle.branch_and_root([], 0) with pytest.raises(TypeError): - Merkle.branch_and_root(hashes, 0.0) + merkle.branch_and_root(hashes, 0.0) with pytest.raises(ValueError): - Merkle.branch_and_root(hashes[:2], -1) + merkle.branch_and_root(hashes[:2], -1) with pytest.raises(ValueError): - Merkle.branch_and_root(hashes[:2], 2) - Merkle.branch_and_root(hashes, 0, 3) + merkle.branch_and_root(hashes[:2], 2) + merkle.branch_and_root(hashes, 0, 3) with pytest.raises(TypeError): - Merkle.branch_and_root(hashes, 0, 3.0) + merkle.branch_and_root(hashes, 0, 3.0) with pytest.raises(ValueError): - Merkle.branch_and_root(hashes, 0, 2) + merkle.branch_and_root(hashes, 0, 2) def test_root_from_proof_bad(): with pytest.raises(TypeError): - Merkle.root_from_proof(0, hashes[:2], 0) + merkle.root_from_proof(0, hashes[:2], 0) with pytest.raises(TypeError): - Merkle.root_from_proof(hashes[0], hashes[0], 0) + merkle.root_from_proof(hashes[0], hashes[0], 0) with pytest.raises(ValueError): - Merkle.root_from_proof(hashes[0], hashes[:3], -1) + merkle.root_from_proof(hashes[0], hashes[:3], -1) with pytest.raises(ValueError): - Merkle.root_from_proof(hashes[0], hashes[:3], 8) + merkle.root_from_proof(hashes[0], hashes[:3], 8) def test_level(): for n in range(len(hashes)): - depth = Merkle.tree_depth(n + 1) + depth = merkle.tree_depth(n + 1) for depth_higher in range(0, depth): - level = Merkle.level(hashes[:n + 1], depth_higher) + level = merkle.level(hashes[:n + 1], depth_higher) if depth_higher == 0: assert level == hashes[:n + 1] if depth_higher == depth: assert level == [roots[n]] # Check raising from level to root works - assert Merkle.root(level) == roots[n] + assert merkle.root(level) == roots[n] -def test_branch_from_level(): +def test_branch_and_root_from_level(): # For all sub-trees for n in range(0, len(hashes)): part = hashes[:n + 1] # For all depths in sub-tree - for depth_higher in range(0, Merkle.tree_depth(len(part))): - level = Merkle.level(part, depth_higher) + for depth_higher in range(0, merkle.tree_depth(len(part))): + level = merkle.level(part, depth_higher) # For each hash in sub-tree for index, hash in enumerate(part): leaf_index = (index >> depth_higher) << depth_higher leaf_hashes = part[leaf_index: leaf_index + (1 << depth_higher)] - branch = Merkle.branch_and_root(part, index) - branch2 = Merkle.branch_from_level(level, leaf_hashes, - index, depth_higher) + branch = merkle.branch_and_root(part, index) + branch2 = merkle.branch_and_root_from_level( + level, leaf_hashes, index, depth_higher) assert branch == branch2 -def test_branch_from_level_bad(): +def test_branch_and_root_from_level_bad(): with pytest.raises(TypeError): - Merkle.branch_from_level(hashes[0], hashes, 0, 0) + merkle.branch_and_root_from_level(hashes[0], hashes, 0, 0) with pytest.raises(TypeError): - Merkle.branch_from_level(hashes, hashes[0], 0, 0) - Merkle.branch_from_level(hashes, [hashes[0]], 0, 0) + merkle.branch_and_root_from_level(hashes, hashes[0], 0, 0) + merkle.branch_and_root_from_level(hashes, [hashes[0]], 0, 0) with pytest.raises(ValueError): - Merkle.branch_from_level(hashes, [hashes[0]], -1, 0) + merkle.branch_and_root_from_level(hashes, [hashes[0]], -1, 0) with pytest.raises(TypeError): - Merkle.branch_from_level(hashes, hashes, 0.0, 0) + merkle.branch_and_root_from_level(hashes, hashes, 0.0, 0) with pytest.raises(ValueError): - Merkle.branch_from_level(hashes, [hashes[0]], 0, -1) + merkle.branch_and_root_from_level(hashes, [hashes[0]], 0, -1) with pytest.raises(ValueError): - Merkle.branch_from_level(hashes, [hashes[0]], 0, 1) + merkle.branch_and_root_from_level(hashes, [hashes[0]], 0, 1) with pytest.raises(ValueError): # Inconsistent hash - Merkle.branch_from_level(hashes, [hashes[1]], 0, 0) + merkle.branch_and_root_from_level(hashes, [hashes[1]], 0, 0) with pytest.raises(ValueError): # Inconsistent hash - Merkle.branch_from_level(hashes, [hashes[0]], 1, 0) + merkle.branch_and_root_from_level(hashes, [hashes[0]], 1, 0) + + +class Source(object): + + def __init__(self, length): + self._hashes = [os.urandom(32) for _ in range(length)] + + def hashes(self, start, length): + assert start >= 0 + assert start + length <= len(self._hashes) + return self._hashes[start: start + length] + + +def test_merkle_cache(): + lengths = (*range(1, 18), 31, 32, 33, 57) + source = Source(max(lengths)) + for length in lengths: + cache = MerkleCache(merkle, source, length) + # Simulate all possible checkpoints + for cp_length in range(1, length + 1): + cp_hashes = source.hashes(0, cp_length) + # All possible indices + for index in range(cp_length): + # Compare correct answer with cache + branch, root = merkle.branch_and_root(cp_hashes, index) + branch2, root2 = cache.branch_and_root(cp_length, index) + assert branch == branch2 + assert root == root2 + + +def merkle_cache_extension(): + source = Source(64) + for length in range(14, 18): + for cp_length in range(30, 36): + cache = MerkleCache(merkle, source, length) + cp_hashes = source.hashes(0, cp_length) + # All possible indices + for index in range(cp_length): + # Compare correct answer with cache + branch, root = merkle.branch_and_root(cp_hashes, index) + branch2, root2 = cache.branch_and_root(cp_length, index) + assert branch == branch2 + assert root == root2 + + +def test_markle_cache_bad(): + length = 23 + source = Source(length) + cache = MerkleCache(merkle, source, length) + cache.branch_and_root(5, 3) + with pytest.raises(TypeError): + cache.branch_and_root(5.0, 3) + with pytest.raises(TypeError): + cache.branch_and_root(5, 3.0) + with pytest.raises(ValueError): + cache.branch_and_root(0, -1) + with pytest.raises(ValueError): + cache.branch_and_root(3, 3) + +def test_bad_extension(): + length = 5 + source = Source(length) + cache = MerkleCache(merkle, source, length) + level = cache.level + with pytest.raises(AssertionError): + cache.branch_and_root(8, 0) + # The bad extension should not destroy the cache + assert cache.level == level + assert cache.length == length + + +def time_it(): + source = Source(500000) + import time + cache = MerkleCache(merkle, source) + cp_length = 492000 + cp_hashes = source.hashes(0, cp_length) + brs2 = [] + t1 = time.time() + for index in range(5, 400000, 500): + brs2.append(cache.branch_and_root(cp_length, index)) + t2 = time.time() + print(t2 - t1) + assert False