diff --git a/lib/util.py b/lib/util.py index c60a0d3..cc755b8 100644 --- a/lib/util.py +++ b/lib/util.py @@ -128,20 +128,13 @@ def int_to_bytes(value): def increment_byte_string(bs): - bs = bytearray(bs) - incremented = False - for i in reversed(range(len(bs))): - if bs[i] < 0xff: - # This is easy - bs[i] += 1 - incremented = True - break - # Otherwise we need to look at the previous character - bs[i] = 0 - if not incremented: - # This can only happen if all characters are 0xff - bs = bytes([1]) + bs - return bytes(bs) + '''Return the lexicographically next byte string of the same length. + + Return None if there is none (when the input is all 0xff bytes).''' + for n in range(1, len(bs) + 1): + if bs[-n] != 0xff: + return bs[:-n] + bytes([bs[-n] + 1]) + bytes(n - 1) + return None class LogicalFile(object): diff --git a/server/storage.py b/server/storage.py index 40d74b6..cd44746 100644 --- a/server/storage.py +++ b/server/storage.py @@ -115,46 +115,53 @@ class RocksDB(Storage): import gc gc.collect() - class WriteBatch(object): - def __init__(self, db): - self.batch = RocksDB.module.WriteBatch() - self.db = db - - def __enter__(self): - return self.batch - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_val: - self.db.write(self.batch) - def write_batch(self): - return RocksDB.WriteBatch(self.db) - - class Iterator(object): - def __init__(self, db, prefix, reverse): - self.it = db.iteritems() - self.reverse = reverse - self.prefix = prefix - # Whether we are at the first item - self.first = True - - def __iter__(self): - prefix = self.prefix - if self.reverse: - prefix = increment_byte_string(prefix) - self.it = reversed(self.it) - self.it.seek(prefix) - return self - - def __next__(self): - k, v = self.it.__next__() - if self.first and self.reverse and not k.startswith(self.prefix): - k, v = self.it.__next__() - self.first = False - if not k.startswith(self.prefix): - # We're already ahead of the prefix - raise StopIteration - return k, v + return RocksDBWriteBatch(self.db) def iterator(self, prefix=b'', reverse=False): - return RocksDB.Iterator(self.db, prefix, reverse) + return RocksDBIterator(self.db, prefix, reverse) + + +class RocksDBWriteBatch(object): + '''A write batch for RocksDB.''' + + def __init__(self, db): + self.batch = RocksDB.module.WriteBatch() + self.db = db + + def __enter__(self): + return self.batch + + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_val: + self.db.write(self.batch) + + +class RocksDBIterator(object): + '''An iterator for RocksDB.''' + + def __init__(self, db, prefix, reverse): + self.prefix = prefix + if reverse: + self.iterator = reversed(db.iteritems()) + nxt_prefix = util.increment_byte_string(prefix) + if nxt_prefix: + self.iterator.seek(nxt_prefix) + try: + next(self.iterator) + except StopIteration: + self.iterator.seek(nxt_prefix) + else: + self.iterator.seek_to_last() + else: + self.iterator = db.iteritems() + self.iterator.seek(prefix) + + def __iter__(self): + return self + + def __next__(self): + k, v = next(self.iterator) + if not k.startswith(self.prefix): + raise StopIteration + return k, v