From 1efc8cb8ec2ecf21e516dd7019d75e9ac2987610 Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Mon, 6 Aug 2018 21:27:33 +0900 Subject: [PATCH] Make the merkle cache and read_headers async read_headers runs in a thread to avoid blocking --- electrumx/lib/merkle.py | 37 +++++++----- electrumx/server/block_processor.py | 18 ++---- electrumx/server/chain_state.py | 4 +- electrumx/server/db.py | 28 +++++---- electrumx/server/session.py | 13 +++-- tests/lib/test_merkle.py | 88 +++++++++++++++++------------ 6 files changed, 105 insertions(+), 83 deletions(-) diff --git a/electrumx/lib/merkle.py b/electrumx/lib/merkle.py index 215879e..d8e5971 100644 --- a/electrumx/lib/merkle.py +++ b/electrumx/lib/merkle.py @@ -158,13 +158,16 @@ class Merkle(object): class MerkleCache(object): '''A cache to calculate merkle branches efficiently.''' - def __init__(self, merkle, source, length): - '''Initialise a cache of length hashes taken from source.''' + def __init__(self, merkle, source_func): + '''Initialise a cache hashes taken from source_func: + + async def source_func(index, count): + ... + ''' self.merkle = merkle - self.source = source - self.length = length - self.depth_higher = merkle.tree_depth(length) // 2 - self.level = self._level(source.hashes(0, length)) + self.source_func = source_func + self.length = 0 + self.depth_higher = 0 def _segment_length(self): return 1 << self.depth_higher @@ -179,18 +182,18 @@ class MerkleCache(object): def _level(self, hashes): return self.merkle.level(hashes, self.depth_higher) - def _extend_to(self, length): + async def _extend_to(self, length): '''Extend the length of the cache if necessary.''' if length <= self.length: return # Start from the beginning of any final partial segment. # Retain the value of depth_higher; in practice this is fine start = self._leaf_start(self.length) - hashes = self.source.hashes(start, length - start) + hashes = await self.source_func(start, length - start) self.level[start >> self.depth_higher:] = self._level(hashes) self.length = length - def _level_for(self, length): + async def _level_for(self, length): '''Return a (level_length, final_hash) pair for a truncation of the hashes to the given length.''' if length == self.length: @@ -198,10 +201,16 @@ class MerkleCache(object): level = self.level[:length >> self.depth_higher] leaf_start = self._leaf_start(length) count = min(self._segment_length(), length - leaf_start) - hashes = self.source.hashes(leaf_start, count) + hashes = await self.source_func(leaf_start, count) level += self._level(hashes) return level + async def initialize(self, length): + '''Call to initialize the cache to a source of given length.''' + self.length = length + self.depth_higher = self.merkle.tree_depth(length) // 2 + self.level = self._level(await self.source_func(0, length)) + def truncate(self, length): '''Truncate the cache so it covers no more than length underlying hashes.''' @@ -215,7 +224,7 @@ class MerkleCache(object): self.length = length self.level[length >> self.depth_higher:] = [] - def branch_and_root(self, length, index): + async def branch_and_root(self, length, index): '''Return a merkle branch and root. Length is the number of hashes used to calculate the merkle root, index is the position of the hash to calculate the branch of. @@ -229,12 +238,12 @@ class MerkleCache(object): raise ValueError('length must be positive') if index >= length: raise ValueError('index must be less than length') - self._extend_to(length) + await self._extend_to(length) leaf_start = self._leaf_start(index) count = min(self._segment_length(), length - leaf_start) - leaf_hashes = self.source.hashes(leaf_start, count) + leaf_hashes = await self.source_func(leaf_start, count) if length < self._segment_length(): return self.merkle.branch_and_root(leaf_hashes, index) - level = self._level_for(length) + level = await self._level_for(length) return self.merkle.branch_and_root_from_level( level, leaf_hashes, index, self.depth_higher) diff --git a/electrumx/server/block_processor.py b/electrumx/server/block_processor.py index 1645bb7..08a72a6 100644 --- a/electrumx/server/block_processor.py +++ b/electrumx/server/block_processor.py @@ -139,12 +139,6 @@ class Prefetcher(object): return True -class HeaderSource(object): - - def __init__(self, db): - self.hashes = db.fs_block_hashes - - class ChainError(Exception): '''Raised on error processing blocks.''' @@ -174,7 +168,7 @@ class BlockProcessor(electrumx.server.db.DB): # Header merkle cache self.merkle = Merkle() - self.header_mc = None + self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes) # Caches of unflushed items. self.headers = [] @@ -251,9 +245,7 @@ class BlockProcessor(electrumx.server.db.DB): await self.run_in_thread_shielded(self.backup_blocks, raw_blocks) last -= len(raw_blocks) # Truncate header_mc: header count is 1 more than the height. - # Note header_mc is None if the reorg happens at startup. - if self.header_mc: - self.header_mc.truncate(self.height + 1) + self.header_mc.truncate(self.height + 1) await self.prefetcher.reset_height(self.height) async def reorg_hashes(self, count): @@ -269,7 +261,7 @@ class BlockProcessor(electrumx.server.db.DB): self.logger.info(f'chain was reorganised replacing {count:,d} ' f'block{s} at heights {start:,d}-{last:,d}') - return start, last, self.fs_block_hashes(start, count) + return start, last, await self.fs_block_hashes(start, count) async def calc_reorg_range(self, count): '''Calculate the reorg range''' @@ -287,7 +279,7 @@ class BlockProcessor(electrumx.server.db.DB): start = self.height - 1 count = 1 while start > 0: - hashes = self.fs_block_hashes(start, count) + hashes = await self.fs_block_hashes(start, count) hex_hashes = [hash_to_hex_str(hash) for hash in hashes] d_hex_hashes = await self.daemon.block_hex_hashes(start, count) n = diff_pos(hex_hashes, d_hex_hashes) @@ -774,7 +766,7 @@ class BlockProcessor(electrumx.server.db.DB): await self.open_for_serving() # Populate the header merkle cache length = max(1, self.height - self.env.reorg_limit) - self.header_mc = MerkleCache(self.merkle, HeaderSource(self), length) + await self.header_mc.initialize(length) self.logger.info('populated header merkle cache') async def _first_open_dbs(self): diff --git a/electrumx/server/chain_state.py b/electrumx/server/chain_state.py index 58534db..06c2d8c 100644 --- a/electrumx/server/chain_state.py +++ b/electrumx/server/chain_state.py @@ -45,12 +45,12 @@ class ChainState(object): 'db_height': self.db_height(), } - def header_branch_and_root(self, length, height): + async def header_branch_and_root(self, length, height): return self._bp.header_mc.branch_and_root(length, height) async def raw_header(self, height): '''Return the binary header at the given height.''' - header, n = self._bp.read_headers(height, 1) + header, n = await self.read_headers(height, 1) if n != 1: raise IndexError(f'height {height:,d} out of range') return header diff --git a/electrumx/server/db.py b/electrumx/server/db.py index 3eeadb4..a6177a3 100644 --- a/electrumx/server/db.py +++ b/electrumx/server/db.py @@ -182,7 +182,7 @@ class DB(object): offset = prior_tx_count * 32 self.hashes_file.write(offset, hashes) - def read_headers(self, start_height, count): + async def read_headers(self, start_height, count): '''Requires start_height >= 0, count >= 0. Reads as many headers as are available starting at start_height up to count. This would be zero if start_height is beyond self.db_height, for @@ -191,16 +191,20 @@ class DB(object): Returns a (binary, n) pair where binary is the concatenated binary headers, and n is the count of headers returned. ''' - # Read some from disk if start_height < 0 or count < 0: - raise self.DBError('{:,d} headers starting at {:,d} not on disk' - .format(count, start_height)) - disk_count = max(0, min(count, self.db_height + 1 - start_height)) - if disk_count: - offset = self.header_offset(start_height) - size = self.header_offset(start_height + disk_count) - offset - return self.headers_file.read(offset, size), disk_count - return b'', 0 + raise self.DBError(f'{count:,d} headers starting at ' + f'{start_height:,d} not on disk') + + def read_headers(): + # Read some from disk + disk_count = max(0, min(count, self.db_height + 1 - start_height)) + if disk_count: + offset = self.header_offset(start_height) + size = self.header_offset(start_height + disk_count) - offset + return self.headers_file.read(offset, size), disk_count + return b'', 0 + + return await run_in_thread(read_headers) def fs_tx_hash(self, tx_num): '''Return a par (tx_hash, tx_height) for the given tx number. @@ -213,8 +217,8 @@ class DB(object): tx_hash = self.hashes_file.read(tx_num * 32, 32) return tx_hash, tx_height - def fs_block_hashes(self, height, count): - headers_concat, headers_count = self.read_headers(height, count) + async def fs_block_hashes(self, height, count): + headers_concat, headers_count = await self.read_headers(height, count) if headers_count != count: raise self.DBError('only got {:,d} headers starting at {:,d}, not ' '{:,d}'.format(headers_count, height, count)) diff --git a/electrumx/server/session.py b/electrumx/server/session.py index 33e6cea..afa61f5 100644 --- a/electrumx/server/session.py +++ b/electrumx/server/session.py @@ -908,14 +908,14 @@ class ElectrumX(SessionBase): hashX = scripthash_to_hashX(scripthash) return await self.hashX_subscribe(hashX, scripthash) - def _merkle_proof(self, cp_height, height): + async def _merkle_proof(self, cp_height, height): max_height = self.db_height() if not height <= cp_height <= max_height: raise RPCError(BAD_REQUEST, f'require header height {height:,d} <= ' f'cp_height {cp_height:,d} <= ' f'chain height {max_height:,d}') - branch, root = self.chain_state.header_branch_and_root( + branch, root = await self.chain_state.header_branch_and_root( cp_height + 1, height) return { 'branch': [hash_to_hex_str(elt) for elt in branch], @@ -931,7 +931,7 @@ class ElectrumX(SessionBase): if cp_height == 0: return raw_header_hex result = {'header': raw_header_hex} - result.update(self._merkle_proof(cp_height, height)) + result.update(await self._merkle_proof(cp_height, height)) return result async def block_header_13(self, height): @@ -953,11 +953,12 @@ class ElectrumX(SessionBase): max_size = self.MAX_CHUNK_SIZE count = min(count, max_size) - headers, count = self.chain_state.read_headers(start_height, count) + headers, count = await self.chain_state.read_headers(start_height, + count) result = {'hex': headers.hex(), 'count': count, 'max': max_size} if count and cp_height: last_height = start_height + count - 1 - result.update(self._merkle_proof(cp_height, last_height)) + result.update(await self._merkle_proof(cp_height, last_height)) return result async def block_headers_12(self, start_height, count): @@ -970,7 +971,7 @@ class ElectrumX(SessionBase): index = non_negative_integer(index) size = self.coin.CHUNK_SIZE start_height = index * size - headers, count = self.chain_state.read_headers(start_height, size) + headers, _ = await self.chain_state.read_headers(start_height, size) return headers.hex() async def block_get_header(self, height): diff --git a/tests/lib/test_merkle.py b/tests/lib/test_merkle.py index dd9d9ff..35b9c65 100644 --- a/tests/lib/test_merkle.py +++ b/tests/lib/test_merkle.py @@ -149,72 +149,83 @@ class Source(object): def __init__(self, length): self._hashes = [os.urandom(32) for _ in range(length)] - def hashes(self, start, count): + async def hashes(self, start, count): assert start >= 0 assert start + count <= len(self._hashes) return self._hashes[start: start + count] -def test_merkle_cache(): +@pytest.mark.asyncio +async def test_merkle_cache(): lengths = (*range(1, 18), 31, 32, 33, 57) - source = Source(max(lengths)) + source = Source(max(lengths)).hashes for length in lengths: - cache = MerkleCache(merkle, source, length) + cache = MerkleCache(merkle, source) + await cache.initialize(length) # Simulate all possible checkpoints for cp_length in range(1, length + 1): - cp_hashes = source.hashes(0, cp_length) + cp_hashes = await source(0, cp_length) # All possible indices for index in range(cp_length): # Compare correct answer with cache branch, root = merkle.branch_and_root(cp_hashes, index) - branch2, root2 = cache.branch_and_root(cp_length, index) + branch2, root2 = await cache.branch_and_root(cp_length, index) assert branch == branch2 assert root == root2 -def test_merkle_cache_extension(): - source = Source(64) +@pytest.mark.asyncio +async def test_merkle_cache_extension(): + source = Source(64).hashes for length in range(14, 18): for cp_length in range(30, 36): - cache = MerkleCache(merkle, source, length) - cp_hashes = source.hashes(0, cp_length) + cache = MerkleCache(merkle, source) + await cache.initialize(length) + cp_hashes = await source(0, cp_length) # All possible indices for index in range(cp_length): # Compare correct answer with cache branch, root = merkle.branch_and_root(cp_hashes, index) - branch2, root2 = cache.branch_and_root(cp_length, index) + branch2, root2 = await cache.branch_and_root(cp_length, index) assert branch == branch2 assert root == root2 -def test_merkle_cache_truncation(): +@pytest.mark.asyncio +async def test_merkle_cache_truncation(): max_length = 33 - source = Source(max_length) + source = Source(max_length).hashes for length in range(max_length - 2, max_length + 1): for trunc_length in range(1, 20, 3): - cache = MerkleCache(merkle, source, length) + cache = MerkleCache(merkle, source) + await cache.initialize(length) cache.truncate(trunc_length) assert cache.length <= trunc_length for cp_length in range(1, length + 1, 3): - cp_hashes = source.hashes(0, cp_length) + cp_hashes = await source(0, cp_length) # All possible indices for index in range(cp_length): # Compare correct answer with cache branch, root = merkle.branch_and_root(cp_hashes, index) - branch2, root2 = cache.branch_and_root(cp_length, index) + branch2, root2 = await cache.branch_and_root(cp_length, + index) assert branch == branch2 assert root == root2 # Truncation is a no-op if longer - cache = MerkleCache(merkle, source, 10) + cache = MerkleCache(merkle, source) + await cache.initialize(10) level = cache.level.copy() for length in range(10, 13): cache.truncate(length) assert cache.level == level assert cache.length == 10 -def test_truncation_bad(): - cache = MerkleCache(merkle, Source(10), 10) + +@pytest.mark.asyncio +async def test_truncation_bad(): + cache = MerkleCache(merkle, Source(10).hashes) + await cache.initialize(10) with pytest.raises(TypeError): cache.truncate(1.0) for n in (-1, 0): @@ -222,43 +233,48 @@ def test_truncation_bad(): cache.truncate(n) -def test_markle_cache_bad(): +@pytest.mark.asyncio +async def test_markle_cache_bad(): length = 23 - source = Source(length) - cache = MerkleCache(merkle, source, length) - cache.branch_and_root(5, 3) + source = Source(length).hashes + cache = MerkleCache(merkle, source) + await cache.initialize(length) + await cache.branch_and_root(5, 3) with pytest.raises(TypeError): - cache.branch_and_root(5.0, 3) + await cache.branch_and_root(5.0, 3) with pytest.raises(TypeError): - cache.branch_and_root(5, 3.0) + await cache.branch_and_root(5, 3.0) with pytest.raises(ValueError): - cache.branch_and_root(0, -1) + await cache.branch_and_root(0, -1) with pytest.raises(ValueError): - cache.branch_and_root(3, 3) + await cache.branch_and_root(3, 3) -def test_bad_extension(): +@pytest.mark.asyncio +async def test_bad_extension(): length = 5 - source = Source(length) - cache = MerkleCache(merkle, source, length) + source = Source(length).hashes + cache = MerkleCache(merkle, source) + await cache.initialize(length) level = cache.level.copy() with pytest.raises(AssertionError): - cache.branch_and_root(8, 0) + await cache.branch_and_root(8, 0) # The bad extension should not destroy the cache assert cache.level == level assert cache.length == length -def time_it(): - source = Source(500000) +async def time_it(): + source = Source(500000).hashes + cp_length = 492000 import time cache = MerkleCache(merkle, source) - cp_length = 492000 - cp_hashes = source.hashes(0, cp_length) + await cache.initialize(cp_length) + cp_hashes = await source(0, cp_length) brs2 = [] t1 = time.time() for index in range(5, 400000, 500): - brs2.append(cache.branch_and_root(cp_length, index)) + brs2.append(await cache.branch_and_root(cp_length, index)) t2 = time.time() print(t2 - t1) assert False