Implement deserialization of SegWit transactions

tx_hash needs to be that of the prior serialization, so
need to change internal read_block API.

Bitcoin core 0.13.1 broke backwards compat of the RPC interface.
Closes #92
This commit is contained in:
Neil Booth 2017-01-07 16:15:10 +09:00
parent 2294f5c791
commit 852753cb94
6 changed files with 115 additions and 40 deletions

View File

@ -21,7 +21,7 @@ import sys
from lib.hash import Base58, hash160, ripemd160, double_sha256, hash_to_str from lib.hash import Base58, hash160, ripemd160, double_sha256, hash_to_str
from lib.script import ScriptPubKey from lib.script import ScriptPubKey
from lib.tx import Deserializer from lib.tx import Deserializer, DeserializerSegWit
from lib.util import cachedproperty, subclasses from lib.util import cachedproperty, subclasses
@ -204,11 +204,14 @@ class Coin(object):
@classmethod @classmethod
def read_block(cls, block, height): def read_block(cls, block, height):
'''Return a tuple (header, tx_hashes, txs) given a raw block at '''Returns a pair (header, tx_list) given a raw block and height.
the given height.'''
tx_list is a list of (deserialized_tx, tx_hash) pairs.
'''
deserializer = cls.deserializer()
hlen = cls.header_len(height) hlen = cls.header_len(height)
header, rest = block[:hlen], block[hlen:] header, rest = block[:hlen], block[hlen:]
return (header, ) + Deserializer(rest).read_block() return (header, deserializer(rest).read_block())
@classmethod @classmethod
def decimal_value(cls, value): def decimal_value(cls, value):
@ -234,6 +237,10 @@ class Coin(object):
'nonce': nonce, 'nonce': nonce,
} }
@classmethod
def deserializer(cls):
return Deserializer
class Bitcoin(Coin): class Bitcoin(Coin):
NAME = "Bitcoin" NAME = "Bitcoin"
@ -271,6 +278,19 @@ class BitcoinTestnet(Bitcoin):
IRC_PREFIX = "ET_" IRC_PREFIX = "ET_"
class BitcoinTestnetSegWit(BitcoinTestnet):
'''Bitcoin Testnet for Core bitcoind >= 0.13.1.
Unfortunately 0.13.1 broke backwards compatibility of the RPC
interface's TX serialization, SegWit transactions serialize
differently than with earlier versions. If you are using such a
bitcoind on testnet, you must use this class as your "COIN".
'''
@classmethod
def deserializer(cls):
return DeserializerSegWit
class Litecoin(Coin): class Litecoin(Coin):
NAME = "Litecoin" NAME = "Litecoin"
SHORTNAME = "LTC" SHORTNAME = "LTC"

View File

@ -72,28 +72,25 @@ class Deserializer(object):
self.cursor = 0 self.cursor = 0
def read_tx(self): def read_tx(self):
'''Return a (Deserialized TX, TX_HASH) pair.
The hash needs to be reversed for human display; for efficiency
we process it in the natural serialized order.
'''
start = self.cursor
return Tx( return Tx(
self._read_le_int32(), # version self._read_le_int32(), # version
self._read_inputs(), # inputs self._read_inputs(), # inputs
self._read_outputs(), # outputs self._read_outputs(), # outputs
self._read_le_uint32() # locktime self._read_le_uint32() # locktime
) ), double_sha256(self.binary[start:self.cursor])
def read_block(self): def read_block(self):
tx_hashes = [] '''Returns a list of (deserialized_tx, tx_hash) pairs.'''
txs = []
binary = self.binary
hash = double_sha256
read_tx = self.read_tx read_tx = self.read_tx
append_hash = tx_hashes.append txs = [read_tx() for n in range(self._read_varint())]
for n in range(self._read_varint()): assert self.cursor == len(self.binary)
start = self.cursor return txs
txs.append(read_tx())
# Note this hash needs to be reversed for human display
# For efficiency we store it in the natural serialized order
append_hash(hash(binary[start:self.cursor]))
assert self.cursor == len(binary)
return tx_hashes, txs
def _read_inputs(self): def _read_inputs(self):
read_input = self._read_input read_input = self._read_input
@ -161,3 +158,62 @@ class Deserializer(object):
result, = unpack_from('<Q', self.binary, self.cursor) result, = unpack_from('<Q', self.binary, self.cursor)
self.cursor += 8 self.cursor += 8
return result return result
class TxSegWit(namedtuple("Tx", "version marker flag inputs outputs "
"witness locktime")):
'''Class representing a SegWit transaction.'''
@cachedproperty
def is_coinbase(self):
return self.inputs[0].is_coinbase
class DeserializerSegWit(Deserializer):
# https://bitcoincore.org/en/segwit_wallet_dev/#transaction-serialization
def _read_byte(self):
cursor = self.cursor
self.cursor += 1
return self.binary[cursor]
def _read_witness(self, fields):
read_witness_field = self._read_witness_field
return [read_witness_field() for i in range(fields)]
def _read_witness_field(self):
read_varbytes = self._read_varbytes
return [read_varbytes() for i in range(self._read_varint())]
def read_tx(self):
'''Return a (Deserialized TX, TX_HASH) pair.
The hash needs to be reversed for human display; for efficiency
we process it in the natural serialized order.
'''
marker = self.binary[self.cursor + 4]
if marker:
return super().read_tx()
# Ugh, this is nasty.
start = self.cursor
version = self._read_le_int32()
orig_ser = self.binary[start:self.cursor]
marker = self._read_byte()
flag = self._read_byte()
start = self.cursor
inputs = self._read_inputs()
outputs = self._read_outputs()
orig_ser += self.binary[start:self.cursor]
witness = self._read_witness(len(inputs))
start = self.cursor
locktime = self._read_le_uint32()
orig_ser += self.binary[start:self.cursor]
return TxSegWit(version, marker, flag, inputs,
outputs, witness, locktime), double_sha256(orig_ser)

View File

@ -444,7 +444,9 @@ class BlockProcessor(server.db.DB):
utxo_cache_size = len(self.utxo_cache) * 205 utxo_cache_size = len(self.utxo_cache) * 205
db_deletes_size = len(self.db_deletes) * 57 db_deletes_size = len(self.db_deletes) * 57
hist_cache_size = len(self.history) * 180 + self.history_size * 4 hist_cache_size = len(self.history) * 180 + self.history_size * 4
tx_hash_size = (self.tx_count - self.fs_tx_count) * 74 # Roughly ntxs * 32 + nblocks * 42
tx_hash_size = ((self.tx_count - self.fs_tx_count) * 32
+ (self.height - self.fs_height) * 42)
utxo_MB = (db_deletes_size + utxo_cache_size) // one_MB utxo_MB = (db_deletes_size + utxo_cache_size) // one_MB
hist_MB = (hist_cache_size + tx_hash_size) // one_MB hist_MB = (hist_cache_size + tx_hash_size) // one_MB
@ -458,28 +460,28 @@ class BlockProcessor(server.db.DB):
if utxo_MB + hist_MB >= self.cache_MB or hist_MB >= self.cache_MB // 5: if utxo_MB + hist_MB >= self.cache_MB or hist_MB >= self.cache_MB // 5:
self.flush(utxo_MB >= self.cache_MB * 4 // 5) self.flush(utxo_MB >= self.cache_MB * 4 // 5)
def fs_advance_block(self, header, tx_hashes, txs): def fs_advance_block(self, header, txs):
'''Update unflushed FS state for a new block.''' '''Update unflushed FS state for a new block.'''
prior_tx_count = self.tx_counts[-1] if self.tx_counts else 0 prior_tx_count = self.tx_counts[-1] if self.tx_counts else 0
# Cache the new header, tx hashes and cumulative tx count # Cache the new header, tx hashes and cumulative tx count
self.headers.append(header) self.headers.append(header)
self.tx_hashes.append(tx_hashes) self.tx_hashes.append(b''.join(tx_hash for tx, tx_hash in txs))
self.tx_counts.append(prior_tx_count + len(txs)) self.tx_counts.append(prior_tx_count + len(txs))
def advance_block(self, block, touched): def advance_block(self, block, touched):
header, tx_hashes, txs = self.coin.read_block(block, self.height + 1) header, txs = self.coin.read_block(block, self.height + 1)
if self.tip != self.coin.header_prevhash(header): if self.tip != self.coin.header_prevhash(header):
raise ChainReorg raise ChainReorg
self.fs_advance_block(header, tx_hashes, txs) self.fs_advance_block(header, txs)
self.tip = self.coin.header_hash(header) self.tip = self.coin.header_hash(header)
self.height += 1 self.height += 1
undo_info = self.advance_txs(tx_hashes, txs, touched) undo_info = self.advance_txs(txs, touched)
if self.daemon.cached_height() - self.height <= self.env.reorg_limit: if self.daemon.cached_height() - self.height <= self.env.reorg_limit:
self.write_undo_info(self.height, b''.join(undo_info)) self.write_undo_info(self.height, b''.join(undo_info))
def advance_txs(self, tx_hashes, txs, touched): def advance_txs(self, txs, touched):
undo_info = [] undo_info = []
# Use local vars for speed in the loops # Use local vars for speed in the loops
@ -492,7 +494,7 @@ class BlockProcessor(server.db.DB):
spend_utxo = self.spend_utxo spend_utxo = self.spend_utxo
undo_info_append = undo_info.append undo_info_append = undo_info.append
for tx, tx_hash in zip(txs, tx_hashes): for tx, tx_hash in txs:
hashXs = set() hashXs = set()
add_hashX = hashXs.add add_hashX = hashXs.add
tx_numb = s_pack('<I', tx_num) tx_numb = s_pack('<I', tx_num)
@ -533,14 +535,14 @@ class BlockProcessor(server.db.DB):
self.assert_flushed() self.assert_flushed()
for block in blocks: for block in blocks:
header, tx_hashes, txs = self.coin.read_block(block, self.height) header, txs = self.coin.read_block(block, self.height)
header_hash = self.coin.header_hash(header) header_hash = self.coin.header_hash(header)
if header_hash != self.tip: if header_hash != self.tip:
raise ChainError('backup block {} is not tip {} at height {:,d}' raise ChainError('backup block {} is not tip {} at height {:,d}'
.format(hash_to_str(header_hash), .format(hash_to_str(header_hash),
hash_to_str(self.tip), self.height)) hash_to_str(self.tip), self.height))
self.backup_txs(tx_hashes, txs, touched) self.backup_txs(txs, touched)
self.tip = self.coin.header_prevhash(header) self.tip = self.coin.header_prevhash(header)
assert self.height >= 0 assert self.height >= 0
self.height -= 1 self.height -= 1
@ -553,7 +555,7 @@ class BlockProcessor(server.db.DB):
touched.discard(None) touched.discard(None)
self.backup_flush(touched) self.backup_flush(touched)
def backup_txs(self, tx_hashes, txs, touched): def backup_txs(self, txs, touched):
# Prevout values, in order down the block (coinbase first if present) # Prevout values, in order down the block (coinbase first if present)
# undo_info is in reverse block order # undo_info is in reverse block order
undo_info = self.read_undo_info(self.height) undo_info = self.read_undo_info(self.height)
@ -569,10 +571,7 @@ class BlockProcessor(server.db.DB):
script_hashX = self.coin.hashX_from_script script_hashX = self.coin.hashX_from_script
undo_entry_len = 12 + self.coin.HASHX_LEN undo_entry_len = 12 + self.coin.HASHX_LEN
rtxs = reversed(txs) for tx, tx_hash in reversed(txs):
rtx_hashes = reversed(tx_hashes)
for tx_hash, tx in zip(rtx_hashes, rtxs):
for idx, txout in enumerate(tx.outputs): for idx, txout in enumerate(tx.outputs):
# Spend the TX outputs. Be careful with unspendable # Spend the TX outputs. Be careful with unspendable
# outputs - we didn't save those in the first place. # outputs - we didn't save those in the first place.

View File

@ -10,7 +10,6 @@
import array import array
import ast import ast
import itertools
import os import os
from struct import pack, unpack from struct import pack, unpack
from bisect import bisect_left, bisect_right from bisect import bisect_left, bisect_right
@ -234,7 +233,7 @@ class DB(util.LoggedClass):
assert len(self.tx_hashes) == blocks_done assert len(self.tx_hashes) == blocks_done
assert len(self.tx_counts) == new_height + 1 assert len(self.tx_counts) == new_height + 1
hashes = b''.join(itertools.chain(*block_tx_hashes)) hashes = b''.join(block_tx_hashes)
assert len(hashes) % 32 == 0 assert len(hashes) % 32 == 0
assert len(hashes) // 32 == txs_done assert len(hashes) // 32 == txs_done

View File

@ -13,7 +13,6 @@ import time
from collections import defaultdict from collections import defaultdict
from lib.hash import hash_to_str, hex_str_to_hash from lib.hash import hash_to_str, hex_str_to_hash
from lib.tx import Deserializer
import lib.util as util import lib.util as util
from server.daemon import DaemonError from server.daemon import DaemonError
@ -200,6 +199,7 @@ class MemPool(util.LoggedClass):
not depend on the result remaining the same are fine. not depend on the result remaining the same are fine.
''' '''
script_hashX = self.coin.hashX_from_script script_hashX = self.coin.hashX_from_script
deserializer = self.coin.deserializer()
db_utxo_lookup = self.db.db_utxo_lookup db_utxo_lookup = self.db.db_utxo_lookup
txs = self.txs txs = self.txs
@ -207,7 +207,7 @@ class MemPool(util.LoggedClass):
for tx_hash, raw_tx in raw_tx_map.items(): for tx_hash, raw_tx in raw_tx_map.items():
if not tx_hash in txs: if not tx_hash in txs:
continue continue
tx = Deserializer(raw_tx).read_tx() tx, _tx_hash = deserializer(raw_tx).read_tx()
# Convert the tx outputs into (hashX, value) pairs # Convert the tx outputs into (hashX, value) pairs
txout_pairs = [(script_hashX(txout.pk_script), txout.value) txout_pairs = [(script_hashX(txout.pk_script), txout.value)
@ -271,6 +271,7 @@ class MemPool(util.LoggedClass):
if not hashX in self.hashXs: if not hashX in self.hashXs:
return [] return []
deserializer = self.coin.deserializer()
hex_hashes = self.hashXs[hashX] hex_hashes = self.hashXs[hashX]
raw_txs = await self.daemon.getrawtransactions(hex_hashes) raw_txs = await self.daemon.getrawtransactions(hex_hashes)
result = [] result = []
@ -281,7 +282,7 @@ class MemPool(util.LoggedClass):
txin_pairs, txout_pairs = item txin_pairs, txout_pairs = item
tx_fee = (sum(v for hashX, v in txin_pairs) tx_fee = (sum(v for hashX, v in txin_pairs)
- sum(v for hashX, v in txout_pairs)) - sum(v for hashX, v in txout_pairs))
tx = Deserializer(raw_tx).read_tx() tx, tx_hash = deserializer(raw_tx).read_tx()
unconfirmed = any(txin.prev_hash in self.txs for txin in tx.inputs) unconfirmed = any(txin.prev_hash in self.txs for txin in tx.inputs)
result.append((hex_hash, tx_fee, unconfirmed)) result.append((hex_hash, tx_fee, unconfirmed))
return result return result

View File

@ -14,7 +14,6 @@ import traceback
from lib.hash import sha256, double_sha256, hash_to_str, hex_str_to_hash from lib.hash import sha256, double_sha256, hash_to_str, hex_str_to_hash
from lib.jsonrpc import JSONRPC from lib.jsonrpc import JSONRPC
from lib.tx import Deserializer
from server.daemon import DaemonError from server.daemon import DaemonError
from server.version import VERSION from server.version import VERSION
@ -427,7 +426,8 @@ class ElectrumX(Session):
if not raw_tx: if not raw_tx:
return None return None
raw_tx = bytes.fromhex(raw_tx) raw_tx = bytes.fromhex(raw_tx)
tx = Deserializer(raw_tx).read_tx() deserializer = self.coin.deserializer()
tx, tx_hash = deserializer(raw_tx).read_tx()
if index >= len(tx.outputs): if index >= len(tx.outputs):
return None return None
return self.coin.address_from_script(tx.outputs[index].pk_script) return self.coin.address_from_script(tx.outputs[index].pk_script)