diff --git a/test_rollback_new.py b/test_rollback_new.py index 599f2c3..043846c 100644 --- a/test_rollback_new.py +++ b/test_rollback_new.py @@ -228,7 +228,7 @@ def rollback_database(blockNumber, dbtype, dbname): db_session = create_database_session_orm('token', {'token_name':dbname}, Base) while(True): subqry = db_session.query(func.max(ActiveTable.id)) - activeTable_entry = db_session.query(ActiveTable).filter(ActiveTable.id == subqry).first() + activeTable_entry = db_session.query(ActiveTable).filter(ActiveTable.id == subqry).first() if activeTable_entry.blockNumber <= blockNumber: break outputAddress = activeTable_entry.address @@ -278,6 +278,12 @@ def rollback_database(blockNumber, dbtype, dbname): for orphan_entry in orphaned_parentid_entries: orphan_entry.parentid = orphan_entry.orphaned_parentid orphan_entry.orphaned_parentid = None + + orphaned_parentid_entries = db_session.query(ConsumedTable).filter(ConsumedTable.orphaned_parentid == key).all() + if len(orphaned_parentid_entries) != 0: + for orphan_entry in orphaned_parentid_entries: + orphan_entry.parentid = orphan_entry.orphaned_parentid + orphan_entry.orphaned_parentid = None # update addressBalance @@ -290,6 +296,7 @@ def rollback_database(blockNumber, dbtype, dbname): db_session.delete(activeTable_entry) db_session.query(TransactionHistory).filter(TransactionHistory.blockNumber > blockNumber).delete() + db_session.query(TransferLogs).filter(TransferLogs.blockNumber > blockNumber).delete() db_session.commit() elif dbtype == 'smartcontract': diff --git a/tracktokens_smartcontracts.py b/tracktokens_smartcontracts.py index dc6a849..fbab76f 100755 --- a/tracktokens_smartcontracts.py +++ b/tracktokens_smartcontracts.py @@ -271,6 +271,9 @@ def transferToken(tokenIdentification, tokenAmount, inputAddress, outputAddress, entries = session.query(ActiveTable).filter(ActiveTable.parentid == piditem[0]).all() process_pids(entries, session, piditem) + entries = session.query(ConsumedTable).filter(ConsumedTable.parentid == piditem[0]).all() + process_pids(entries, session, piditem) + # move the pids consumed in the transaction to consumedTable and delete them from activeTable session.execute('INSERT INTO consumedTable (id, address, parentid, consumedpid, transferBalance, addressBalance, orphaned_parentid, blockNumber) SELECT id, address, parentid, consumedpid, transferBalance, addressBalance, orphaned_parentid, blockNumber FROM activeTable WHERE id={}'.format(piditem[0])) session.execute('DELETE FROM activeTable WHERE id={}'.format(piditem[0])) @@ -324,6 +327,9 @@ def transferToken(tokenIdentification, tokenAmount, inputAddress, outputAddress, entries = session.query(ActiveTable).filter(ActiveTable.parentid == piditem[0]).all() process_pids(entries, session, piditem) + entries = session.query(ConsumedTable).filter(ConsumedTable.parentid == piditem[0]).all() + process_pids(entries, session, piditem) + # move the pids consumed in the transaction to consumedTable and delete them from activeTable session.execute('INSERT INTO consumedTable (id, address, parentid, consumedpid, transferBalance, addressBalance, orphaned_parentid, blockNumber) SELECT id, address, parentid, consumedpid, transferBalance, addressBalance, orphaned_parentid, blockNumber FROM activeTable WHERE id={}'.format(piditem[0])) session.execute('DELETE FROM activeTable WHERE id={}'.format(piditem[0]))