diff --git a/lib/coins.py b/lib/coins.py index eea54f1..e1f96b3 100644 --- a/lib/coins.py +++ b/lib/coins.py @@ -70,6 +70,8 @@ class Coin(object): BLOCK_PROCESSOR = BlockProcessor XPUB_VERBYTES = bytes('????', 'utf-8') XPRV_VERBYTES = bytes('????', 'utf-8') + ENCODE_CHECK = Base58.encode_check + DECODE_CHECK = Base58.decode_check # Peer discovery PEER_DEFAULT_PORTS = {'t': '50001', 's': '50002'} PEERS = [] @@ -168,7 +170,7 @@ class Coin(object): def P2PKH_address_from_hash160(cls, hash160): '''Return a P2PKH address given a public key.''' assert len(hash160) == 20 - return Base58.encode_check(cls.P2PKH_VERBYTE + hash160) + return cls.ENCODE_CHECK(cls.P2PKH_VERBYTE + hash160) @classmethod def P2PKH_address_from_pubkey(cls, pubkey): @@ -179,7 +181,7 @@ class Coin(object): def P2SH_address_from_hash160(cls, hash160): '''Return a coin address given a hash160.''' assert len(hash160) == 20 - return Base58.encode_check(cls.P2SH_VERBYTES[0] + hash160) + return cls.ENCODE_CHECK(cls.P2SH_VERBYTES[0] + hash160) @classmethod def multisig_address(cls, m, pubkeys): @@ -212,7 +214,7 @@ class Coin(object): Pass the address (either P2PKH or P2SH) in base58 form. ''' - raw = Base58.decode_check(address) + raw = cls.DECODE_CHECK(address) # Require version byte(s) plus hash160. verbyte = -1 @@ -233,7 +235,7 @@ class Coin(object): payload = bytearray(cls.WIF_BYTE) + privkey_bytes if compressed: payload.append(0x01) - return Base58.encode_check(payload) + return cls.ENCODE_CHECK(payload) @classmethod def header_hash(cls, header): diff --git a/lib/hash.py b/lib/hash.py index 0f9436d..47039bd 100644 --- a/lib/hash.py +++ b/lib/hash.py @@ -143,18 +143,18 @@ class Base58(object): return txt[::-1] @staticmethod - def decode_check(txt): + def decode_check(txt, *, hash_fn=double_sha256): '''Decodes a Base58Check-encoded string to a payload. The version prefixes it.''' be_bytes = Base58.decode(txt) result, check = be_bytes[:-4], be_bytes[-4:] - if check != double_sha256(result)[:4]: + if check != hash_fn(result)[:4]: raise Base58Error('invalid base 58 checksum for {}'.format(txt)) return result @staticmethod - def encode_check(payload): + def encode_check(payload, *, hash_fn=double_sha256): """Encodes a payload bytearray (which includes the version byte(s)) into a Base58Check string.""" - be_bytes = payload + double_sha256(payload)[:4] + be_bytes = payload + hash_fn(payload)[:4] return Base58.encode(be_bytes) diff --git a/tests/lib/test_addresses.py b/tests/lib/test_addresses.py index e832f04..293f2eb 100644 --- a/tests/lib/test_addresses.py +++ b/tests/lib/test_addresses.py @@ -60,7 +60,7 @@ def test_address_to_hashX(address): def test_address_from_hash160(address): coin, addr, hash, _ = address - raw = Base58.decode_check(addr) + raw = coin.DECODE_CHECK(addr) verlen = len(raw) - 20 assert verlen > 0 verbyte, hash_bytes = raw[:verlen], raw[verlen:] diff --git a/tests/lib/test_hash.py b/tests/lib/test_hash.py index f1b9d46..87aa355 100644 --- a/tests/lib/test_hash.py +++ b/tests/lib/test_hash.py @@ -1,6 +1,7 @@ # # Tests of lib/hash.py # +from functools import partial import pytest @@ -66,3 +67,19 @@ def test_Base58_encode_check(): with pytest.raises(TypeError): lib_hash.Base58.encode_check('foo') assert lib_hash.Base58.encode_check(b'foo') == '4t9WKfuAB8' + +def test_Base58_decode_check_custom(): + decode_check_sha256 = partial(lib_hash.Base58.decode_check, + hash_fn=lib_hash.sha256) + with pytest.raises(TypeError): + decode_check_sha256(b'foo') + assert decode_check_sha256('4t9WFhKfWr') == b'foo' + with pytest.raises(lib_hash.Base58Error): + decode_check_sha256('4t9WFhKfWp') + +def test_Base58_encode_check_custom(): + encode_check_sha256 = partial(lib_hash.Base58.encode_check, + hash_fn=lib_hash.sha256) + with pytest.raises(TypeError): + encode_check_sha256('foo') + assert encode_check_sha256(b'foo') == '4t9WFhKfWr'