Make the merkle cache and read_headers async
read_headers runs in a thread to avoid blocking
This commit is contained in:
parent
db5d516756
commit
1efc8cb8ec
@ -158,13 +158,16 @@ class Merkle(object):
|
||||
class MerkleCache(object):
|
||||
'''A cache to calculate merkle branches efficiently.'''
|
||||
|
||||
def __init__(self, merkle, source, length):
|
||||
'''Initialise a cache of length hashes taken from source.'''
|
||||
def __init__(self, merkle, source_func):
|
||||
'''Initialise a cache hashes taken from source_func:
|
||||
|
||||
async def source_func(index, count):
|
||||
...
|
||||
'''
|
||||
self.merkle = merkle
|
||||
self.source = source
|
||||
self.length = length
|
||||
self.depth_higher = merkle.tree_depth(length) // 2
|
||||
self.level = self._level(source.hashes(0, length))
|
||||
self.source_func = source_func
|
||||
self.length = 0
|
||||
self.depth_higher = 0
|
||||
|
||||
def _segment_length(self):
|
||||
return 1 << self.depth_higher
|
||||
@ -179,18 +182,18 @@ class MerkleCache(object):
|
||||
def _level(self, hashes):
|
||||
return self.merkle.level(hashes, self.depth_higher)
|
||||
|
||||
def _extend_to(self, length):
|
||||
async def _extend_to(self, length):
|
||||
'''Extend the length of the cache if necessary.'''
|
||||
if length <= self.length:
|
||||
return
|
||||
# Start from the beginning of any final partial segment.
|
||||
# Retain the value of depth_higher; in practice this is fine
|
||||
start = self._leaf_start(self.length)
|
||||
hashes = self.source.hashes(start, length - start)
|
||||
hashes = await self.source_func(start, length - start)
|
||||
self.level[start >> self.depth_higher:] = self._level(hashes)
|
||||
self.length = length
|
||||
|
||||
def _level_for(self, length):
|
||||
async def _level_for(self, length):
|
||||
'''Return a (level_length, final_hash) pair for a truncation
|
||||
of the hashes to the given length.'''
|
||||
if length == self.length:
|
||||
@ -198,10 +201,16 @@ class MerkleCache(object):
|
||||
level = self.level[:length >> self.depth_higher]
|
||||
leaf_start = self._leaf_start(length)
|
||||
count = min(self._segment_length(), length - leaf_start)
|
||||
hashes = self.source.hashes(leaf_start, count)
|
||||
hashes = await self.source_func(leaf_start, count)
|
||||
level += self._level(hashes)
|
||||
return level
|
||||
|
||||
async def initialize(self, length):
|
||||
'''Call to initialize the cache to a source of given length.'''
|
||||
self.length = length
|
||||
self.depth_higher = self.merkle.tree_depth(length) // 2
|
||||
self.level = self._level(await self.source_func(0, length))
|
||||
|
||||
def truncate(self, length):
|
||||
'''Truncate the cache so it covers no more than length underlying
|
||||
hashes.'''
|
||||
@ -215,7 +224,7 @@ class MerkleCache(object):
|
||||
self.length = length
|
||||
self.level[length >> self.depth_higher:] = []
|
||||
|
||||
def branch_and_root(self, length, index):
|
||||
async def branch_and_root(self, length, index):
|
||||
'''Return a merkle branch and root. Length is the number of
|
||||
hashes used to calculate the merkle root, index is the position
|
||||
of the hash to calculate the branch of.
|
||||
@ -229,12 +238,12 @@ class MerkleCache(object):
|
||||
raise ValueError('length must be positive')
|
||||
if index >= length:
|
||||
raise ValueError('index must be less than length')
|
||||
self._extend_to(length)
|
||||
await self._extend_to(length)
|
||||
leaf_start = self._leaf_start(index)
|
||||
count = min(self._segment_length(), length - leaf_start)
|
||||
leaf_hashes = self.source.hashes(leaf_start, count)
|
||||
leaf_hashes = await self.source_func(leaf_start, count)
|
||||
if length < self._segment_length():
|
||||
return self.merkle.branch_and_root(leaf_hashes, index)
|
||||
level = self._level_for(length)
|
||||
level = await self._level_for(length)
|
||||
return self.merkle.branch_and_root_from_level(
|
||||
level, leaf_hashes, index, self.depth_higher)
|
||||
|
||||
@ -139,12 +139,6 @@ class Prefetcher(object):
|
||||
return True
|
||||
|
||||
|
||||
class HeaderSource(object):
|
||||
|
||||
def __init__(self, db):
|
||||
self.hashes = db.fs_block_hashes
|
||||
|
||||
|
||||
class ChainError(Exception):
|
||||
'''Raised on error processing blocks.'''
|
||||
|
||||
@ -174,7 +168,7 @@ class BlockProcessor(electrumx.server.db.DB):
|
||||
|
||||
# Header merkle cache
|
||||
self.merkle = Merkle()
|
||||
self.header_mc = None
|
||||
self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
|
||||
|
||||
# Caches of unflushed items.
|
||||
self.headers = []
|
||||
@ -251,9 +245,7 @@ class BlockProcessor(electrumx.server.db.DB):
|
||||
await self.run_in_thread_shielded(self.backup_blocks, raw_blocks)
|
||||
last -= len(raw_blocks)
|
||||
# Truncate header_mc: header count is 1 more than the height.
|
||||
# Note header_mc is None if the reorg happens at startup.
|
||||
if self.header_mc:
|
||||
self.header_mc.truncate(self.height + 1)
|
||||
self.header_mc.truncate(self.height + 1)
|
||||
await self.prefetcher.reset_height(self.height)
|
||||
|
||||
async def reorg_hashes(self, count):
|
||||
@ -269,7 +261,7 @@ class BlockProcessor(electrumx.server.db.DB):
|
||||
self.logger.info(f'chain was reorganised replacing {count:,d} '
|
||||
f'block{s} at heights {start:,d}-{last:,d}')
|
||||
|
||||
return start, last, self.fs_block_hashes(start, count)
|
||||
return start, last, await self.fs_block_hashes(start, count)
|
||||
|
||||
async def calc_reorg_range(self, count):
|
||||
'''Calculate the reorg range'''
|
||||
@ -287,7 +279,7 @@ class BlockProcessor(electrumx.server.db.DB):
|
||||
start = self.height - 1
|
||||
count = 1
|
||||
while start > 0:
|
||||
hashes = self.fs_block_hashes(start, count)
|
||||
hashes = await self.fs_block_hashes(start, count)
|
||||
hex_hashes = [hash_to_hex_str(hash) for hash in hashes]
|
||||
d_hex_hashes = await self.daemon.block_hex_hashes(start, count)
|
||||
n = diff_pos(hex_hashes, d_hex_hashes)
|
||||
@ -774,7 +766,7 @@ class BlockProcessor(electrumx.server.db.DB):
|
||||
await self.open_for_serving()
|
||||
# Populate the header merkle cache
|
||||
length = max(1, self.height - self.env.reorg_limit)
|
||||
self.header_mc = MerkleCache(self.merkle, HeaderSource(self), length)
|
||||
await self.header_mc.initialize(length)
|
||||
self.logger.info('populated header merkle cache')
|
||||
|
||||
async def _first_open_dbs(self):
|
||||
|
||||
@ -45,12 +45,12 @@ class ChainState(object):
|
||||
'db_height': self.db_height(),
|
||||
}
|
||||
|
||||
def header_branch_and_root(self, length, height):
|
||||
async def header_branch_and_root(self, length, height):
|
||||
return self._bp.header_mc.branch_and_root(length, height)
|
||||
|
||||
async def raw_header(self, height):
|
||||
'''Return the binary header at the given height.'''
|
||||
header, n = self._bp.read_headers(height, 1)
|
||||
header, n = await self.read_headers(height, 1)
|
||||
if n != 1:
|
||||
raise IndexError(f'height {height:,d} out of range')
|
||||
return header
|
||||
|
||||
@ -182,7 +182,7 @@ class DB(object):
|
||||
offset = prior_tx_count * 32
|
||||
self.hashes_file.write(offset, hashes)
|
||||
|
||||
def read_headers(self, start_height, count):
|
||||
async def read_headers(self, start_height, count):
|
||||
'''Requires start_height >= 0, count >= 0. Reads as many headers as
|
||||
are available starting at start_height up to count. This
|
||||
would be zero if start_height is beyond self.db_height, for
|
||||
@ -191,16 +191,20 @@ class DB(object):
|
||||
Returns a (binary, n) pair where binary is the concatenated
|
||||
binary headers, and n is the count of headers returned.
|
||||
'''
|
||||
# Read some from disk
|
||||
if start_height < 0 or count < 0:
|
||||
raise self.DBError('{:,d} headers starting at {:,d} not on disk'
|
||||
.format(count, start_height))
|
||||
disk_count = max(0, min(count, self.db_height + 1 - start_height))
|
||||
if disk_count:
|
||||
offset = self.header_offset(start_height)
|
||||
size = self.header_offset(start_height + disk_count) - offset
|
||||
return self.headers_file.read(offset, size), disk_count
|
||||
return b'', 0
|
||||
raise self.DBError(f'{count:,d} headers starting at '
|
||||
f'{start_height:,d} not on disk')
|
||||
|
||||
def read_headers():
|
||||
# Read some from disk
|
||||
disk_count = max(0, min(count, self.db_height + 1 - start_height))
|
||||
if disk_count:
|
||||
offset = self.header_offset(start_height)
|
||||
size = self.header_offset(start_height + disk_count) - offset
|
||||
return self.headers_file.read(offset, size), disk_count
|
||||
return b'', 0
|
||||
|
||||
return await run_in_thread(read_headers)
|
||||
|
||||
def fs_tx_hash(self, tx_num):
|
||||
'''Return a par (tx_hash, tx_height) for the given tx number.
|
||||
@ -213,8 +217,8 @@ class DB(object):
|
||||
tx_hash = self.hashes_file.read(tx_num * 32, 32)
|
||||
return tx_hash, tx_height
|
||||
|
||||
def fs_block_hashes(self, height, count):
|
||||
headers_concat, headers_count = self.read_headers(height, count)
|
||||
async def fs_block_hashes(self, height, count):
|
||||
headers_concat, headers_count = await self.read_headers(height, count)
|
||||
if headers_count != count:
|
||||
raise self.DBError('only got {:,d} headers starting at {:,d}, not '
|
||||
'{:,d}'.format(headers_count, height, count))
|
||||
|
||||
@ -908,14 +908,14 @@ class ElectrumX(SessionBase):
|
||||
hashX = scripthash_to_hashX(scripthash)
|
||||
return await self.hashX_subscribe(hashX, scripthash)
|
||||
|
||||
def _merkle_proof(self, cp_height, height):
|
||||
async def _merkle_proof(self, cp_height, height):
|
||||
max_height = self.db_height()
|
||||
if not height <= cp_height <= max_height:
|
||||
raise RPCError(BAD_REQUEST,
|
||||
f'require header height {height:,d} <= '
|
||||
f'cp_height {cp_height:,d} <= '
|
||||
f'chain height {max_height:,d}')
|
||||
branch, root = self.chain_state.header_branch_and_root(
|
||||
branch, root = await self.chain_state.header_branch_and_root(
|
||||
cp_height + 1, height)
|
||||
return {
|
||||
'branch': [hash_to_hex_str(elt) for elt in branch],
|
||||
@ -931,7 +931,7 @@ class ElectrumX(SessionBase):
|
||||
if cp_height == 0:
|
||||
return raw_header_hex
|
||||
result = {'header': raw_header_hex}
|
||||
result.update(self._merkle_proof(cp_height, height))
|
||||
result.update(await self._merkle_proof(cp_height, height))
|
||||
return result
|
||||
|
||||
async def block_header_13(self, height):
|
||||
@ -953,11 +953,12 @@ class ElectrumX(SessionBase):
|
||||
|
||||
max_size = self.MAX_CHUNK_SIZE
|
||||
count = min(count, max_size)
|
||||
headers, count = self.chain_state.read_headers(start_height, count)
|
||||
headers, count = await self.chain_state.read_headers(start_height,
|
||||
count)
|
||||
result = {'hex': headers.hex(), 'count': count, 'max': max_size}
|
||||
if count and cp_height:
|
||||
last_height = start_height + count - 1
|
||||
result.update(self._merkle_proof(cp_height, last_height))
|
||||
result.update(await self._merkle_proof(cp_height, last_height))
|
||||
return result
|
||||
|
||||
async def block_headers_12(self, start_height, count):
|
||||
@ -970,7 +971,7 @@ class ElectrumX(SessionBase):
|
||||
index = non_negative_integer(index)
|
||||
size = self.coin.CHUNK_SIZE
|
||||
start_height = index * size
|
||||
headers, count = self.chain_state.read_headers(start_height, size)
|
||||
headers, _ = await self.chain_state.read_headers(start_height, size)
|
||||
return headers.hex()
|
||||
|
||||
async def block_get_header(self, height):
|
||||
|
||||
@ -149,72 +149,83 @@ class Source(object):
|
||||
def __init__(self, length):
|
||||
self._hashes = [os.urandom(32) for _ in range(length)]
|
||||
|
||||
def hashes(self, start, count):
|
||||
async def hashes(self, start, count):
|
||||
assert start >= 0
|
||||
assert start + count <= len(self._hashes)
|
||||
return self._hashes[start: start + count]
|
||||
|
||||
|
||||
def test_merkle_cache():
|
||||
@pytest.mark.asyncio
|
||||
async def test_merkle_cache():
|
||||
lengths = (*range(1, 18), 31, 32, 33, 57)
|
||||
source = Source(max(lengths))
|
||||
source = Source(max(lengths)).hashes
|
||||
for length in lengths:
|
||||
cache = MerkleCache(merkle, source, length)
|
||||
cache = MerkleCache(merkle, source)
|
||||
await cache.initialize(length)
|
||||
# Simulate all possible checkpoints
|
||||
for cp_length in range(1, length + 1):
|
||||
cp_hashes = source.hashes(0, cp_length)
|
||||
cp_hashes = await source(0, cp_length)
|
||||
# All possible indices
|
||||
for index in range(cp_length):
|
||||
# Compare correct answer with cache
|
||||
branch, root = merkle.branch_and_root(cp_hashes, index)
|
||||
branch2, root2 = cache.branch_and_root(cp_length, index)
|
||||
branch2, root2 = await cache.branch_and_root(cp_length, index)
|
||||
assert branch == branch2
|
||||
assert root == root2
|
||||
|
||||
|
||||
def test_merkle_cache_extension():
|
||||
source = Source(64)
|
||||
@pytest.mark.asyncio
|
||||
async def test_merkle_cache_extension():
|
||||
source = Source(64).hashes
|
||||
for length in range(14, 18):
|
||||
for cp_length in range(30, 36):
|
||||
cache = MerkleCache(merkle, source, length)
|
||||
cp_hashes = source.hashes(0, cp_length)
|
||||
cache = MerkleCache(merkle, source)
|
||||
await cache.initialize(length)
|
||||
cp_hashes = await source(0, cp_length)
|
||||
# All possible indices
|
||||
for index in range(cp_length):
|
||||
# Compare correct answer with cache
|
||||
branch, root = merkle.branch_and_root(cp_hashes, index)
|
||||
branch2, root2 = cache.branch_and_root(cp_length, index)
|
||||
branch2, root2 = await cache.branch_and_root(cp_length, index)
|
||||
assert branch == branch2
|
||||
assert root == root2
|
||||
|
||||
|
||||
def test_merkle_cache_truncation():
|
||||
@pytest.mark.asyncio
|
||||
async def test_merkle_cache_truncation():
|
||||
max_length = 33
|
||||
source = Source(max_length)
|
||||
source = Source(max_length).hashes
|
||||
for length in range(max_length - 2, max_length + 1):
|
||||
for trunc_length in range(1, 20, 3):
|
||||
cache = MerkleCache(merkle, source, length)
|
||||
cache = MerkleCache(merkle, source)
|
||||
await cache.initialize(length)
|
||||
cache.truncate(trunc_length)
|
||||
assert cache.length <= trunc_length
|
||||
for cp_length in range(1, length + 1, 3):
|
||||
cp_hashes = source.hashes(0, cp_length)
|
||||
cp_hashes = await source(0, cp_length)
|
||||
# All possible indices
|
||||
for index in range(cp_length):
|
||||
# Compare correct answer with cache
|
||||
branch, root = merkle.branch_and_root(cp_hashes, index)
|
||||
branch2, root2 = cache.branch_and_root(cp_length, index)
|
||||
branch2, root2 = await cache.branch_and_root(cp_length,
|
||||
index)
|
||||
assert branch == branch2
|
||||
assert root == root2
|
||||
|
||||
# Truncation is a no-op if longer
|
||||
cache = MerkleCache(merkle, source, 10)
|
||||
cache = MerkleCache(merkle, source)
|
||||
await cache.initialize(10)
|
||||
level = cache.level.copy()
|
||||
for length in range(10, 13):
|
||||
cache.truncate(length)
|
||||
assert cache.level == level
|
||||
assert cache.length == 10
|
||||
|
||||
def test_truncation_bad():
|
||||
cache = MerkleCache(merkle, Source(10), 10)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_bad():
|
||||
cache = MerkleCache(merkle, Source(10).hashes)
|
||||
await cache.initialize(10)
|
||||
with pytest.raises(TypeError):
|
||||
cache.truncate(1.0)
|
||||
for n in (-1, 0):
|
||||
@ -222,43 +233,48 @@ def test_truncation_bad():
|
||||
cache.truncate(n)
|
||||
|
||||
|
||||
def test_markle_cache_bad():
|
||||
@pytest.mark.asyncio
|
||||
async def test_markle_cache_bad():
|
||||
length = 23
|
||||
source = Source(length)
|
||||
cache = MerkleCache(merkle, source, length)
|
||||
cache.branch_and_root(5, 3)
|
||||
source = Source(length).hashes
|
||||
cache = MerkleCache(merkle, source)
|
||||
await cache.initialize(length)
|
||||
await cache.branch_and_root(5, 3)
|
||||
with pytest.raises(TypeError):
|
||||
cache.branch_and_root(5.0, 3)
|
||||
await cache.branch_and_root(5.0, 3)
|
||||
with pytest.raises(TypeError):
|
||||
cache.branch_and_root(5, 3.0)
|
||||
await cache.branch_and_root(5, 3.0)
|
||||
with pytest.raises(ValueError):
|
||||
cache.branch_and_root(0, -1)
|
||||
await cache.branch_and_root(0, -1)
|
||||
with pytest.raises(ValueError):
|
||||
cache.branch_and_root(3, 3)
|
||||
await cache.branch_and_root(3, 3)
|
||||
|
||||
|
||||
def test_bad_extension():
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_extension():
|
||||
length = 5
|
||||
source = Source(length)
|
||||
cache = MerkleCache(merkle, source, length)
|
||||
source = Source(length).hashes
|
||||
cache = MerkleCache(merkle, source)
|
||||
await cache.initialize(length)
|
||||
level = cache.level.copy()
|
||||
with pytest.raises(AssertionError):
|
||||
cache.branch_and_root(8, 0)
|
||||
await cache.branch_and_root(8, 0)
|
||||
# The bad extension should not destroy the cache
|
||||
assert cache.level == level
|
||||
assert cache.length == length
|
||||
|
||||
|
||||
def time_it():
|
||||
source = Source(500000)
|
||||
async def time_it():
|
||||
source = Source(500000).hashes
|
||||
cp_length = 492000
|
||||
import time
|
||||
cache = MerkleCache(merkle, source)
|
||||
cp_length = 492000
|
||||
cp_hashes = source.hashes(0, cp_length)
|
||||
await cache.initialize(cp_length)
|
||||
cp_hashes = await source(0, cp_length)
|
||||
brs2 = []
|
||||
t1 = time.time()
|
||||
for index in range(5, 400000, 500):
|
||||
brs2.append(cache.branch_and_root(cp_length, index))
|
||||
brs2.append(await cache.branch_and_root(cp_length, index))
|
||||
t2 = time.time()
|
||||
print(t2 - t1)
|
||||
assert False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user