Make the merkle cache and read_headers async

read_headers runs in a thread to avoid blocking
This commit is contained in:
Neil Booth 2018-08-06 21:27:33 +09:00
parent db5d516756
commit 1efc8cb8ec
6 changed files with 105 additions and 83 deletions

View File

@ -158,13 +158,16 @@ class Merkle(object):
class MerkleCache(object): class MerkleCache(object):
'''A cache to calculate merkle branches efficiently.''' '''A cache to calculate merkle branches efficiently.'''
def __init__(self, merkle, source, length): def __init__(self, merkle, source_func):
'''Initialise a cache of length hashes taken from source.''' '''Initialise a cache hashes taken from source_func:
async def source_func(index, count):
...
'''
self.merkle = merkle self.merkle = merkle
self.source = source self.source_func = source_func
self.length = length self.length = 0
self.depth_higher = merkle.tree_depth(length) // 2 self.depth_higher = 0
self.level = self._level(source.hashes(0, length))
def _segment_length(self): def _segment_length(self):
return 1 << self.depth_higher return 1 << self.depth_higher
@ -179,18 +182,18 @@ class MerkleCache(object):
def _level(self, hashes): def _level(self, hashes):
return self.merkle.level(hashes, self.depth_higher) return self.merkle.level(hashes, self.depth_higher)
def _extend_to(self, length): async def _extend_to(self, length):
'''Extend the length of the cache if necessary.''' '''Extend the length of the cache if necessary.'''
if length <= self.length: if length <= self.length:
return return
# Start from the beginning of any final partial segment. # Start from the beginning of any final partial segment.
# Retain the value of depth_higher; in practice this is fine # Retain the value of depth_higher; in practice this is fine
start = self._leaf_start(self.length) start = self._leaf_start(self.length)
hashes = self.source.hashes(start, length - start) hashes = await self.source_func(start, length - start)
self.level[start >> self.depth_higher:] = self._level(hashes) self.level[start >> self.depth_higher:] = self._level(hashes)
self.length = length self.length = length
def _level_for(self, length): async def _level_for(self, length):
'''Return a (level_length, final_hash) pair for a truncation '''Return a (level_length, final_hash) pair for a truncation
of the hashes to the given length.''' of the hashes to the given length.'''
if length == self.length: if length == self.length:
@ -198,10 +201,16 @@ class MerkleCache(object):
level = self.level[:length >> self.depth_higher] level = self.level[:length >> self.depth_higher]
leaf_start = self._leaf_start(length) leaf_start = self._leaf_start(length)
count = min(self._segment_length(), length - leaf_start) count = min(self._segment_length(), length - leaf_start)
hashes = self.source.hashes(leaf_start, count) hashes = await self.source_func(leaf_start, count)
level += self._level(hashes) level += self._level(hashes)
return level return level
async def initialize(self, length):
'''Call to initialize the cache to a source of given length.'''
self.length = length
self.depth_higher = self.merkle.tree_depth(length) // 2
self.level = self._level(await self.source_func(0, length))
def truncate(self, length): def truncate(self, length):
'''Truncate the cache so it covers no more than length underlying '''Truncate the cache so it covers no more than length underlying
hashes.''' hashes.'''
@ -215,7 +224,7 @@ class MerkleCache(object):
self.length = length self.length = length
self.level[length >> self.depth_higher:] = [] self.level[length >> self.depth_higher:] = []
def branch_and_root(self, length, index): async def branch_and_root(self, length, index):
'''Return a merkle branch and root. Length is the number of '''Return a merkle branch and root. Length is the number of
hashes used to calculate the merkle root, index is the position hashes used to calculate the merkle root, index is the position
of the hash to calculate the branch of. of the hash to calculate the branch of.
@ -229,12 +238,12 @@ class MerkleCache(object):
raise ValueError('length must be positive') raise ValueError('length must be positive')
if index >= length: if index >= length:
raise ValueError('index must be less than length') raise ValueError('index must be less than length')
self._extend_to(length) await self._extend_to(length)
leaf_start = self._leaf_start(index) leaf_start = self._leaf_start(index)
count = min(self._segment_length(), length - leaf_start) count = min(self._segment_length(), length - leaf_start)
leaf_hashes = self.source.hashes(leaf_start, count) leaf_hashes = await self.source_func(leaf_start, count)
if length < self._segment_length(): if length < self._segment_length():
return self.merkle.branch_and_root(leaf_hashes, index) return self.merkle.branch_and_root(leaf_hashes, index)
level = self._level_for(length) level = await self._level_for(length)
return self.merkle.branch_and_root_from_level( return self.merkle.branch_and_root_from_level(
level, leaf_hashes, index, self.depth_higher) level, leaf_hashes, index, self.depth_higher)

View File

@ -139,12 +139,6 @@ class Prefetcher(object):
return True return True
class HeaderSource(object):
def __init__(self, db):
self.hashes = db.fs_block_hashes
class ChainError(Exception): class ChainError(Exception):
'''Raised on error processing blocks.''' '''Raised on error processing blocks.'''
@ -174,7 +168,7 @@ class BlockProcessor(electrumx.server.db.DB):
# Header merkle cache # Header merkle cache
self.merkle = Merkle() self.merkle = Merkle()
self.header_mc = None self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
# Caches of unflushed items. # Caches of unflushed items.
self.headers = [] self.headers = []
@ -251,9 +245,7 @@ class BlockProcessor(electrumx.server.db.DB):
await self.run_in_thread_shielded(self.backup_blocks, raw_blocks) await self.run_in_thread_shielded(self.backup_blocks, raw_blocks)
last -= len(raw_blocks) last -= len(raw_blocks)
# Truncate header_mc: header count is 1 more than the height. # Truncate header_mc: header count is 1 more than the height.
# Note header_mc is None if the reorg happens at startup. self.header_mc.truncate(self.height + 1)
if self.header_mc:
self.header_mc.truncate(self.height + 1)
await self.prefetcher.reset_height(self.height) await self.prefetcher.reset_height(self.height)
async def reorg_hashes(self, count): async def reorg_hashes(self, count):
@ -269,7 +261,7 @@ class BlockProcessor(electrumx.server.db.DB):
self.logger.info(f'chain was reorganised replacing {count:,d} ' self.logger.info(f'chain was reorganised replacing {count:,d} '
f'block{s} at heights {start:,d}-{last:,d}') f'block{s} at heights {start:,d}-{last:,d}')
return start, last, self.fs_block_hashes(start, count) return start, last, await self.fs_block_hashes(start, count)
async def calc_reorg_range(self, count): async def calc_reorg_range(self, count):
'''Calculate the reorg range''' '''Calculate the reorg range'''
@ -287,7 +279,7 @@ class BlockProcessor(electrumx.server.db.DB):
start = self.height - 1 start = self.height - 1
count = 1 count = 1
while start > 0: while start > 0:
hashes = self.fs_block_hashes(start, count) hashes = await self.fs_block_hashes(start, count)
hex_hashes = [hash_to_hex_str(hash) for hash in hashes] hex_hashes = [hash_to_hex_str(hash) for hash in hashes]
d_hex_hashes = await self.daemon.block_hex_hashes(start, count) d_hex_hashes = await self.daemon.block_hex_hashes(start, count)
n = diff_pos(hex_hashes, d_hex_hashes) n = diff_pos(hex_hashes, d_hex_hashes)
@ -774,7 +766,7 @@ class BlockProcessor(electrumx.server.db.DB):
await self.open_for_serving() await self.open_for_serving()
# Populate the header merkle cache # Populate the header merkle cache
length = max(1, self.height - self.env.reorg_limit) length = max(1, self.height - self.env.reorg_limit)
self.header_mc = MerkleCache(self.merkle, HeaderSource(self), length) await self.header_mc.initialize(length)
self.logger.info('populated header merkle cache') self.logger.info('populated header merkle cache')
async def _first_open_dbs(self): async def _first_open_dbs(self):

View File

@ -45,12 +45,12 @@ class ChainState(object):
'db_height': self.db_height(), 'db_height': self.db_height(),
} }
def header_branch_and_root(self, length, height): async def header_branch_and_root(self, length, height):
return self._bp.header_mc.branch_and_root(length, height) return self._bp.header_mc.branch_and_root(length, height)
async def raw_header(self, height): async def raw_header(self, height):
'''Return the binary header at the given height.''' '''Return the binary header at the given height.'''
header, n = self._bp.read_headers(height, 1) header, n = await self.read_headers(height, 1)
if n != 1: if n != 1:
raise IndexError(f'height {height:,d} out of range') raise IndexError(f'height {height:,d} out of range')
return header return header

View File

@ -182,7 +182,7 @@ class DB(object):
offset = prior_tx_count * 32 offset = prior_tx_count * 32
self.hashes_file.write(offset, hashes) self.hashes_file.write(offset, hashes)
def read_headers(self, start_height, count): async def read_headers(self, start_height, count):
'''Requires start_height >= 0, count >= 0. Reads as many headers as '''Requires start_height >= 0, count >= 0. Reads as many headers as
are available starting at start_height up to count. This are available starting at start_height up to count. This
would be zero if start_height is beyond self.db_height, for would be zero if start_height is beyond self.db_height, for
@ -191,16 +191,20 @@ class DB(object):
Returns a (binary, n) pair where binary is the concatenated Returns a (binary, n) pair where binary is the concatenated
binary headers, and n is the count of headers returned. binary headers, and n is the count of headers returned.
''' '''
# Read some from disk
if start_height < 0 or count < 0: if start_height < 0 or count < 0:
raise self.DBError('{:,d} headers starting at {:,d} not on disk' raise self.DBError(f'{count:,d} headers starting at '
.format(count, start_height)) f'{start_height:,d} not on disk')
disk_count = max(0, min(count, self.db_height + 1 - start_height))
if disk_count: def read_headers():
offset = self.header_offset(start_height) # Read some from disk
size = self.header_offset(start_height + disk_count) - offset disk_count = max(0, min(count, self.db_height + 1 - start_height))
return self.headers_file.read(offset, size), disk_count if disk_count:
return b'', 0 offset = self.header_offset(start_height)
size = self.header_offset(start_height + disk_count) - offset
return self.headers_file.read(offset, size), disk_count
return b'', 0
return await run_in_thread(read_headers)
def fs_tx_hash(self, tx_num): def fs_tx_hash(self, tx_num):
'''Return a par (tx_hash, tx_height) for the given tx number. '''Return a par (tx_hash, tx_height) for the given tx number.
@ -213,8 +217,8 @@ class DB(object):
tx_hash = self.hashes_file.read(tx_num * 32, 32) tx_hash = self.hashes_file.read(tx_num * 32, 32)
return tx_hash, tx_height return tx_hash, tx_height
def fs_block_hashes(self, height, count): async def fs_block_hashes(self, height, count):
headers_concat, headers_count = self.read_headers(height, count) headers_concat, headers_count = await self.read_headers(height, count)
if headers_count != count: if headers_count != count:
raise self.DBError('only got {:,d} headers starting at {:,d}, not ' raise self.DBError('only got {:,d} headers starting at {:,d}, not '
'{:,d}'.format(headers_count, height, count)) '{:,d}'.format(headers_count, height, count))

View File

@ -908,14 +908,14 @@ class ElectrumX(SessionBase):
hashX = scripthash_to_hashX(scripthash) hashX = scripthash_to_hashX(scripthash)
return await self.hashX_subscribe(hashX, scripthash) return await self.hashX_subscribe(hashX, scripthash)
def _merkle_proof(self, cp_height, height): async def _merkle_proof(self, cp_height, height):
max_height = self.db_height() max_height = self.db_height()
if not height <= cp_height <= max_height: if not height <= cp_height <= max_height:
raise RPCError(BAD_REQUEST, raise RPCError(BAD_REQUEST,
f'require header height {height:,d} <= ' f'require header height {height:,d} <= '
f'cp_height {cp_height:,d} <= ' f'cp_height {cp_height:,d} <= '
f'chain height {max_height:,d}') f'chain height {max_height:,d}')
branch, root = self.chain_state.header_branch_and_root( branch, root = await self.chain_state.header_branch_and_root(
cp_height + 1, height) cp_height + 1, height)
return { return {
'branch': [hash_to_hex_str(elt) for elt in branch], 'branch': [hash_to_hex_str(elt) for elt in branch],
@ -931,7 +931,7 @@ class ElectrumX(SessionBase):
if cp_height == 0: if cp_height == 0:
return raw_header_hex return raw_header_hex
result = {'header': raw_header_hex} result = {'header': raw_header_hex}
result.update(self._merkle_proof(cp_height, height)) result.update(await self._merkle_proof(cp_height, height))
return result return result
async def block_header_13(self, height): async def block_header_13(self, height):
@ -953,11 +953,12 @@ class ElectrumX(SessionBase):
max_size = self.MAX_CHUNK_SIZE max_size = self.MAX_CHUNK_SIZE
count = min(count, max_size) count = min(count, max_size)
headers, count = self.chain_state.read_headers(start_height, count) headers, count = await self.chain_state.read_headers(start_height,
count)
result = {'hex': headers.hex(), 'count': count, 'max': max_size} result = {'hex': headers.hex(), 'count': count, 'max': max_size}
if count and cp_height: if count and cp_height:
last_height = start_height + count - 1 last_height = start_height + count - 1
result.update(self._merkle_proof(cp_height, last_height)) result.update(await self._merkle_proof(cp_height, last_height))
return result return result
async def block_headers_12(self, start_height, count): async def block_headers_12(self, start_height, count):
@ -970,7 +971,7 @@ class ElectrumX(SessionBase):
index = non_negative_integer(index) index = non_negative_integer(index)
size = self.coin.CHUNK_SIZE size = self.coin.CHUNK_SIZE
start_height = index * size start_height = index * size
headers, count = self.chain_state.read_headers(start_height, size) headers, _ = await self.chain_state.read_headers(start_height, size)
return headers.hex() return headers.hex()
async def block_get_header(self, height): async def block_get_header(self, height):

View File

@ -149,72 +149,83 @@ class Source(object):
def __init__(self, length): def __init__(self, length):
self._hashes = [os.urandom(32) for _ in range(length)] self._hashes = [os.urandom(32) for _ in range(length)]
def hashes(self, start, count): async def hashes(self, start, count):
assert start >= 0 assert start >= 0
assert start + count <= len(self._hashes) assert start + count <= len(self._hashes)
return self._hashes[start: start + count] return self._hashes[start: start + count]
def test_merkle_cache(): @pytest.mark.asyncio
async def test_merkle_cache():
lengths = (*range(1, 18), 31, 32, 33, 57) lengths = (*range(1, 18), 31, 32, 33, 57)
source = Source(max(lengths)) source = Source(max(lengths)).hashes
for length in lengths: for length in lengths:
cache = MerkleCache(merkle, source, length) cache = MerkleCache(merkle, source)
await cache.initialize(length)
# Simulate all possible checkpoints # Simulate all possible checkpoints
for cp_length in range(1, length + 1): for cp_length in range(1, length + 1):
cp_hashes = source.hashes(0, cp_length) cp_hashes = await source(0, cp_length)
# All possible indices # All possible indices
for index in range(cp_length): for index in range(cp_length):
# Compare correct answer with cache # Compare correct answer with cache
branch, root = merkle.branch_and_root(cp_hashes, index) branch, root = merkle.branch_and_root(cp_hashes, index)
branch2, root2 = cache.branch_and_root(cp_length, index) branch2, root2 = await cache.branch_and_root(cp_length, index)
assert branch == branch2 assert branch == branch2
assert root == root2 assert root == root2
def test_merkle_cache_extension(): @pytest.mark.asyncio
source = Source(64) async def test_merkle_cache_extension():
source = Source(64).hashes
for length in range(14, 18): for length in range(14, 18):
for cp_length in range(30, 36): for cp_length in range(30, 36):
cache = MerkleCache(merkle, source, length) cache = MerkleCache(merkle, source)
cp_hashes = source.hashes(0, cp_length) await cache.initialize(length)
cp_hashes = await source(0, cp_length)
# All possible indices # All possible indices
for index in range(cp_length): for index in range(cp_length):
# Compare correct answer with cache # Compare correct answer with cache
branch, root = merkle.branch_and_root(cp_hashes, index) branch, root = merkle.branch_and_root(cp_hashes, index)
branch2, root2 = cache.branch_and_root(cp_length, index) branch2, root2 = await cache.branch_and_root(cp_length, index)
assert branch == branch2 assert branch == branch2
assert root == root2 assert root == root2
def test_merkle_cache_truncation(): @pytest.mark.asyncio
async def test_merkle_cache_truncation():
max_length = 33 max_length = 33
source = Source(max_length) source = Source(max_length).hashes
for length in range(max_length - 2, max_length + 1): for length in range(max_length - 2, max_length + 1):
for trunc_length in range(1, 20, 3): for trunc_length in range(1, 20, 3):
cache = MerkleCache(merkle, source, length) cache = MerkleCache(merkle, source)
await cache.initialize(length)
cache.truncate(trunc_length) cache.truncate(trunc_length)
assert cache.length <= trunc_length assert cache.length <= trunc_length
for cp_length in range(1, length + 1, 3): for cp_length in range(1, length + 1, 3):
cp_hashes = source.hashes(0, cp_length) cp_hashes = await source(0, cp_length)
# All possible indices # All possible indices
for index in range(cp_length): for index in range(cp_length):
# Compare correct answer with cache # Compare correct answer with cache
branch, root = merkle.branch_and_root(cp_hashes, index) branch, root = merkle.branch_and_root(cp_hashes, index)
branch2, root2 = cache.branch_and_root(cp_length, index) branch2, root2 = await cache.branch_and_root(cp_length,
index)
assert branch == branch2 assert branch == branch2
assert root == root2 assert root == root2
# Truncation is a no-op if longer # Truncation is a no-op if longer
cache = MerkleCache(merkle, source, 10) cache = MerkleCache(merkle, source)
await cache.initialize(10)
level = cache.level.copy() level = cache.level.copy()
for length in range(10, 13): for length in range(10, 13):
cache.truncate(length) cache.truncate(length)
assert cache.level == level assert cache.level == level
assert cache.length == 10 assert cache.length == 10
def test_truncation_bad():
cache = MerkleCache(merkle, Source(10), 10) @pytest.mark.asyncio
async def test_truncation_bad():
cache = MerkleCache(merkle, Source(10).hashes)
await cache.initialize(10)
with pytest.raises(TypeError): with pytest.raises(TypeError):
cache.truncate(1.0) cache.truncate(1.0)
for n in (-1, 0): for n in (-1, 0):
@ -222,43 +233,48 @@ def test_truncation_bad():
cache.truncate(n) cache.truncate(n)
def test_markle_cache_bad(): @pytest.mark.asyncio
async def test_markle_cache_bad():
length = 23 length = 23
source = Source(length) source = Source(length).hashes
cache = MerkleCache(merkle, source, length) cache = MerkleCache(merkle, source)
cache.branch_and_root(5, 3) await cache.initialize(length)
await cache.branch_and_root(5, 3)
with pytest.raises(TypeError): with pytest.raises(TypeError):
cache.branch_and_root(5.0, 3) await cache.branch_and_root(5.0, 3)
with pytest.raises(TypeError): with pytest.raises(TypeError):
cache.branch_and_root(5, 3.0) await cache.branch_and_root(5, 3.0)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cache.branch_and_root(0, -1) await cache.branch_and_root(0, -1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
cache.branch_and_root(3, 3) await cache.branch_and_root(3, 3)
def test_bad_extension(): @pytest.mark.asyncio
async def test_bad_extension():
length = 5 length = 5
source = Source(length) source = Source(length).hashes
cache = MerkleCache(merkle, source, length) cache = MerkleCache(merkle, source)
await cache.initialize(length)
level = cache.level.copy() level = cache.level.copy()
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
cache.branch_and_root(8, 0) await cache.branch_and_root(8, 0)
# The bad extension should not destroy the cache # The bad extension should not destroy the cache
assert cache.level == level assert cache.level == level
assert cache.length == length assert cache.length == length
def time_it(): async def time_it():
source = Source(500000) source = Source(500000).hashes
cp_length = 492000
import time import time
cache = MerkleCache(merkle, source) cache = MerkleCache(merkle, source)
cp_length = 492000 await cache.initialize(cp_length)
cp_hashes = source.hashes(0, cp_length) cp_hashes = await source(0, cp_length)
brs2 = [] brs2 = []
t1 = time.time() t1 = time.time()
for index in range(5, 400000, 500): for index in range(5, 400000, 500):
brs2.append(cache.branch_and_root(cp_length, index)) brs2.append(await cache.branch_and_root(cp_length, index))
t2 = time.time() t2 = time.time()
print(t2 - t1) print(t2 - t1)
assert False assert False