Make the merkle cache and read_headers async

read_headers runs in a thread to avoid blocking
This commit is contained in:
Neil Booth 2018-08-06 21:27:33 +09:00
parent db5d516756
commit 1efc8cb8ec
6 changed files with 105 additions and 83 deletions

View File

@ -158,13 +158,16 @@ class Merkle(object):
class MerkleCache(object):
'''A cache to calculate merkle branches efficiently.'''
def __init__(self, merkle, source, length):
'''Initialise a cache of length hashes taken from source.'''
def __init__(self, merkle, source_func):
'''Initialise a cache hashes taken from source_func:
async def source_func(index, count):
...
'''
self.merkle = merkle
self.source = source
self.length = length
self.depth_higher = merkle.tree_depth(length) // 2
self.level = self._level(source.hashes(0, length))
self.source_func = source_func
self.length = 0
self.depth_higher = 0
def _segment_length(self):
return 1 << self.depth_higher
@ -179,18 +182,18 @@ class MerkleCache(object):
def _level(self, hashes):
return self.merkle.level(hashes, self.depth_higher)
def _extend_to(self, length):
async def _extend_to(self, length):
'''Extend the length of the cache if necessary.'''
if length <= self.length:
return
# Start from the beginning of any final partial segment.
# Retain the value of depth_higher; in practice this is fine
start = self._leaf_start(self.length)
hashes = self.source.hashes(start, length - start)
hashes = await self.source_func(start, length - start)
self.level[start >> self.depth_higher:] = self._level(hashes)
self.length = length
def _level_for(self, length):
async def _level_for(self, length):
'''Return a (level_length, final_hash) pair for a truncation
of the hashes to the given length.'''
if length == self.length:
@ -198,10 +201,16 @@ class MerkleCache(object):
level = self.level[:length >> self.depth_higher]
leaf_start = self._leaf_start(length)
count = min(self._segment_length(), length - leaf_start)
hashes = self.source.hashes(leaf_start, count)
hashes = await self.source_func(leaf_start, count)
level += self._level(hashes)
return level
async def initialize(self, length):
'''Call to initialize the cache to a source of given length.'''
self.length = length
self.depth_higher = self.merkle.tree_depth(length) // 2
self.level = self._level(await self.source_func(0, length))
def truncate(self, length):
'''Truncate the cache so it covers no more than length underlying
hashes.'''
@ -215,7 +224,7 @@ class MerkleCache(object):
self.length = length
self.level[length >> self.depth_higher:] = []
def branch_and_root(self, length, index):
async def branch_and_root(self, length, index):
'''Return a merkle branch and root. Length is the number of
hashes used to calculate the merkle root, index is the position
of the hash to calculate the branch of.
@ -229,12 +238,12 @@ class MerkleCache(object):
raise ValueError('length must be positive')
if index >= length:
raise ValueError('index must be less than length')
self._extend_to(length)
await self._extend_to(length)
leaf_start = self._leaf_start(index)
count = min(self._segment_length(), length - leaf_start)
leaf_hashes = self.source.hashes(leaf_start, count)
leaf_hashes = await self.source_func(leaf_start, count)
if length < self._segment_length():
return self.merkle.branch_and_root(leaf_hashes, index)
level = self._level_for(length)
level = await self._level_for(length)
return self.merkle.branch_and_root_from_level(
level, leaf_hashes, index, self.depth_higher)

View File

@ -139,12 +139,6 @@ class Prefetcher(object):
return True
class HeaderSource(object):
def __init__(self, db):
self.hashes = db.fs_block_hashes
class ChainError(Exception):
'''Raised on error processing blocks.'''
@ -174,7 +168,7 @@ class BlockProcessor(electrumx.server.db.DB):
# Header merkle cache
self.merkle = Merkle()
self.header_mc = None
self.header_mc = MerkleCache(self.merkle, self.fs_block_hashes)
# Caches of unflushed items.
self.headers = []
@ -251,9 +245,7 @@ class BlockProcessor(electrumx.server.db.DB):
await self.run_in_thread_shielded(self.backup_blocks, raw_blocks)
last -= len(raw_blocks)
# Truncate header_mc: header count is 1 more than the height.
# Note header_mc is None if the reorg happens at startup.
if self.header_mc:
self.header_mc.truncate(self.height + 1)
self.header_mc.truncate(self.height + 1)
await self.prefetcher.reset_height(self.height)
async def reorg_hashes(self, count):
@ -269,7 +261,7 @@ class BlockProcessor(electrumx.server.db.DB):
self.logger.info(f'chain was reorganised replacing {count:,d} '
f'block{s} at heights {start:,d}-{last:,d}')
return start, last, self.fs_block_hashes(start, count)
return start, last, await self.fs_block_hashes(start, count)
async def calc_reorg_range(self, count):
'''Calculate the reorg range'''
@ -287,7 +279,7 @@ class BlockProcessor(electrumx.server.db.DB):
start = self.height - 1
count = 1
while start > 0:
hashes = self.fs_block_hashes(start, count)
hashes = await self.fs_block_hashes(start, count)
hex_hashes = [hash_to_hex_str(hash) for hash in hashes]
d_hex_hashes = await self.daemon.block_hex_hashes(start, count)
n = diff_pos(hex_hashes, d_hex_hashes)
@ -774,7 +766,7 @@ class BlockProcessor(electrumx.server.db.DB):
await self.open_for_serving()
# Populate the header merkle cache
length = max(1, self.height - self.env.reorg_limit)
self.header_mc = MerkleCache(self.merkle, HeaderSource(self), length)
await self.header_mc.initialize(length)
self.logger.info('populated header merkle cache')
async def _first_open_dbs(self):

View File

@ -45,12 +45,12 @@ class ChainState(object):
'db_height': self.db_height(),
}
def header_branch_and_root(self, length, height):
async def header_branch_and_root(self, length, height):
return self._bp.header_mc.branch_and_root(length, height)
async def raw_header(self, height):
'''Return the binary header at the given height.'''
header, n = self._bp.read_headers(height, 1)
header, n = await self.read_headers(height, 1)
if n != 1:
raise IndexError(f'height {height:,d} out of range')
return header

View File

@ -182,7 +182,7 @@ class DB(object):
offset = prior_tx_count * 32
self.hashes_file.write(offset, hashes)
def read_headers(self, start_height, count):
async def read_headers(self, start_height, count):
'''Requires start_height >= 0, count >= 0. Reads as many headers as
are available starting at start_height up to count. This
would be zero if start_height is beyond self.db_height, for
@ -191,16 +191,20 @@ class DB(object):
Returns a (binary, n) pair where binary is the concatenated
binary headers, and n is the count of headers returned.
'''
# Read some from disk
if start_height < 0 or count < 0:
raise self.DBError('{:,d} headers starting at {:,d} not on disk'
.format(count, start_height))
disk_count = max(0, min(count, self.db_height + 1 - start_height))
if disk_count:
offset = self.header_offset(start_height)
size = self.header_offset(start_height + disk_count) - offset
return self.headers_file.read(offset, size), disk_count
return b'', 0
raise self.DBError(f'{count:,d} headers starting at '
f'{start_height:,d} not on disk')
def read_headers():
# Read some from disk
disk_count = max(0, min(count, self.db_height + 1 - start_height))
if disk_count:
offset = self.header_offset(start_height)
size = self.header_offset(start_height + disk_count) - offset
return self.headers_file.read(offset, size), disk_count
return b'', 0
return await run_in_thread(read_headers)
def fs_tx_hash(self, tx_num):
'''Return a par (tx_hash, tx_height) for the given tx number.
@ -213,8 +217,8 @@ class DB(object):
tx_hash = self.hashes_file.read(tx_num * 32, 32)
return tx_hash, tx_height
def fs_block_hashes(self, height, count):
headers_concat, headers_count = self.read_headers(height, count)
async def fs_block_hashes(self, height, count):
headers_concat, headers_count = await self.read_headers(height, count)
if headers_count != count:
raise self.DBError('only got {:,d} headers starting at {:,d}, not '
'{:,d}'.format(headers_count, height, count))

View File

@ -908,14 +908,14 @@ class ElectrumX(SessionBase):
hashX = scripthash_to_hashX(scripthash)
return await self.hashX_subscribe(hashX, scripthash)
def _merkle_proof(self, cp_height, height):
async def _merkle_proof(self, cp_height, height):
max_height = self.db_height()
if not height <= cp_height <= max_height:
raise RPCError(BAD_REQUEST,
f'require header height {height:,d} <= '
f'cp_height {cp_height:,d} <= '
f'chain height {max_height:,d}')
branch, root = self.chain_state.header_branch_and_root(
branch, root = await self.chain_state.header_branch_and_root(
cp_height + 1, height)
return {
'branch': [hash_to_hex_str(elt) for elt in branch],
@ -931,7 +931,7 @@ class ElectrumX(SessionBase):
if cp_height == 0:
return raw_header_hex
result = {'header': raw_header_hex}
result.update(self._merkle_proof(cp_height, height))
result.update(await self._merkle_proof(cp_height, height))
return result
async def block_header_13(self, height):
@ -953,11 +953,12 @@ class ElectrumX(SessionBase):
max_size = self.MAX_CHUNK_SIZE
count = min(count, max_size)
headers, count = self.chain_state.read_headers(start_height, count)
headers, count = await self.chain_state.read_headers(start_height,
count)
result = {'hex': headers.hex(), 'count': count, 'max': max_size}
if count and cp_height:
last_height = start_height + count - 1
result.update(self._merkle_proof(cp_height, last_height))
result.update(await self._merkle_proof(cp_height, last_height))
return result
async def block_headers_12(self, start_height, count):
@ -970,7 +971,7 @@ class ElectrumX(SessionBase):
index = non_negative_integer(index)
size = self.coin.CHUNK_SIZE
start_height = index * size
headers, count = self.chain_state.read_headers(start_height, size)
headers, _ = await self.chain_state.read_headers(start_height, size)
return headers.hex()
async def block_get_header(self, height):

View File

@ -149,72 +149,83 @@ class Source(object):
def __init__(self, length):
self._hashes = [os.urandom(32) for _ in range(length)]
def hashes(self, start, count):
async def hashes(self, start, count):
assert start >= 0
assert start + count <= len(self._hashes)
return self._hashes[start: start + count]
def test_merkle_cache():
@pytest.mark.asyncio
async def test_merkle_cache():
lengths = (*range(1, 18), 31, 32, 33, 57)
source = Source(max(lengths))
source = Source(max(lengths)).hashes
for length in lengths:
cache = MerkleCache(merkle, source, length)
cache = MerkleCache(merkle, source)
await cache.initialize(length)
# Simulate all possible checkpoints
for cp_length in range(1, length + 1):
cp_hashes = source.hashes(0, cp_length)
cp_hashes = await source(0, cp_length)
# All possible indices
for index in range(cp_length):
# Compare correct answer with cache
branch, root = merkle.branch_and_root(cp_hashes, index)
branch2, root2 = cache.branch_and_root(cp_length, index)
branch2, root2 = await cache.branch_and_root(cp_length, index)
assert branch == branch2
assert root == root2
def test_merkle_cache_extension():
source = Source(64)
@pytest.mark.asyncio
async def test_merkle_cache_extension():
source = Source(64).hashes
for length in range(14, 18):
for cp_length in range(30, 36):
cache = MerkleCache(merkle, source, length)
cp_hashes = source.hashes(0, cp_length)
cache = MerkleCache(merkle, source)
await cache.initialize(length)
cp_hashes = await source(0, cp_length)
# All possible indices
for index in range(cp_length):
# Compare correct answer with cache
branch, root = merkle.branch_and_root(cp_hashes, index)
branch2, root2 = cache.branch_and_root(cp_length, index)
branch2, root2 = await cache.branch_and_root(cp_length, index)
assert branch == branch2
assert root == root2
def test_merkle_cache_truncation():
@pytest.mark.asyncio
async def test_merkle_cache_truncation():
max_length = 33
source = Source(max_length)
source = Source(max_length).hashes
for length in range(max_length - 2, max_length + 1):
for trunc_length in range(1, 20, 3):
cache = MerkleCache(merkle, source, length)
cache = MerkleCache(merkle, source)
await cache.initialize(length)
cache.truncate(trunc_length)
assert cache.length <= trunc_length
for cp_length in range(1, length + 1, 3):
cp_hashes = source.hashes(0, cp_length)
cp_hashes = await source(0, cp_length)
# All possible indices
for index in range(cp_length):
# Compare correct answer with cache
branch, root = merkle.branch_and_root(cp_hashes, index)
branch2, root2 = cache.branch_and_root(cp_length, index)
branch2, root2 = await cache.branch_and_root(cp_length,
index)
assert branch == branch2
assert root == root2
# Truncation is a no-op if longer
cache = MerkleCache(merkle, source, 10)
cache = MerkleCache(merkle, source)
await cache.initialize(10)
level = cache.level.copy()
for length in range(10, 13):
cache.truncate(length)
assert cache.level == level
assert cache.length == 10
def test_truncation_bad():
cache = MerkleCache(merkle, Source(10), 10)
@pytest.mark.asyncio
async def test_truncation_bad():
cache = MerkleCache(merkle, Source(10).hashes)
await cache.initialize(10)
with pytest.raises(TypeError):
cache.truncate(1.0)
for n in (-1, 0):
@ -222,43 +233,48 @@ def test_truncation_bad():
cache.truncate(n)
def test_markle_cache_bad():
@pytest.mark.asyncio
async def test_markle_cache_bad():
length = 23
source = Source(length)
cache = MerkleCache(merkle, source, length)
cache.branch_and_root(5, 3)
source = Source(length).hashes
cache = MerkleCache(merkle, source)
await cache.initialize(length)
await cache.branch_and_root(5, 3)
with pytest.raises(TypeError):
cache.branch_and_root(5.0, 3)
await cache.branch_and_root(5.0, 3)
with pytest.raises(TypeError):
cache.branch_and_root(5, 3.0)
await cache.branch_and_root(5, 3.0)
with pytest.raises(ValueError):
cache.branch_and_root(0, -1)
await cache.branch_and_root(0, -1)
with pytest.raises(ValueError):
cache.branch_and_root(3, 3)
await cache.branch_and_root(3, 3)
def test_bad_extension():
@pytest.mark.asyncio
async def test_bad_extension():
length = 5
source = Source(length)
cache = MerkleCache(merkle, source, length)
source = Source(length).hashes
cache = MerkleCache(merkle, source)
await cache.initialize(length)
level = cache.level.copy()
with pytest.raises(AssertionError):
cache.branch_and_root(8, 0)
await cache.branch_and_root(8, 0)
# The bad extension should not destroy the cache
assert cache.level == level
assert cache.length == length
def time_it():
source = Source(500000)
async def time_it():
source = Source(500000).hashes
cp_length = 492000
import time
cache = MerkleCache(merkle, source)
cp_length = 492000
cp_hashes = source.hashes(0, cp_length)
await cache.initialize(cp_length)
cp_hashes = await source(0, cp_length)
brs2 = []
t1 = time.time()
for index in range(5, 400000, 500):
brs2.append(cache.branch_and_root(cp_length, index))
brs2.append(await cache.branch_and_root(cp_length, index))
t2 = time.time()
print(t2 - t1)
assert False