From 0963ce5230cd51840f586eac744fcb1fdea41c37 Mon Sep 17 00:00:00 2001 From: Neil Booth Date: Mon, 23 Jul 2018 07:03:21 +0800 Subject: [PATCH] Completely overhaul mempool sync logic - highly concurrent and a lot more efficient than previously - initial mempool sync should be much faster (feedback please) - mempool processing no longer blocks client session handling - uses less memory to store the mempool - fixes an obscure bug where sometimes txs were dropped - more robust, clean and easy to understand Fixes #433 --- electrumx/server/mempool.py | 351 ++++++++++++++++-------------------- electrumx/server/session.py | 6 +- 2 files changed, 162 insertions(+), 195 deletions(-) diff --git a/electrumx/server/mempool.py b/electrumx/server/mempool.py index d2cc41c..ec54e6b 100644 --- a/electrumx/server/mempool.py +++ b/electrumx/server/mempool.py @@ -15,13 +15,12 @@ from collections import defaultdict import attr from electrumx.lib.hash import hash_to_hex_str, hex_str_to_hash -from electrumx.lib.util import class_logger +from electrumx.lib.util import class_logger, chunks from electrumx.server.db import UTXO @attr.s(slots=True) class MemPoolTx(object): - hash = attr.ib() in_pairs = attr.ib() out_pairs = attr.ib() fee = attr.ib() @@ -36,10 +35,10 @@ class MemPool(object): To that end we maintain the following maps: - tx_hash -> MemPoolTx + tx_hash -> MemPoolTx (in_paris hashX -> set of all tx hashes in which the hashX appears - A pair is a (hashX, value) tuple. tx hashes are hex strings. + A pair is a (hashX, value) tuple. tx hashes are binary not strings. ''' def __init__(self, coin, tasks, daemon, notifications, lookup_utxos): @@ -61,201 +60,164 @@ class MemPool(object): f'touching {len(self.hashXs):,d} addresses') await asyncio.sleep(120) - async def _synchronize_forever(self): - while True: - await asyncio.sleep(5) - await self._synchronize(False) + def _accept_transactions(self, tx_map, utxo_map, touched): + '''Accept transactions in tx_map to the mempool if all their inputs + can be found in the existing mempool or a utxo_map from the + DB. - async def _refresh_hashes(self): - '''Return a (hash set, height) pair when we're sure which height they - are for.''' - while True: - height = self.daemon.cached_height() - hashes = await self.daemon.mempool_hashes() - if height == await self.daemon.height(): - return set(hashes), height - - async def _synchronize(self, first_time): - '''Asynchronously maintain mempool status with daemon. - - Processes the mempool each time the mempool refresh event is - signalled. + Returns an (unprocessed tx_map, unspent utxo_map) pair. ''' - unprocessed = {} - unfetched = set() - touched = set() - txs = self.txs - next_refresh = 0 - fetch_size = 800 - process_some = self._async_process_some(fetch_size // 2) - - while True: - now = time.time() - # If processing a large mempool, a block being found might - # shrink our work considerably, so refresh our view every 20s - if now > next_refresh: - hashes, height = await self._refresh_hashes() - self._resync_hashes(hashes, unprocessed, unfetched, touched) - next_refresh = time.time() + 20 - - # Log progress of initial sync - todo = len(unfetched) + len(unprocessed) - if first_time: - pct = (len(txs) - todo) * 100 // len(txs) if txs else 0 - self.logger.info(f'catchup {pct:d}% complete ' - f'({todo:,d} txs left)') - if not todo: - break - - # FIXME: parallelize - if unfetched: - count = min(len(unfetched), fetch_size) - hex_hashes = [unfetched.pop() for n in range(count)] - unprocessed.update(await self._fetch_raw_txs(hex_hashes)) - if unprocessed: - await process_some(unprocessed, touched) - - await self.notifications.on_mempool(touched, height) - - def _resync_hashes(self, hashes, unprocessed, unfetched, touched): - '''Re-sync self.txs with the list of hashes in the daemon's mempool. - - Additionally, remove gone hashes from unprocessed and - unfetched. Add new ones to unprocessed. - ''' - txs = self.txs hashXs = self.hashXs - fee_hist = self.fee_histogram - gone = set(txs).difference(hashes) - for hex_hash in gone: - unfetched.discard(hex_hash) - unprocessed.pop(hex_hash, None) - tx = txs.pop(hex_hash) - if tx: - fee_rate = tx.fee // tx.size - fee_hist[fee_rate] -= tx.size - if fee_hist[fee_rate] == 0: - fee_hist.pop(fee_rate) - tx_hashXs = set(hashX for hashX, value in tx.in_pairs) - tx_hashXs.update(hashX for hashX, value in tx.out_pairs) - for hashX in tx_hashXs: - hashXs[hashX].remove(hex_hash) - if not hashXs[hashX]: - del hashXs[hashX] - touched.update(tx_hashXs) - - new = hashes.difference(txs) - unfetched.update(new) - for hex_hash in new: - txs[hex_hash] = None - - def _async_process_some(self, limit): - pending = [] txs = self.txs - - async def process(unprocessed, touched): - nonlocal pending - - raw_txs = {} - - while unprocessed and len(raw_txs) < limit: - hex_hash, raw_tx = unprocessed.popitem() - raw_txs[hex_hash] = raw_tx - - if unprocessed: - deferred = [] - else: - deferred = pending - pending = [] - - deferred = await self._process_raw_txs(raw_txs, deferred, touched) - pending.extend(deferred) - - return process - - async def _fetch_raw_txs(self, hex_hashes): - '''Fetch a list of mempool transactions.''' - raw_txs = await self.daemon.getrawtransactions(hex_hashes) - - # Skip hashes the daemon has dropped. Either they were - # evicted or they got in a block. - return {hh: raw for hh, raw in zip(hex_hashes, raw_txs) if raw} - - async def _process_raw_txs(self, raw_tx_map, pending, touched): - '''Process the dictionary of raw transactions and return a dictionary - of updates to apply to self.txs. - ''' - def deserialize_txs(): - script_hashX = self.coin.hashX_from_script - deserializer = self.coin.DESERIALIZER - - # Deserialize each tx and put it in a pending list - for tx_hash, raw_tx in raw_tx_map.items(): - tx, tx_size = deserializer(raw_tx).read_tx_and_vsize() - - # Convert the tx outputs into (hashX, value) pairs - txout_pairs = [(script_hashX(txout.pk_script), txout.value) - for txout in tx.outputs] - - # Convert the tx inputs to ([prev_hex_hash, prev_idx) pairs - txin_pairs = [(hash_to_hex_str(txin.prev_hash), txin.prev_idx) - for txin in tx.inputs] - - pending.append(MemPoolTx(tx_hash, txin_pairs, txout_pairs, - 0, tx_size)) - - # Do this potentially slow operation in a thread so as not to - # block - await self.tasks.run_in_thread(deserialize_txs) - - # The transaction inputs can be from other mempool - # transactions (which may or may not be processed yet) or are - # otherwise presumably in the DB. - txs = self.txs - db_prevouts = [(hex_str_to_hash(prev_hash), prev_idx) - for tx in pending - for (prev_hash, prev_idx) in tx.in_pairs - if prev_hash not in txs] - - # If a lookup fails, it returns a None entry - db_utxos = await self.lookup_utxos(db_prevouts) - db_utxo_map = {(hash_to_hex_str(prev_hash), prev_idx): db_utxo - for (prev_hash, prev_idx), db_utxo - in zip(db_prevouts, db_utxos)} - - deferred = [] - hashXs = self.hashXs fee_hist = self.fee_histogram + init_count = len(utxo_map) - for tx in pending: - if tx.hash not in txs: - continue - + deferred = {} + unspent = set(utxo_map) + # Try to find all previns so we can accept the TX + for hash, tx in tx_map.items(): in_pairs = [] try: for previn in tx.in_pairs: - utxo = db_utxo_map.get(previn) + utxo = utxo_map.get(previn) if not utxo: prev_hash, prev_index = previn - # This can raise a KeyError or TypeError - utxo = txs[prev_hash][1][prev_index] + # Raises KeyError if prev_hash is not in txs + utxo = txs[prev_hash].out_pairs[prev_index] in_pairs.append(utxo) - except (KeyError, TypeError): - deferred.append(tx) + except KeyError: + deferred[hash] = tx continue + # Spend the previns + unspent.difference_update(tx.in_pairs) + + # Convert in_pairs and add the TX to tx.in_pairs = in_pairs # Compute fee tx_fee = (sum(v for hashX, v in tx.in_pairs) - sum(v for hashX, v in tx.out_pairs)) fee_rate = tx.fee // tx.size fee_hist[fee_rate] += tx.size - txs[tx.hash] = tx + txs[hash] = tx for hashX, value in itertools.chain(tx.in_pairs, tx.out_pairs): touched.add(hashX) - hashXs[hashX].add(tx.hash) + hashXs[hashX].add(hash) - return deferred + return deferred, {previn: utxo_map[previn] for previn in unspent} + + async def _refresh_hashes(self, single_pass): + '''Return a (hash set, height) pair when we're sure which height they + are for.''' + refresh_event = asyncio.Event() + loop = self.tasks.loop + while True: + height = self.daemon.cached_height() + hex_hashes = await self.daemon.mempool_hashes() + if height != await self.daemon.height(): + continue + loop.call_later(5, refresh_event.set) + hashes = set(hex_str_to_hash(hh) for hh in hex_hashes) + touched = await self._process_mempool(hashes) + await self.notifications.on_mempool(touched, height) + if single_pass: + return + await refresh_event.wait() + refresh_event.clear() + + async def _process_mempool(self, all_hashes): + # Re-sync with the new set of hashes + txs = self.txs + hashXs = self.hashXs + touched = set() + fee_hist = self.fee_histogram + + # First handle txs that have disappeared + for tx_hash in set(txs).difference(all_hashes): + tx = txs.pop(tx_hash) + fee_rate = tx.fee // tx.size + fee_hist[fee_rate] -= tx.size + if fee_hist[fee_rate] == 0: + fee_hist.pop(fee_rate) + tx_hashXs = set(hashX for hashX, value in tx.in_pairs) + tx_hashXs.update(hashX for hashX, value in tx.out_pairs) + for hashX in tx_hashXs: + hashXs[hashX].remove(tx_hash) + if not hashXs[hashX]: + del hashXs[hashX] + touched.update(tx_hashXs) + + # Process new transactions + new_hashes = list(all_hashes.difference(txs)) + jobs = [self.tasks.create_task(self._fetch_and_accept + (hashes, all_hashes, touched)) + for hashes in chunks(new_hashes, 2000)] + if jobs: + await asyncio.wait(jobs) + tx_map = {} + utxo_map = {} + for job in jobs: + deferred, unspent = job.result() + tx_map.update(deferred) + utxo_map.update(unspent) + + # Handle the stragglers + if len(tx_map) >= 10: + self.logger.info(f'{len(tx_map)} stragglers') + prior_count = 0 + # FIXME: this is not particularly efficient + while tx_map and len(tx_map) != prior_count: + prior_count = len(tx_map) + tx_map, utxo_map = self._accept_transactions(tx_map, utxo_map, + touched) + if tx_map: + self.logger.info(f'{len(tx_map)} txs dropped') + + return touched + + async def _fetch_and_accept(self, hashes, all_hashes, touched): + '''Fetch a list of mempool transactions.''' + hex_hashes = [hash_to_hex_str(hash) for hash in hashes] + raw_txs = await self.daemon.getrawtransactions(hex_hashes) + count = len([raw_tx for raw_tx in raw_txs if raw_tx]) + + def deserialize_txs(): + # This function is pure + script_hashX = self.coin.hashX_from_script + deserializer = self.coin.DESERIALIZER + + txs = {} + for hash, raw_tx in zip(hashes, raw_txs): + # The daemon may have evicted the tx from its + # mempool or it may have gotten in a block + if not raw_tx: + continue + tx, tx_size = deserializer(raw_tx).read_tx_and_vsize() + + # Convert the tx outputs into (hashX, value) pairs + txout_pairs = [(script_hashX(txout.pk_script), txout.value) + for txout in tx.outputs] + + # Convert the tx inputs to (prev_hash, prev_idx) pairs + txin_pairs = [(txin.prev_hash, txin.prev_idx) + for txin in tx.inputs] + + txs[hash] = MemPoolTx(txin_pairs, txout_pairs, 0, tx_size) + return txs + + # Thread this potentially slow operation so as not to block + tx_map = await self.tasks.run_in_thread(deserialize_txs) + + # Determine all prevouts not in the mempool, and fetch the + # UTXO information from the database. Failed prevout lookups + # return None - concurrent database updates happen + prevouts = [tx_in for tx in tx_map.values()for tx_in in tx.in_pairs + if tx_in[0] not in all_hashes] + utxos = await self.lookup_utxos(prevouts) + utxo_map = {prevout: utxo for prevout, utxo in zip(prevouts, utxos)} + + # Attempt to complete processing of txs + return self._accept_transactions(tx_map, utxo_map, touched) async def _raw_transactions(self, hashX): '''Returns an iterable of (hex_hash, raw_tx) pairs for all @@ -267,9 +229,10 @@ class MemPool(object): if hashX not in self.hashXs: return [] - hex_hashes = self.hashXs[hashX] + hashes = self.hashXs[hashX] + hex_hashes = [hash_to_hex_str(hash) for hash in hashes] raw_txs = await self.daemon.getrawtransactions(hex_hashes) - return zip(hex_hashes, raw_txs) + return zip(hashes, raw_txs) def _calc_compact_histogram(self): # For efficiency, get_fees returns a compact histogram with @@ -300,9 +263,12 @@ class MemPool(object): ''' self.logger.info('beginning processing of daemon mempool. ' 'This can take some time...') - await self._synchronize(True) + start = time.time() + await self._refresh_hashes(True) + elapsed = time.time() - start + self.logger.info(f'synced in {elapsed:.2f}s') self.tasks.create_task(self._log_stats()) - self.tasks.create_task(self._synchronize_forever()) + self.tasks.create_task(self._refresh_hashes(False)) async def balance_delta(self, hashX): '''Return the unconfirmed amount in the mempool for hashX. @@ -312,8 +278,8 @@ class MemPool(object): value = 0 # hashXs is a defaultdict if hashX in self.hashXs: - for hex_hash in self.hashXs[hashX]: - tx = self.txs[hex_hash] + for hash in self.hashXs[hashX]: + tx = self.txs[hash] value -= sum(v for h168, v in tx.in_pairs if h168 == hashX) value += sum(v for h168, v in tx.out_pairs if h168 == hashX) return value @@ -335,7 +301,7 @@ class MemPool(object): deserializer = self.coin.DESERIALIZER pairs = await self._raw_transactions(hashX) result = set() - for hex_hash, raw_tx in pairs: + for hash, raw_tx in pairs: if not raw_tx: continue tx = deserializer(raw_tx).read_tx() @@ -344,7 +310,7 @@ class MemPool(object): return result async def transaction_summaries(self, hashX): - '''Return a list of (tx_hex_hash, tx_fee, unconfirmed) tuples for + '''Return a list of (tx_hash, tx_fee, unconfirmed) tuples for mempool entries for the hashX. unconfirmed is True if any txin is unconfirmed. @@ -352,14 +318,15 @@ class MemPool(object): deserializer = self.coin.DESERIALIZER pairs = await self._raw_transactions(hashX) result = [] - for hex_hash, raw_tx in pairs: - mempool_tx = self.txs.get(hex_hash) + for tx_hash, raw_tx in pairs: + mempool_tx = self.txs.get(tx_hash) if not mempool_tx or not raw_tx: continue tx = deserializer(raw_tx).read_tx() - unconfirmed = any(hash_to_hex_str(txin.prev_hash) in self.txs + # FIXME: use all_hashes not self.txs + unconfirmed = any(txin.prev_hash in self.txs for txin in tx.inputs) - result.append((hex_hash, mempool_tx.fee, unconfirmed)) + result.append((tx_hash, mempool_tx.fee, unconfirmed)) return result async def unordered_UTXOs(self, hashX): @@ -371,13 +338,11 @@ class MemPool(object): ''' utxos = [] # hashXs is a defaultdict, so use get() to query - for hex_hash in self.hashXs.get(hashX, []): - tx = self.txs.get(hex_hash) + for tx_hash in self.hashXs.get(hashX, []): + tx = self.txs.get(tx_hash) if not tx: continue for pos, (hX, value) in enumerate(tx.out_pairs): if hX == hashX: - # Unfortunately UTXO holds a binary hash - utxos.append(UTXO(-1, pos, hex_str_to_hash(hex_hash), - 0, value)) + utxos.append(UTXO(-1, pos, tx_hash, 0, value)) return utxos diff --git a/electrumx/server/session.py b/electrumx/server/session.py index 8e7f958..a0c8c7c 100644 --- a/electrumx/server/session.py +++ b/electrumx/server/session.py @@ -733,7 +733,8 @@ class ElectrumX(SessionBase): status = ''.join('{}:{:d}:'.format(hash_to_hex_str(tx_hash), height) for tx_hash, height in history) - status += ''.join('{}:{:d}:'.format(hex_hash, -unconfirmed) + status += ''.join('{}:{:d}:'.format(hash_to_hex_str(hex_hash), + -unconfirmed) for hex_hash, tx_fee, unconfirmed in mempool) if status: status = sha256(status.encode()).hex() @@ -821,7 +822,8 @@ class ElectrumX(SessionBase): # Note unconfirmed history is unordered in electrum-server # Height is -1 if unconfirmed txins, otherwise 0 mempool = await self.mempool.transaction_summaries(hashX) - return [{'tx_hash': tx_hash, 'height': -unconfirmed, 'fee': fee} + return [{'tx_hash': hash_to_hex_str(tx_hash), 'height': -unconfirmed, + 'fee': fee} for tx_hash, fee, unconfirmed in mempool] async def confirmed_and_unconfirmed_history(self, hashX):