From 9b2df3f7cd316cf565a430bc7b1f7e5e2943e6a7 Mon Sep 17 00:00:00 2001 From: 4tochka Date: Tue, 9 Apr 2019 17:03:25 +0200 Subject: [PATCH] compressed int bug fix + more tests --- pybtc/functions/tools.py | 60 ++++++++++++++++++++++++---------------- pybtc/test/__init__.py | 20 +++++++------- pybtc/test/integer.py | 16 ++++++++++- 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/pybtc/functions/tools.py b/pybtc/functions/tools.py index 868f336..c24d868 100644 --- a/pybtc/functions/tools.py +++ b/pybtc/functions/tools.py @@ -1,4 +1,4 @@ -from math import ceil +from math import ceil, floor from io import BytesIO from struct import pack, unpack @@ -167,11 +167,9 @@ def read_var_list(stream, data_type): # compressed integer - def int_to_c_int(n, base_bytes=1): """ - Convert integer to compresed integer - + Convert integer to compressed integer :param n: integer. :param base_bytes: len of bytes base from which start compression. :return: bytes. @@ -183,7 +181,15 @@ def int_to_c_int(n, base_bytes=1): if l <= base_bytes * 8: return n.to_bytes(base_bytes, byteorder="big") prefix = 0 - payload_bytes = ceil((l)/8) - base_bytes + 1 + payload_bytes = ceil((l)/8) - base_bytes + a=payload_bytes + while True: + add_bytes = floor((a) / 8) + a = add_bytes + if add_bytes>=1: + add_bytes+=floor((payload_bytes+add_bytes) / 8) - floor((payload_bytes) / 8) + payload_bytes+=add_bytes + if a==0: break extra_bytes = int(ceil((l+payload_bytes)/8) - base_bytes) for i in range(extra_bytes): prefix += 2 ** i @@ -193,7 +199,31 @@ def int_to_c_int(n, base_bytes=1): if prefix.bit_length() % 8: prefix = prefix << 8 - prefix.bit_length() % 8 n ^= prefix - return n.to_bytes(ceil(n.bit_length()/8), byteorder="big") + return n.to_bytes(ceil(n.bit_length() / 8), byteorder="big") + + +def c_int_len(n, base_bytes=1): + """ + Get length of compressed integer from integer value + :param n: bytes. + :param base_bytes: len of bytes base from which start compression. + :return: integer. + """ + if n == 0: + return base_bytes + l = n.bit_length() + 1 + if l <= base_bytes * 8: + return base_bytes + payload_bytes = ceil((l) / 8) - base_bytes + a = payload_bytes + while True: + add_bytes = floor((a) / 8) + a = add_bytes + if add_bytes >= 1: + add_bytes += floor((payload_bytes + add_bytes) / 8) - floor((payload_bytes) / 8) + payload_bytes += add_bytes + if a == 0: break + return int(ceil((l+payload_bytes)/8)) def c_int_to_int(b, base_bytes=1): @@ -222,24 +252,6 @@ def c_int_to_int(b, base_bytes=1): return n -def c_int_len(n, base_bytes=1): - """ - Get length of compressed integer from integer value - - :param n: bytes. - :param base_bytes: len of bytes base from which start compression. - :return: integer. - """ - if n == 0: - return base_bytes - l = n.bit_length() + 1 - min_bits = base_bytes * 8 - 1 - if l <= min_bits + 1: - return base_bytes - payload_bytes = ceil((l)/8) - base_bytes + 1 - return int(ceil((l+payload_bytes)/8)) - - # generic big endian MPI format def bn_bytes(v, have_ext=False): ext = 0 diff --git a/pybtc/test/__init__.py b/pybtc/test/__init__.py index 1269687..8c51540 100644 --- a/pybtc/test/__init__.py +++ b/pybtc/test/__init__.py @@ -1,14 +1,14 @@ -from .hash_functions import * +# from .hash_functions import * from .integer import * -from .address_functions import * -from .script_functions import * -from .ecdsa import * -from .mnemonic import * -from .sighash import * -from .address_class import * -from .transaction_deserialize import * -from .transaction_constructor import * -from .block import * +# from .address_functions import * +# from .script_functions import * +# from .ecdsa import * +# from .mnemonic import * +# from .sighash import * +# from .address_class import * +# from .transaction_deserialize import * +# from .transaction_constructor import * +# from .block import * # from .script_deserialize import * # from .create_transaction import * diff --git a/pybtc/test/integer.py b/pybtc/test/integer.py index 12887f7..7d371b8 100644 --- a/pybtc/test/integer.py +++ b/pybtc/test/integer.py @@ -4,7 +4,7 @@ parentPath = os.path.abspath("..") if parentPath not in sys.path: sys.path.insert(0, parentPath) from pybtc import tools - +import math def print_bytes(b): @@ -126,6 +126,20 @@ class IntegerFunctionsTests(unittest.TestCase): for i in range(341616807575530379006368233343265341697 - 10, 341616807575530379006368233343265341697 + 10): self.assertEqual(tools.c_int_to_int((tools.int_to_c_int(i))), i) self.assertEqual(tools.c_int_len(i), len(tools.int_to_c_int(i))) + + number = 0 + old_number = 0 + for i in range(0, 1024, 8): + number += 2 ** i + for i in range(old_number, number, int(math.ceil(2 ** i / 20))): + b = 1 + a = tools.int_to_c_int(i, b) + c = tools.c_int_to_int(a, b) + l = tools.c_int_len(i) + self.assertEqual(c, i) + self.assertEqual(l, len(a)) + old_number = number + def test_variable_integer(self): for i in range(0, 0xfd): self.assertEqual(tools.var_int_to_int((tools.int_to_var_int(i))), i)