compressed int bug fix + more tests

This commit is contained in:
4tochka 2019-04-09 17:03:25 +02:00
parent 8ee9602526
commit 9b2df3f7cd
3 changed files with 61 additions and 35 deletions

View File

@ -1,4 +1,4 @@
from math import ceil from math import ceil, floor
from io import BytesIO from io import BytesIO
from struct import pack, unpack from struct import pack, unpack
@ -167,11 +167,9 @@ def read_var_list(stream, data_type):
# compressed integer # compressed integer
def int_to_c_int(n, base_bytes=1): def int_to_c_int(n, base_bytes=1):
""" """
Convert integer to compresed integer Convert integer to compressed integer
:param n: integer. :param n: integer.
:param base_bytes: len of bytes base from which start compression. :param base_bytes: len of bytes base from which start compression.
:return: bytes. :return: bytes.
@ -183,7 +181,15 @@ def int_to_c_int(n, base_bytes=1):
if l <= base_bytes * 8: if l <= base_bytes * 8:
return n.to_bytes(base_bytes, byteorder="big") return n.to_bytes(base_bytes, byteorder="big")
prefix = 0 prefix = 0
payload_bytes = ceil((l)/8) - base_bytes + 1 payload_bytes = ceil((l)/8) - base_bytes
a=payload_bytes
while True:
add_bytes = floor((a) / 8)
a = add_bytes
if add_bytes>=1:
add_bytes+=floor((payload_bytes+add_bytes) / 8) - floor((payload_bytes) / 8)
payload_bytes+=add_bytes
if a==0: break
extra_bytes = int(ceil((l+payload_bytes)/8) - base_bytes) extra_bytes = int(ceil((l+payload_bytes)/8) - base_bytes)
for i in range(extra_bytes): for i in range(extra_bytes):
prefix += 2 ** i prefix += 2 ** i
@ -193,7 +199,31 @@ def int_to_c_int(n, base_bytes=1):
if prefix.bit_length() % 8: if prefix.bit_length() % 8:
prefix = prefix << 8 - prefix.bit_length() % 8 prefix = prefix << 8 - prefix.bit_length() % 8
n ^= prefix n ^= prefix
return n.to_bytes(ceil(n.bit_length()/8), byteorder="big") return n.to_bytes(ceil(n.bit_length() / 8), byteorder="big")
def c_int_len(n, base_bytes=1):
"""
Get length of compressed integer from integer value
:param n: bytes.
:param base_bytes: len of bytes base from which start compression.
:return: integer.
"""
if n == 0:
return base_bytes
l = n.bit_length() + 1
if l <= base_bytes * 8:
return base_bytes
payload_bytes = ceil((l) / 8) - base_bytes
a = payload_bytes
while True:
add_bytes = floor((a) / 8)
a = add_bytes
if add_bytes >= 1:
add_bytes += floor((payload_bytes + add_bytes) / 8) - floor((payload_bytes) / 8)
payload_bytes += add_bytes
if a == 0: break
return int(ceil((l+payload_bytes)/8))
def c_int_to_int(b, base_bytes=1): def c_int_to_int(b, base_bytes=1):
@ -222,24 +252,6 @@ def c_int_to_int(b, base_bytes=1):
return n return n
def c_int_len(n, base_bytes=1):
"""
Get length of compressed integer from integer value
:param n: bytes.
:param base_bytes: len of bytes base from which start compression.
:return: integer.
"""
if n == 0:
return base_bytes
l = n.bit_length() + 1
min_bits = base_bytes * 8 - 1
if l <= min_bits + 1:
return base_bytes
payload_bytes = ceil((l)/8) - base_bytes + 1
return int(ceil((l+payload_bytes)/8))
# generic big endian MPI format # generic big endian MPI format
def bn_bytes(v, have_ext=False): def bn_bytes(v, have_ext=False):
ext = 0 ext = 0

View File

@ -1,14 +1,14 @@
from .hash_functions import * # from .hash_functions import *
from .integer import * from .integer import *
from .address_functions import * # from .address_functions import *
from .script_functions import * # from .script_functions import *
from .ecdsa import * # from .ecdsa import *
from .mnemonic import * # from .mnemonic import *
from .sighash import * # from .sighash import *
from .address_class import * # from .address_class import *
from .transaction_deserialize import * # from .transaction_deserialize import *
from .transaction_constructor import * # from .transaction_constructor import *
from .block import * # from .block import *
# from .script_deserialize import * # from .script_deserialize import *
# from .create_transaction import * # from .create_transaction import *

View File

@ -4,7 +4,7 @@ parentPath = os.path.abspath("..")
if parentPath not in sys.path: if parentPath not in sys.path:
sys.path.insert(0, parentPath) sys.path.insert(0, parentPath)
from pybtc import tools from pybtc import tools
import math
def print_bytes(b): def print_bytes(b):
@ -126,6 +126,20 @@ class IntegerFunctionsTests(unittest.TestCase):
for i in range(341616807575530379006368233343265341697 - 10, 341616807575530379006368233343265341697 + 10): for i in range(341616807575530379006368233343265341697 - 10, 341616807575530379006368233343265341697 + 10):
self.assertEqual(tools.c_int_to_int((tools.int_to_c_int(i))), i) self.assertEqual(tools.c_int_to_int((tools.int_to_c_int(i))), i)
self.assertEqual(tools.c_int_len(i), len(tools.int_to_c_int(i))) self.assertEqual(tools.c_int_len(i), len(tools.int_to_c_int(i)))
number = 0
old_number = 0
for i in range(0, 1024, 8):
number += 2 ** i
for i in range(old_number, number, int(math.ceil(2 ** i / 20))):
b = 1
a = tools.int_to_c_int(i, b)
c = tools.c_int_to_int(a, b)
l = tools.c_int_len(i)
self.assertEqual(c, i)
self.assertEqual(l, len(a))
old_number = number
def test_variable_integer(self): def test_variable_integer(self):
for i in range(0, 0xfd): for i in range(0, 0xfd):
self.assertEqual(tools.var_int_to_int((tools.int_to_var_int(i))), i) self.assertEqual(tools.var_int_to_int((tools.int_to_var_int(i))), i)