diff --git a/tests/server/test_mempool.py b/tests/server/test_mempool.py new file mode 100644 index 0000000..669d90c --- /dev/null +++ b/tests/server/test_mempool.py @@ -0,0 +1,502 @@ +import logging +import os +from collections import defaultdict +from functools import partial +from random import randrange, choice + +import pytest +from aiorpcx import Event, TaskGroup, sleep, spawn, ignore_after + +from electrumx.server.mempool import MemPool, MemPoolAPI +from electrumx.lib.coins import BitcoinCash +from electrumx.lib.hash import HASHX_LEN, hex_str_to_hash, hash_to_hex_str +from electrumx.lib.tx import Tx, TxInput, TxOutput +from electrumx.lib.util import make_logger + + +coin = BitcoinCash +tx_hash_fn = coin.DESERIALIZER.TX_HASH_FN + + +def random_tx(hash160s, utxos): + '''Create a random TX paying to some of the hash160s using some of the + UTXOS. Return the TX. UTXOs is updated for the effects of the TX. + ''' + inputs = [] + n_inputs = min(randrange(1, 4), len(utxos)) + input_value = 0 + # Create inputs spending random UTXOs. total the inpu + for n in range(n_inputs): + prevout = choice(list(utxos)) + hashX, value = utxos.pop(prevout) + inputs.append(TxInput(prevout[0], prevout[1], b'', 4294967295)) + input_value += value + + fee = min(input_value, randrange(500)) + input_value -= fee + outputs = [] + n_outputs = randrange(1, 4) + for n in range(n_outputs): + value = randrange(input_value) + input_value -= value + pk_script = coin.hash160_to_P2PKH_script(choice(hash160s)) + outputs.append(TxOutput(value, pk_script)) + + tx = Tx(2, inputs, outputs, 0) + tx_bytes = tx.serialize() + tx_hash = tx_hash_fn(tx_bytes) + for n, output in enumerate(tx.outputs): + utxos[(tx_hash, n)] = (coin.hashX_from_script(output.pk_script), + output.value) + return tx, tx_hash, tx_bytes + + +class API(MemPoolAPI): + + def __init__(self): + self._height = 0 + self._cached_height = self._height + # Create a pool of hash160s. Map them to their script hashes + # Create a bunch of UTXOs paying to those script hashes + # Create a bunch of TXs that spend from the UTXO set and create + # new outpus, which are added to the UTXO set for later TXs to + # spend + self.db_utxos = {} + self.on_mempool_calls = [] + self.hashXs = [] + # Maps of mempool txs from tx_hash to raw and Tx object forms + self.raw_txs = {} + self.txs = {} + self.ordered_adds = [] + + def initialize(self, addr_count=100, db_utxo_count=100, mempool_size=50): + hash160s = [os.urandom(20) for n in range(addr_count)] + self.hashXs = [coin.hash160_to_P2PKH_hashX(hash160) + for hash160 in hash160s] + prevouts = [(os.urandom(32), randrange(0, 10)) + for n in range (db_utxo_count)] + random_value = partial(randrange, coin.VALUE_PER_COIN * 10) + self.db_utxos = {prevout: (choice(self.hashXs), random_value()) + for prevout in prevouts} + + unspent_utxos = self.db_utxos.copy() + for n in range(mempool_size): + tx, tx_hash, raw_tx = random_tx(hash160s, unspent_utxos) + self.raw_txs[tx_hash] = raw_tx + self.txs[tx_hash] = tx + self.ordered_adds.append(tx_hash) + + def mempool_utxos(self): + utxos = {} + for tx_hash, tx in self.txs.items(): + for n, output in enumerate(tx.outputs): + hashX = coin.hashX_from_script(output.pk_script) + utxos[(tx_hash, n)] = (hashX, output.value) + return utxos + + def mempool_spends(self): + return [(input.prev_hash, input.prev_idx) + for tx in self.txs.values() for input in tx.inputs] + + def balance_deltas(self): + # Return mempool balance deltas indexed by hashX + deltas = defaultdict(int) + utxos = self.mempool_utxos() + for tx_hash, tx in self.txs.items(): + for n, input in enumerate(tx.inputs): + prevout = (input.prev_hash, input.prev_idx) + if prevout in utxos: + utxos.pop(prevout) + else: + hashX, value = self.db_utxos[prevout] + deltas[hashX] -= value + for hashX, value in utxos.values(): + deltas[hashX] += value + return deltas + + def spends(self): + # Return spends indexed by hashX + spends = defaultdict(list) + utxos = self.mempool_utxos() + for tx_hash, tx in self.txs.items(): + for n, input in enumerate(tx.inputs): + prevout = (input.prev_hash, input.prev_idx) + if prevout in utxos: + hashX, value = utxos.pop(prevout) + else: + hashX, value = self.db_utxos[prevout] + spends[hashX].append(prevout) + return spends + + def summaries(self): + # Return lists of (tx_hash, fee, has_unconfirmed_inputs) by hashX + summaries = defaultdict(list) + utxos = self.mempool_utxos() + for tx_hash, tx in self.txs.items(): + fee = 0 + hashXs = set() + has_ui = False + for n, input in enumerate(tx.inputs): + has_ui = has_ui or (input.prev_hash in self.txs) + prevout = (input.prev_hash, input.prev_idx) + if prevout in utxos: + hashX, value = utxos[prevout] + else: + hashX, value = self.db_utxos[prevout] + hashXs.add(hashX) + fee += value + + for output in tx.outputs: + hashXs.add(coin.hashX_from_script(output.pk_script)) + fee -= output.value + + summary = (tx_hash, fee, has_ui) + for hashX in hashXs: + summaries[hashX].append(summary) + return summaries + + def touched(self, tx_hashes): + touched = set() + utxos = self.mempool_utxos() + for tx_hash in tx_hashes: + tx = self.txs[tx_hash] + for n, input in enumerate(tx.inputs): + prevout = (input.prev_hash, input.prev_idx) + if prevout in utxos: + hashX, value = utxos[prevout] + else: + hashX, value = self.db_utxos[prevout] + touched.add(hashX) + + for output in tx.outputs: + touched.add(coin.hashX_from_script(output.pk_script)) + return touched + + def UTXOs(self): + # Return lists of UTXO 5-tuples by hashX + utxos = defaultdict(list) + for tx_hash, tx in self.txs.items(): + for n, output in enumerate(tx.outputs): + hashX = coin.hashX_from_script(output.pk_script) + utxos[hashX].append((-1, n, tx_hash, 0, output.value)) + return utxos + + async def height(self): + await sleep(0) + self._cached_height = self._height + return self._height + + def cached_height(self): + return self._cached_height + + async def mempool_hashes(self): + '''Query bitcoind for the hashes of all transactions in its + mempool, returned as a list.''' + await sleep(0) + return [hash_to_hex_str(hash) for hash in self.txs] + + async def raw_transactions(self, hex_hashes): + '''Query bitcoind for the serialized raw transactions with the given + hashes. Missing transactions are returned as None. + + hex_hashes is an iterable of hexadecimal hash strings.''' + await sleep(0) + hashes = [hex_str_to_hash(hex_hash) for hex_hash in hex_hashes] + return [self.raw_txs.get(hash) for hash in hashes] + + async def lookup_utxos(self, prevouts): + '''Return a list of (hashX, value) pairs each prevout if unspent, + otherwise return None if spent or not found. + + prevouts - an iterable of (hash, index) pairs + ''' + await sleep(0) + return [self.db_utxos.get(prevout) for prevout in prevouts] + + async def on_mempool(self, touched, height): + '''Called each time the mempool is synchronized. touched is a set of + hashXs touched since the previous call. height is the + daemon's height at the time the mempool was obtained.''' + self.on_mempool_calls.append((touched, height)) + await sleep(0) + + +class DropAPI(API): + + def __init__(self, drop_count): + super().__init__() + self.drop_count = drop_count + self.dropped = False + + async def raw_transactions(self, hex_hashes): + if not self.dropped: + self.dropped = True + for hash in self.ordered_adds[-self.drop_count:]: + del self.raw_txs[hash] + del self.txs[hash] + return await super().raw_transactions(hex_hashes) + + +def in_caplog(caplog, message): + return any(message in record.message for record in caplog.records) + + +@pytest.mark.asyncio +async def test_keep_synchronized(caplog): + api = API() + mempool = MemPool(coin, api) + event = Event() + with caplog.at_level(logging.INFO): + async with TaskGroup() as group: + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + await group.cancel_remaining() + + assert in_caplog(caplog, 'beginning processing of daemon mempool') + assert in_caplog(caplog, 'compact fee histogram') + assert in_caplog(caplog, 'synced in ') + assert in_caplog(caplog, '0 txs touching 0 addresses') + assert not in_caplog(caplog, 'txs dropped') + + +@pytest.mark.asyncio +async def test_balance_delta(): + api = API() + api.initialize() + mempool = MemPool(coin, api) + event = Event() + async with TaskGroup() as group: + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + await group.cancel_remaining() + + # Check the default dict is handled properly + prior_len = len(mempool.hashXs) + assert await mempool.balance_delta(os.urandom(HASHX_LEN)) == 0 + assert prior_len == len(mempool.hashXs) + + # Test all hashXs + deltas = api.balance_deltas() + for hashX in api.hashXs: + expected = deltas.get(hashX, 0) + assert await mempool.balance_delta(hashX) == expected + + +@pytest.mark.asyncio +async def test_compact_fee_histogram(): + api = API() + api.initialize() + mempool = MemPool(coin, api) + event = Event() + async with TaskGroup() as group: + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + await group.cancel_remaining() + + histogram = await mempool.compact_fee_histogram() + assert histogram == [] + bin_size = 1000 + mempool._update_histogram(bin_size) + histogram = await mempool.compact_fee_histogram() + assert len(histogram) > 0 + rates, sizes = zip(*histogram) + assert all(rates[n] < rates[n - 1] for n in range(1, len(rates))) + assert all(size > bin_size * 0.95 for size in sizes) + + +@pytest.mark.asyncio +async def test_potential_spends(): + api = API() + api.initialize() + mempool = MemPool(coin, api) + event = Event() + async with TaskGroup() as group: + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + await group.cancel_remaining() + + # Check the default dict is handled properly + prior_len = len(mempool.hashXs) + assert await mempool.potential_spends(os.urandom(HASHX_LEN)) == set() + assert prior_len == len(mempool.hashXs) + + # Test all hashXs + spends = api.spends() + for hashX in api.hashXs: + ps = await mempool.potential_spends(hashX) + assert all(spend in ps for spend in spends[hashX]) + + +async def _test_summaries(mempool, api): + # Test all hashXs + summaries = api.summaries() + for hashX in api.hashXs: + mempool_result = await mempool.transaction_summaries(hashX) + mempool_result = [(item.hash, item.fee, item.has_unconfirmed_inputs) + for item in mempool_result] + our_result = summaries.get(hashX, []) + assert set(our_result) == set(mempool_result) + + +@pytest.mark.asyncio +async def test_transaction_summaries(caplog): + api = API() + api.initialize() + mempool = MemPool(coin, api) + event = Event() + with caplog.at_level(logging.INFO): + async with TaskGroup() as group: + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + await group.cancel_remaining() + + # Check the default dict is handled properly + prior_len = len(mempool.hashXs) + assert await mempool.transaction_summaries(os.urandom(HASHX_LEN)) == [] + assert prior_len == len(mempool.hashXs) + + await _test_summaries(mempool, api) + assert not in_caplog(caplog, 'txs dropped') + + +@pytest.mark.asyncio +async def test_unordered_UTXOs(): + api = API() + api.initialize() + mempool = MemPool(coin, api) + event = Event() + async with TaskGroup() as group: + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + await group.cancel_remaining() + + # Check the default dict is handled properly + prior_len = len(mempool.hashXs) + assert await mempool.unordered_UTXOs(os.urandom(HASHX_LEN)) == [] + assert prior_len == len(mempool.hashXs) + + # Test all hashXs + utxos = api.UTXOs() + for hashX in api.hashXs: + mempool_result = await mempool.unordered_UTXOs(hashX) + our_result = utxos.get(hashX, []) + assert set(our_result) == set(mempool_result) + + +@pytest.mark.asyncio +async def test_mempool_removals(): + api = API() + api.initialize() + mempool = MemPool(coin, api, refresh_secs=0.01) + event = Event() + async with TaskGroup() as group: + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + # Remove half the TXs from the mempool + start = len(api.ordered_adds) // 2 + for tx_hash in api.ordered_adds[start:]: + del api.txs[tx_hash] + del api.raw_txs[tx_hash] + await event.wait() + await _test_summaries(mempool, api) + # Removed hashXs should have key destroyed + assert all(mempool.hashXs.values()) + # Remove the rest + api.txs.clear() + api.raw_txs.clear() + await event.wait() + await _test_summaries(mempool, api) + assert not mempool.hashXs + assert not mempool.txs + await group.cancel_remaining() + + +@pytest.mark.asyncio +async def test_daemon_drops_txs(): + # Tests things work if the daemon drops some transactions between + # returning their hashes and the mempool requesting the raw txs + api = DropAPI(10) + api.initialize() + mempool = MemPool(coin, api, refresh_secs=0.01) + event = Event() + async with TaskGroup() as group: + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + await _test_summaries(mempool, api) + await group.cancel_remaining() + + +@pytest.mark.asyncio +async def test_notifications(): + # Tests notifications over a cycle of: + # 1) A first batch of txs come in + # 2) A second batch of txs come in + # 3) A block comes in confirming the first batch only + api = API() + api.initialize() + mempool = MemPool(coin, api, refresh_secs=0.001, log_status_secs=0) + event = Event() + + n = len(api.ordered_adds) // 2 + raw_txs = api.raw_txs.copy() + txs = api.txs.copy() + first_hashes = api.ordered_adds[:n] + first_touched = api.touched(first_hashes) + second_hashes = api.ordered_adds[n:] + second_touched = api.touched(second_hashes) + + async with TaskGroup() as group: + # First batch enters the mempool + api.raw_txs = {hash: raw_txs[hash] for hash in first_hashes} + api.txs = {hash: txs[hash] for hash in first_hashes} + first_utxos = api.mempool_utxos() + first_spends = api.mempool_spends() + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + assert len(api.on_mempool_calls) == 1 + touched, height = api.on_mempool_calls[0] + assert height == api._height == api._cached_height + assert touched == first_touched + # Second batch enters the mempool + api.raw_txs = raw_txs + api.txs = txs + await event.wait() + assert len(api.on_mempool_calls) == 2 + touched, height = api.on_mempool_calls[1] + assert height == api._height == api._cached_height + # Touched is incremental + assert touched == second_touched + # Block found; first half confirm + new_height = 2 + api._height = new_height + api.db_utxos.update(first_utxos) + for spend in first_spends: + del api.db_utxos[spend] + api.raw_txs = {hash: raw_txs[hash] for hash in second_hashes} + api.txs = {hash: txs[hash] for hash in second_hashes} + await event.wait() + assert len(api.on_mempool_calls) == 3 + touched, height = api.on_mempool_calls[2] + assert height == api._height == api._cached_height == new_height + assert touched == first_touched + await group.cancel_remaining() + + +@pytest.mark.asyncio +async def test_dropped_txs(caplog): + api = API() + api.initialize() + mempool = MemPool(coin, api) + event = Event() + # Remove a single TX_HASH that is used in another mempool tx + for prev_hash, prev_idx in api.mempool_spends(): + if prev_hash in api.txs: + del api.txs[prev_hash] + + with caplog.at_level(logging.INFO): + async with TaskGroup() as group: + await group.spawn(mempool.keep_synchronized, event) + await event.wait() + await group.cancel_remaining() + + assert in_caplog(caplog, 'txs dropped')