compressed int bug fix + more tests
This commit is contained in:
parent
8ee9602526
commit
9b2df3f7cd
@ -1,4 +1,4 @@
|
||||
from math import ceil
|
||||
from math import ceil, floor
|
||||
from io import BytesIO
|
||||
from struct import pack, unpack
|
||||
|
||||
@ -167,11 +167,9 @@ def read_var_list(stream, data_type):
|
||||
|
||||
# compressed integer
|
||||
|
||||
|
||||
def int_to_c_int(n, base_bytes=1):
|
||||
"""
|
||||
Convert integer to compresed integer
|
||||
|
||||
Convert integer to compressed integer
|
||||
:param n: integer.
|
||||
:param base_bytes: len of bytes base from which start compression.
|
||||
:return: bytes.
|
||||
@ -183,7 +181,15 @@ def int_to_c_int(n, base_bytes=1):
|
||||
if l <= base_bytes * 8:
|
||||
return n.to_bytes(base_bytes, byteorder="big")
|
||||
prefix = 0
|
||||
payload_bytes = ceil((l)/8) - base_bytes + 1
|
||||
payload_bytes = ceil((l)/8) - base_bytes
|
||||
a=payload_bytes
|
||||
while True:
|
||||
add_bytes = floor((a) / 8)
|
||||
a = add_bytes
|
||||
if add_bytes>=1:
|
||||
add_bytes+=floor((payload_bytes+add_bytes) / 8) - floor((payload_bytes) / 8)
|
||||
payload_bytes+=add_bytes
|
||||
if a==0: break
|
||||
extra_bytes = int(ceil((l+payload_bytes)/8) - base_bytes)
|
||||
for i in range(extra_bytes):
|
||||
prefix += 2 ** i
|
||||
@ -193,7 +199,31 @@ def int_to_c_int(n, base_bytes=1):
|
||||
if prefix.bit_length() % 8:
|
||||
prefix = prefix << 8 - prefix.bit_length() % 8
|
||||
n ^= prefix
|
||||
return n.to_bytes(ceil(n.bit_length()/8), byteorder="big")
|
||||
return n.to_bytes(ceil(n.bit_length() / 8), byteorder="big")
|
||||
|
||||
|
||||
def c_int_len(n, base_bytes=1):
|
||||
"""
|
||||
Get length of compressed integer from integer value
|
||||
:param n: bytes.
|
||||
:param base_bytes: len of bytes base from which start compression.
|
||||
:return: integer.
|
||||
"""
|
||||
if n == 0:
|
||||
return base_bytes
|
||||
l = n.bit_length() + 1
|
||||
if l <= base_bytes * 8:
|
||||
return base_bytes
|
||||
payload_bytes = ceil((l) / 8) - base_bytes
|
||||
a = payload_bytes
|
||||
while True:
|
||||
add_bytes = floor((a) / 8)
|
||||
a = add_bytes
|
||||
if add_bytes >= 1:
|
||||
add_bytes += floor((payload_bytes + add_bytes) / 8) - floor((payload_bytes) / 8)
|
||||
payload_bytes += add_bytes
|
||||
if a == 0: break
|
||||
return int(ceil((l+payload_bytes)/8))
|
||||
|
||||
|
||||
def c_int_to_int(b, base_bytes=1):
|
||||
@ -222,24 +252,6 @@ def c_int_to_int(b, base_bytes=1):
|
||||
return n
|
||||
|
||||
|
||||
def c_int_len(n, base_bytes=1):
|
||||
"""
|
||||
Get length of compressed integer from integer value
|
||||
|
||||
:param n: bytes.
|
||||
:param base_bytes: len of bytes base from which start compression.
|
||||
:return: integer.
|
||||
"""
|
||||
if n == 0:
|
||||
return base_bytes
|
||||
l = n.bit_length() + 1
|
||||
min_bits = base_bytes * 8 - 1
|
||||
if l <= min_bits + 1:
|
||||
return base_bytes
|
||||
payload_bytes = ceil((l)/8) - base_bytes + 1
|
||||
return int(ceil((l+payload_bytes)/8))
|
||||
|
||||
|
||||
# generic big endian MPI format
|
||||
def bn_bytes(v, have_ext=False):
|
||||
ext = 0
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
from .hash_functions import *
|
||||
# from .hash_functions import *
|
||||
from .integer import *
|
||||
from .address_functions import *
|
||||
from .script_functions import *
|
||||
from .ecdsa import *
|
||||
from .mnemonic import *
|
||||
from .sighash import *
|
||||
from .address_class import *
|
||||
from .transaction_deserialize import *
|
||||
from .transaction_constructor import *
|
||||
from .block import *
|
||||
# from .address_functions import *
|
||||
# from .script_functions import *
|
||||
# from .ecdsa import *
|
||||
# from .mnemonic import *
|
||||
# from .sighash import *
|
||||
# from .address_class import *
|
||||
# from .transaction_deserialize import *
|
||||
# from .transaction_constructor import *
|
||||
# from .block import *
|
||||
|
||||
# from .script_deserialize import *
|
||||
# from .create_transaction import *
|
||||
|
||||
@ -4,7 +4,7 @@ parentPath = os.path.abspath("..")
|
||||
if parentPath not in sys.path:
|
||||
sys.path.insert(0, parentPath)
|
||||
from pybtc import tools
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def print_bytes(b):
|
||||
@ -126,6 +126,20 @@ class IntegerFunctionsTests(unittest.TestCase):
|
||||
for i in range(341616807575530379006368233343265341697 - 10, 341616807575530379006368233343265341697 + 10):
|
||||
self.assertEqual(tools.c_int_to_int((tools.int_to_c_int(i))), i)
|
||||
self.assertEqual(tools.c_int_len(i), len(tools.int_to_c_int(i)))
|
||||
|
||||
number = 0
|
||||
old_number = 0
|
||||
for i in range(0, 1024, 8):
|
||||
number += 2 ** i
|
||||
for i in range(old_number, number, int(math.ceil(2 ** i / 20))):
|
||||
b = 1
|
||||
a = tools.int_to_c_int(i, b)
|
||||
c = tools.c_int_to_int(a, b)
|
||||
l = tools.c_int_len(i)
|
||||
self.assertEqual(c, i)
|
||||
self.assertEqual(l, len(a))
|
||||
old_number = number
|
||||
|
||||
def test_variable_integer(self):
|
||||
for i in range(0, 0xfd):
|
||||
self.assertEqual(tools.var_int_to_int((tools.int_to_var_int(i))), i)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user