Implement a markle cache with tests
This commit is contained in:
parent
997a1be377
commit
4168341857
@ -121,7 +121,8 @@ class Merkle(object):
|
||||
return [root(hashes[n: n + size], depth_higher)
|
||||
for n in range(0, len(hashes), size)]
|
||||
|
||||
def branch_from_level(self, level, leaf_hashes, index, depth_higher):
|
||||
def branch_and_root_from_level(self, level, leaf_hashes, index,
|
||||
depth_higher):
|
||||
'''Return a (merkle branch, merkle_root) pair when a merkle-tree has a
|
||||
level cached.
|
||||
|
||||
@ -142,7 +143,7 @@ class Merkle(object):
|
||||
if not isinstance(level, list):
|
||||
raise TypeError("level must be a list")
|
||||
if not isinstance(leaf_hashes, list):
|
||||
raise TypeError("level must be a list")
|
||||
raise TypeError("leaf_hashes must be a list")
|
||||
leaf_index = (index >> depth_higher) << depth_higher
|
||||
leaf_branch, leaf_root = self.branch_and_root(
|
||||
leaf_hashes, index - leaf_index, depth_higher)
|
||||
@ -152,3 +153,79 @@ class Merkle(object):
|
||||
if leaf_root != level[index]:
|
||||
raise ValueError('leaf hashes inconsistent with level')
|
||||
return leaf_branch + level_branch, root
|
||||
|
||||
|
||||
class MerkleCache(object):
|
||||
'''A cache to calculate merkle branches efficiently.'''
|
||||
|
||||
def __init__(self, merkle, source, length):
|
||||
'''Initialise a cache of length hashes taken from source.'''
|
||||
self.merkle = merkle
|
||||
self.source = source
|
||||
self.length = length
|
||||
self.depth_higher = merkle.tree_depth(length) // 2
|
||||
self.level = self._level(source.hashes(0, length))
|
||||
|
||||
def _segment_length(self):
|
||||
return 1 << self.depth_higher
|
||||
|
||||
def _leaf_start(self, index):
|
||||
'''Given a level's depth higher and a hash index, return the leaf
|
||||
index and leaf hash count needed to calculate a merkle branch.
|
||||
'''
|
||||
depth_higher = self.depth_higher
|
||||
return (index >> depth_higher) << depth_higher
|
||||
|
||||
def _level(self, hashes):
|
||||
return self.merkle.level(hashes, self.depth_higher)
|
||||
|
||||
def _extend_to(self, length):
|
||||
'''Extend the length of the cache if necessary.'''
|
||||
if length <= self.length:
|
||||
return
|
||||
# Start from the beginning of any final partial segment.
|
||||
# Retain the value of depth_higher; in practice this is fine
|
||||
start = self._leaf_start(self.length)
|
||||
hashes = self.source.hashes(start, length - start)
|
||||
self.level[start >> self.depth_higher] = self._level(hashes)
|
||||
self.length = length
|
||||
|
||||
def _level_for(self, length):
|
||||
'''Return a (level_length, final_hash) pair for a truncation
|
||||
of the hashes to the given length. Length may be an extension,
|
||||
in which case extra hashes are requested from the source.'''
|
||||
if length > self.length:
|
||||
hashes = self.source.hashes(self.length, length - self.length)
|
||||
return self.level + self._level(hashes)
|
||||
if length < self.length:
|
||||
level = self.level[:length >> self.depth_higher]
|
||||
leaf_start = self._leaf_start(length)
|
||||
count = min(self._segment_length(), length - leaf_start)
|
||||
hashes = self.source.hashes(leaf_start, count)
|
||||
level += self._level(hashes)
|
||||
return level
|
||||
return self.level
|
||||
|
||||
def branch_and_root(self, length, index):
|
||||
'''Return a merkle branch and root. Length is the number of
|
||||
hashes used to calculate the merkle root, index is the position
|
||||
of the hash to calculate the branch of.
|
||||
|
||||
index must be less than length, which must be at least 1.'''
|
||||
if not isinstance(length, int):
|
||||
raise TypeError('length must be an integer')
|
||||
if not isinstance(index, int):
|
||||
raise TypeError('index must be an integer')
|
||||
if length <= 0:
|
||||
raise ValueError('length must be positive')
|
||||
if index >= length:
|
||||
raise ValueError('index must be less than length')
|
||||
self._extend_to(length)
|
||||
leaf_start = self._leaf_start(index)
|
||||
count = min(self._segment_length(), length - leaf_start)
|
||||
leaf_hashes = self.source.hashes(leaf_start, count)
|
||||
if length < self._segment_length():
|
||||
return self.merkle.branch_and_root(leaf_hashes, index)
|
||||
level = self._level_for(length)
|
||||
return self.merkle.branch_and_root_from_level(
|
||||
level, leaf_hashes, index, self.depth_higher)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from electrumx.lib.merkle import Merkle
|
||||
from electrumx.lib.merkle import Merkle, MerkleCache
|
||||
|
||||
|
||||
Merkle = Merkle()
|
||||
hashes = [Merkle.hash_func(bytes([x])) for x in range(8)]
|
||||
merkle = Merkle()
|
||||
hashes = [merkle.hash_func(bytes([x])) for x in range(8)]
|
||||
roots = [
|
||||
b'\x14\x06\xe0X\x81\xe2\x996wf\xd3\x13\xe2l\x05VN\xc9\x1b\xf7!\xd3\x17&\xbdnF\xe6\x06\x89S\x9a',
|
||||
b'K\xbe\x83\xbc8\xeb\xe2\xbc\xc7R\r#A9\xdf\x1c\x0e\xb9\xff\xa5\x1f\x83\xea\xb1\xc5\x12\x9b[\x90kvU',
|
||||
@ -19,125 +20,209 @@ roots = [
|
||||
|
||||
|
||||
def test_branch_length():
|
||||
assert Merkle.branch_length(1) == 0
|
||||
assert Merkle.branch_length(2) == 1
|
||||
assert merkle.branch_length(1) == 0
|
||||
assert merkle.branch_length(2) == 1
|
||||
for n in range(3, 5):
|
||||
assert Merkle.branch_length(n) == 2
|
||||
assert merkle.branch_length(n) == 2
|
||||
for n in range(5, 9):
|
||||
assert Merkle.branch_length(n) == 3
|
||||
assert merkle.branch_length(n) == 3
|
||||
|
||||
|
||||
def test_branch_length_bad():
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.branch_length(1.0)
|
||||
merkle.branch_length(1.0)
|
||||
for n in (-1, 0):
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.branch_length(n)
|
||||
merkle.branch_length(n)
|
||||
|
||||
|
||||
def test_tree_depth():
|
||||
for n in range(1, 10):
|
||||
assert Merkle.tree_depth(n) == Merkle.branch_length(n) + 1
|
||||
assert merkle.tree_depth(n) == merkle.branch_length(n) + 1
|
||||
|
||||
|
||||
def test_root():
|
||||
for n in range(len(hashes)):
|
||||
assert Merkle.root(hashes[:n + 1]) == roots[n]
|
||||
assert merkle.root(hashes[:n + 1]) == roots[n]
|
||||
|
||||
|
||||
def test_root_bad():
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.root(0)
|
||||
merkle.root(0)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.root([])
|
||||
merkle.root([])
|
||||
|
||||
|
||||
def test_branch_and_root_from_proof():
|
||||
for n in range(len(hashes)):
|
||||
for m in range(n + 1):
|
||||
branch, root = Merkle.branch_and_root(hashes[:n + 1], m)
|
||||
branch, root = merkle.branch_and_root(hashes[:n + 1], m)
|
||||
assert root == roots[n]
|
||||
root = Merkle.root_from_proof(hashes[m], branch, m)
|
||||
root = merkle.root_from_proof(hashes[m], branch, m)
|
||||
assert root == roots[n]
|
||||
|
||||
|
||||
def test_branch_bad():
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.branch_and_root(0, 0)
|
||||
merkle.branch_and_root(0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.branch_and_root([], 0)
|
||||
merkle.branch_and_root([], 0)
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.branch_and_root(hashes, 0.0)
|
||||
merkle.branch_and_root(hashes, 0.0)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.branch_and_root(hashes[:2], -1)
|
||||
merkle.branch_and_root(hashes[:2], -1)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.branch_and_root(hashes[:2], 2)
|
||||
Merkle.branch_and_root(hashes, 0, 3)
|
||||
merkle.branch_and_root(hashes[:2], 2)
|
||||
merkle.branch_and_root(hashes, 0, 3)
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.branch_and_root(hashes, 0, 3.0)
|
||||
merkle.branch_and_root(hashes, 0, 3.0)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.branch_and_root(hashes, 0, 2)
|
||||
merkle.branch_and_root(hashes, 0, 2)
|
||||
|
||||
|
||||
def test_root_from_proof_bad():
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.root_from_proof(0, hashes[:2], 0)
|
||||
merkle.root_from_proof(0, hashes[:2], 0)
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.root_from_proof(hashes[0], hashes[0], 0)
|
||||
merkle.root_from_proof(hashes[0], hashes[0], 0)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.root_from_proof(hashes[0], hashes[:3], -1)
|
||||
merkle.root_from_proof(hashes[0], hashes[:3], -1)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.root_from_proof(hashes[0], hashes[:3], 8)
|
||||
merkle.root_from_proof(hashes[0], hashes[:3], 8)
|
||||
|
||||
|
||||
def test_level():
|
||||
for n in range(len(hashes)):
|
||||
depth = Merkle.tree_depth(n + 1)
|
||||
depth = merkle.tree_depth(n + 1)
|
||||
for depth_higher in range(0, depth):
|
||||
level = Merkle.level(hashes[:n + 1], depth_higher)
|
||||
level = merkle.level(hashes[:n + 1], depth_higher)
|
||||
if depth_higher == 0:
|
||||
assert level == hashes[:n + 1]
|
||||
if depth_higher == depth:
|
||||
assert level == [roots[n]]
|
||||
# Check raising from level to root works
|
||||
assert Merkle.root(level) == roots[n]
|
||||
assert merkle.root(level) == roots[n]
|
||||
|
||||
|
||||
def test_branch_from_level():
|
||||
def test_branch_and_root_from_level():
|
||||
# For all sub-trees
|
||||
for n in range(0, len(hashes)):
|
||||
part = hashes[:n + 1]
|
||||
# For all depths in sub-tree
|
||||
for depth_higher in range(0, Merkle.tree_depth(len(part))):
|
||||
level = Merkle.level(part, depth_higher)
|
||||
for depth_higher in range(0, merkle.tree_depth(len(part))):
|
||||
level = merkle.level(part, depth_higher)
|
||||
# For each hash in sub-tree
|
||||
for index, hash in enumerate(part):
|
||||
leaf_index = (index >> depth_higher) << depth_higher
|
||||
leaf_hashes = part[leaf_index:
|
||||
leaf_index + (1 << depth_higher)]
|
||||
branch = Merkle.branch_and_root(part, index)
|
||||
branch2 = Merkle.branch_from_level(level, leaf_hashes,
|
||||
index, depth_higher)
|
||||
branch = merkle.branch_and_root(part, index)
|
||||
branch2 = merkle.branch_and_root_from_level(
|
||||
level, leaf_hashes, index, depth_higher)
|
||||
assert branch == branch2
|
||||
|
||||
|
||||
def test_branch_from_level_bad():
|
||||
def test_branch_and_root_from_level_bad():
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.branch_from_level(hashes[0], hashes, 0, 0)
|
||||
merkle.branch_and_root_from_level(hashes[0], hashes, 0, 0)
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.branch_from_level(hashes, hashes[0], 0, 0)
|
||||
Merkle.branch_from_level(hashes, [hashes[0]], 0, 0)
|
||||
merkle.branch_and_root_from_level(hashes, hashes[0], 0, 0)
|
||||
merkle.branch_and_root_from_level(hashes, [hashes[0]], 0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.branch_from_level(hashes, [hashes[0]], -1, 0)
|
||||
merkle.branch_and_root_from_level(hashes, [hashes[0]], -1, 0)
|
||||
with pytest.raises(TypeError):
|
||||
Merkle.branch_from_level(hashes, hashes, 0.0, 0)
|
||||
merkle.branch_and_root_from_level(hashes, hashes, 0.0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.branch_from_level(hashes, [hashes[0]], 0, -1)
|
||||
merkle.branch_and_root_from_level(hashes, [hashes[0]], 0, -1)
|
||||
with pytest.raises(ValueError):
|
||||
Merkle.branch_from_level(hashes, [hashes[0]], 0, 1)
|
||||
merkle.branch_and_root_from_level(hashes, [hashes[0]], 0, 1)
|
||||
with pytest.raises(ValueError):
|
||||
# Inconsistent hash
|
||||
Merkle.branch_from_level(hashes, [hashes[1]], 0, 0)
|
||||
merkle.branch_and_root_from_level(hashes, [hashes[1]], 0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
# Inconsistent hash
|
||||
Merkle.branch_from_level(hashes, [hashes[0]], 1, 0)
|
||||
merkle.branch_and_root_from_level(hashes, [hashes[0]], 1, 0)
|
||||
|
||||
|
||||
class Source(object):
|
||||
|
||||
def __init__(self, length):
|
||||
self._hashes = [os.urandom(32) for _ in range(length)]
|
||||
|
||||
def hashes(self, start, length):
|
||||
assert start >= 0
|
||||
assert start + length <= len(self._hashes)
|
||||
return self._hashes[start: start + length]
|
||||
|
||||
|
||||
def test_merkle_cache():
|
||||
lengths = (*range(1, 18), 31, 32, 33, 57)
|
||||
source = Source(max(lengths))
|
||||
for length in lengths:
|
||||
cache = MerkleCache(merkle, source, length)
|
||||
# Simulate all possible checkpoints
|
||||
for cp_length in range(1, length + 1):
|
||||
cp_hashes = source.hashes(0, cp_length)
|
||||
# All possible indices
|
||||
for index in range(cp_length):
|
||||
# Compare correct answer with cache
|
||||
branch, root = merkle.branch_and_root(cp_hashes, index)
|
||||
branch2, root2 = cache.branch_and_root(cp_length, index)
|
||||
assert branch == branch2
|
||||
assert root == root2
|
||||
|
||||
|
||||
def merkle_cache_extension():
|
||||
source = Source(64)
|
||||
for length in range(14, 18):
|
||||
for cp_length in range(30, 36):
|
||||
cache = MerkleCache(merkle, source, length)
|
||||
cp_hashes = source.hashes(0, cp_length)
|
||||
# All possible indices
|
||||
for index in range(cp_length):
|
||||
# Compare correct answer with cache
|
||||
branch, root = merkle.branch_and_root(cp_hashes, index)
|
||||
branch2, root2 = cache.branch_and_root(cp_length, index)
|
||||
assert branch == branch2
|
||||
assert root == root2
|
||||
|
||||
|
||||
def test_markle_cache_bad():
|
||||
length = 23
|
||||
source = Source(length)
|
||||
cache = MerkleCache(merkle, source, length)
|
||||
cache.branch_and_root(5, 3)
|
||||
with pytest.raises(TypeError):
|
||||
cache.branch_and_root(5.0, 3)
|
||||
with pytest.raises(TypeError):
|
||||
cache.branch_and_root(5, 3.0)
|
||||
with pytest.raises(ValueError):
|
||||
cache.branch_and_root(0, -1)
|
||||
with pytest.raises(ValueError):
|
||||
cache.branch_and_root(3, 3)
|
||||
|
||||
def test_bad_extension():
|
||||
length = 5
|
||||
source = Source(length)
|
||||
cache = MerkleCache(merkle, source, length)
|
||||
level = cache.level
|
||||
with pytest.raises(AssertionError):
|
||||
cache.branch_and_root(8, 0)
|
||||
# The bad extension should not destroy the cache
|
||||
assert cache.level == level
|
||||
assert cache.length == length
|
||||
|
||||
|
||||
def time_it():
|
||||
source = Source(500000)
|
||||
import time
|
||||
cache = MerkleCache(merkle, source)
|
||||
cp_length = 492000
|
||||
cp_hashes = source.hashes(0, cp_length)
|
||||
brs2 = []
|
||||
t1 = time.time()
|
||||
for index in range(5, 400000, 500):
|
||||
brs2.append(cache.branch_and_root(cp_length, index))
|
||||
t2 = time.time()
|
||||
print(t2 - t1)
|
||||
assert False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user