class ExpiringDict(Generic[KT, VT]): """ Set with expiration time. For determining if items are in the set, use "if item in expiring_set.contents". __contains__ is intentionally not overwritten. This is a performance critical class, and we're avoiding extra function call overhead. """ contents: Dict[KT, VT] _alarm_queue: AlarmQueue _expiration_queue: ExpirationQueue[KT] _expiration_time: int def __init__(self, alarm_queue, expiration_time_s): self.contents = {} self._alarm_queue = alarm_queue self._expiration_queue = ExpirationQueue(expiration_time_s) self._expiration_time = expiration_time_s def add(self, key, value): self.contents[key] = value self._expiration_queue.add(key) self._alarm_queue.register_approx_alarm(self._expiration_time * 2, self._expiration_time, self.cleanup) def cleanup(self): self._expiration_queue.remove_expired(remove_callback=self.remove_item) return 0 def remove_item(self, key): if key in self.contents: del self.contents[key]
class ExpiringDict(Generic[KT, VT]): """ Dictionary with expiration time. """ contents: Dict[KT, VT] _alarm_queue: AlarmQueue _expiration_queue: ExpirationQueue[KT] _expiration_time: int _name: str def __init__(self, alarm_queue: AlarmQueue, expiration_time_s: int, name: str) -> None: self.contents = {} self._alarm_queue = alarm_queue self._expiration_queue = ExpirationQueue(expiration_time_s) self._expiration_time = expiration_time_s self._name = name def __contains__(self, item: KT): return item in self.contents def __setitem__(self, key: KT, value: VT): if key in self.contents: self.contents[key] = value else: self.add(key, value) def __delitem__(self, key: KT): del self.contents[key] self._expiration_queue.remove(key) def __getitem__(self, item: KT) -> VT: return self.contents[item] def add(self, key: KT, value: VT) -> None: self.contents[key] = value self._expiration_queue.add(key) self._alarm_queue.register_approx_alarm( self._expiration_time * 2, self._expiration_time, self.cleanup, alarm_name=f"ExpiringDict[{self._name}]#cleanup") def cleanup(self) -> float: self._expiration_queue.remove_expired(remove_callback=self.remove_item) return 0 def remove_item(self, key: KT) -> Optional[VT]: if key in self.contents: return self.contents.pop(key) else: return None
class BlockRecoveryService: """ Service class that handles blocks gateway receives with unknown transaction short ids are contents. Attributes ---------- recovered blocks: queue to which recovered blocks are pushed to _alarm_queue: reference to alarm queue to schedule cleanup on _bx_block_hash_to_sids: map of compressed block hash to its set of unknown short ids _bx_block_hash_to_tx_hashes: map of compressed block hash to its set of unknown transaction hashes _bx_block_hash_to_block_hash: map of compressed block hash to its original block hash _bx_block_hash_to_block: map of compressed block hash to its compressed byte representation _block_hash_to_bx_block_hashes: map of original block hash to compressed block hashes waiting for recovery _sid_to_bx_block_hashes: map of short id to compressed block hashes waiting for recovery _tx_hash_to_bx_block_hashes: map of transaction hash to block hashes waiting for recovery _cleanup_scheduled: whether block recovery has an alarm scheduled to clean up recovering blocks _blocks_expiration_queue: queue to trigger expiration of waiting for block recovery """ _alarm_queue: AlarmQueue _bx_block_hash_to_sids: Dict[Sha256Hash, Set[int]] _bx_block_hash_to_tx_hashes: Dict[Sha256Hash, Set[Sha256Hash]] _bx_block_hash_to_block_hash: Dict[Sha256Hash, Sha256Hash] _bx_block_hash_to_block: Dict[Sha256Hash, memoryview] _block_hash_to_bx_block_hashes: Dict[Sha256Hash, Set[Sha256Hash]] _sid_to_bx_block_hashes: Dict[int, Set[Sha256Hash]] _tx_hash_to_bx_block_hashes: Dict[Sha256Hash, Set[Sha256Hash]] _blocks_expiration_queue: ExpirationQueue _cleanup_scheduled: bool = False recovery_attempts_by_block: Dict[Sha256Hash, int] def __init__(self, alarm_queue: AlarmQueue): self.recovered_blocks = [] self._alarm_queue = alarm_queue self._bx_block_hash_to_sids = {} self._bx_block_hash_to_tx_hashes = {} self._bx_block_hash_to_block_hash = {} self._bx_block_hash_to_block = {} self.recovery_attempts_by_block = defaultdict(int) self._block_hash_to_bx_block_hashes = defaultdict(set) self._sid_to_bx_block_hashes = defaultdict(set) self._tx_hash_to_bx_block_hashes: Dict[ Sha256Hash, Set[Sha256Hash]] = defaultdict(set) self._blocks_expiration_queue = ExpirationQueue( gateway_constants.BLOCK_RECOVERY_MAX_QUEUE_TIME) def add_block(self, bx_block: memoryview, block_hash: Sha256Hash, unknown_tx_sids: List[int], unknown_tx_hashes: List[Sha256Hash]): """ Adds a block that needs to recovery. Tracks unknown short ids and contents as they come in. :param bx_block: bytearray representation of compressed block :param block_hash: original ObjectHash of block :param unknown_tx_sids: list of unknown short ids :param unknown_tx_hashes: list of unknown tx ObjectHashes """ logger.trace( "Recovering block with {} unknown short ids and {} contents: {}", len(unknown_tx_sids), len(unknown_tx_hashes), block_hash) bx_block_hash = Sha256Hash(crypto.double_sha256(bx_block)) self._bx_block_hash_to_block[bx_block_hash] = bx_block self._bx_block_hash_to_block_hash[bx_block_hash] = block_hash self._bx_block_hash_to_sids[bx_block_hash] = set(unknown_tx_sids) self._bx_block_hash_to_tx_hashes[bx_block_hash] = set( unknown_tx_hashes) self._block_hash_to_bx_block_hashes[block_hash].add(bx_block_hash) for sid in unknown_tx_sids: self._sid_to_bx_block_hashes[sid].add(bx_block_hash) for tx_hash in unknown_tx_hashes: self._tx_hash_to_bx_block_hashes[tx_hash].add(bx_block_hash) self._blocks_expiration_queue.add(bx_block_hash) self._schedule_cleanup() def get_blocks_awaiting_recovery(self) -> List[BlockRecoveryInfo]: """ Fetch all blocks still awaiting recovery and retry. """ blocks_awaiting_recovery = [] for block_hash, bx_block_hashes in self._block_hash_to_bx_block_hashes.items( ): unknown_short_ids = set() unknown_transaction_hashes = set() for bx_block_hash in bx_block_hashes: unknown_short_ids.update( self._bx_block_hash_to_sids[bx_block_hash]) unknown_transaction_hashes.update( self._bx_block_hash_to_tx_hashes[bx_block_hash]) blocks_awaiting_recovery.append( BlockRecoveryInfo(block_hash, unknown_short_ids, unknown_transaction_hashes, time.time())) return blocks_awaiting_recovery def check_missing_sid(self, sid: int, recovered_txs_source: RecoveredTxsSource) -> bool: """ Resolves recovering blocks depend on sid. :param sid: SID info that has been processed :param recovered_txs_source: source of recovered transaction """ if sid in self._sid_to_bx_block_hashes: logger.trace("Resolved previously unknown short id: {0}.", sid) bx_block_hashes = self._sid_to_bx_block_hashes[sid] for bx_block_hash in bx_block_hashes: if bx_block_hash in self._bx_block_hash_to_sids: if sid in self._bx_block_hash_to_sids[bx_block_hash]: self._bx_block_hash_to_sids[bx_block_hash].discard(sid) self._check_if_recovered(bx_block_hash, recovered_txs_source) del self._sid_to_bx_block_hashes[sid] return True else: return False def check_missing_tx_hash( self, tx_hash: Sha256Hash, recovered_txs_source: RecoveredTxsSource) -> bool: """ Resolves recovering blocks depend on transaction hash. :param tx_hash: transaction info that has been processed :param recovered_txs_source: source of recovered transaction """ if tx_hash in self._tx_hash_to_bx_block_hashes: logger.trace("Resolved previously unknown transaction hash {0}.", tx_hash) bx_block_hashes = self._tx_hash_to_bx_block_hashes[tx_hash] for bx_block_hash in bx_block_hashes: if bx_block_hash in self._bx_block_hash_to_tx_hashes: if tx_hash in self._bx_block_hash_to_tx_hashes[ bx_block_hash]: self._bx_block_hash_to_tx_hashes[ bx_block_hash].discard(tx_hash) self._check_if_recovered(bx_block_hash, recovered_txs_source) del self._tx_hash_to_bx_block_hashes[tx_hash] return True else: return False def cancel_recovery_for_block(self, block_hash: Sha256Hash) -> bool: """ Cancels recovery for all compressed blocks matching a block hash :param block_hash: ObjectHash """ if block_hash in self._block_hash_to_bx_block_hashes: logger.trace("Cancelled block recovery for block: {}", block_hash) self._remove_recovered_block_hash(block_hash) return True else: return False def awaiting_recovery(self, block_hash: Sha256Hash) -> bool: return block_hash in self._block_hash_to_bx_block_hashes # pyre-fixme[9]: clean_up_time has type `float`; used as `None`. def cleanup_old_blocks(self, clean_up_time: float = None): """ Cleans up old compressed blocks awaiting recovery. :param clean_up_time: """ logger.debug("Cleaning up block recovery.") num_blocks_awaiting_recovery = len(self._bx_block_hash_to_block) self._blocks_expiration_queue.remove_expired( current_time=clean_up_time, remove_callback=self._remove_not_recovered_block) logger.debug( "Cleaned up {} blocks awaiting recovery.", num_blocks_awaiting_recovery - len(self._bx_block_hash_to_block)) if self._bx_block_hash_to_block: return gateway_constants.BLOCK_RECOVERY_MAX_QUEUE_TIME # disable clean up until receive the next block with unknown tx self._cleanup_scheduled = False return 0 def clean_up_recovered_blocks(self): """ Cleans up blocks that have finished recovery. :return: """ logger.trace("Cleaning up {} recovered blocks.", len(self.recovered_blocks)) del self.recovered_blocks[:] def _check_if_recovered(self, bx_block_hash: Sha256Hash, recovered_txs_source: RecoveredTxsSource): """ Checks if a compressed block has received all short ids and transaction hashes necessary to recover. Adds block to recovered blocks if so. :param bx_block_hash: ObjectHash :return: """ if self._is_block_recovered(bx_block_hash): bx_block = self._bx_block_hash_to_block[bx_block_hash] block_hash = self._bx_block_hash_to_block_hash[bx_block_hash] logger.debug( "Recovery status for block {}, compress block hash {}: " "Block recovered by gateway. Source of recovered txs is {}.", block_hash, bx_block_hash, recovered_txs_source) self._remove_recovered_block_hash(block_hash) self.recovered_blocks.append((bx_block, recovered_txs_source)) def _is_block_recovered(self, bx_block_hash: Sha256Hash): """ Indicates if a compressed block has received all short ids and transaction hashes necessary to recover. :param bx_block_hash: ObjectHash :return: """ return len(self._bx_block_hash_to_sids[bx_block_hash]) == 0 and len( self._bx_block_hash_to_tx_hashes[bx_block_hash]) == 0 def _remove_recovered_block_hash(self, block_hash: Sha256Hash): """ Removes all compressed blocks awaiting recovery that are matches to a recovered block hash. :param block_hash: ObjectHash :return: """ if block_hash in self._block_hash_to_bx_block_hashes: for bx_block_hash in self._block_hash_to_bx_block_hashes[ block_hash]: if bx_block_hash in self._bx_block_hash_to_block: self._remove_sid_and_tx_mapping_for_bx_block_hash( bx_block_hash) del self._bx_block_hash_to_block[bx_block_hash] del self._bx_block_hash_to_block_hash[bx_block_hash] del self._block_hash_to_bx_block_hashes[block_hash] def _remove_sid_and_tx_mapping_for_bx_block_hash( self, bx_block_hash: Sha256Hash): """ Removes all short id and transaction mapping for a compressed block. :param bx_block_hash: :return: """ if bx_block_hash not in self._bx_block_hash_to_block: raise ValueError( "Can't remove mapping for a block that isn't being recovered.") for sid in self._bx_block_hash_to_sids[bx_block_hash]: if sid in self._sid_to_bx_block_hashes: self._sid_to_bx_block_hashes[sid].discard(bx_block_hash) if len(self._sid_to_bx_block_hashes[sid]) == 0: del self._sid_to_bx_block_hashes[sid] del self._bx_block_hash_to_sids[bx_block_hash] for tx_hash in self._bx_block_hash_to_tx_hashes[bx_block_hash]: if tx_hash in self._tx_hash_to_bx_block_hashes: self._tx_hash_to_bx_block_hashes[tx_hash].discard( bx_block_hash) if len(self._tx_hash_to_bx_block_hashes[tx_hash]) == 0: del self._tx_hash_to_bx_block_hashes[tx_hash] del self._bx_block_hash_to_tx_hashes[bx_block_hash] def _remove_not_recovered_block(self, bx_block_hash: Sha256Hash): """ Removes compressed block that has not recovered. :param bx_block_hash: ObjectHash """ if bx_block_hash in self._bx_block_hash_to_block: logger.trace("Block has failed recovery: {}", bx_block_hash) self._remove_sid_and_tx_mapping_for_bx_block_hash(bx_block_hash) del self._bx_block_hash_to_block[bx_block_hash] block_hash = self._bx_block_hash_to_block_hash.pop(bx_block_hash) self._block_hash_to_bx_block_hashes[block_hash].discard( bx_block_hash) if len(self._block_hash_to_bx_block_hashes[block_hash]) == 0: del self._block_hash_to_bx_block_hashes[block_hash] def _schedule_cleanup(self): if not self._cleanup_scheduled and self._bx_block_hash_to_block: logger.trace("Scheduling block recovery cleanup in {} seconds.", gateway_constants.BLOCK_RECOVERY_MAX_QUEUE_TIME) self._alarm_queue.register_alarm( gateway_constants.BLOCK_RECOVERY_MAX_QUEUE_TIME, self.cleanup_old_blocks) self._cleanup_scheduled = True
class ExpiringSet(Generic[T]): """ Set with expiration time. """ contents: Set[T] _alarm_queue: AlarmQueue _expiration_queue: ExpirationQueue[T] _expiration_time: int _log_removal: bool _name: str def __init__(self, alarm_queue: AlarmQueue, expiration_time_s: int, name: str, log_removal: bool = False): self.contents = set() self._alarm_queue = alarm_queue self._expiration_queue = ExpirationQueue(expiration_time_s) self._expiration_time = expiration_time_s self._log_removal = log_removal self._name = name def __contains__(self, item: T) -> bool: return item in self.contents def __len__(self) -> int: return len(self.contents) def add(self, item: T) -> None: self.contents.add(item) self._expiration_queue.add(item) self._alarm_queue.register_approx_alarm( self._expiration_time * 2, self._expiration_time, self.cleanup, alarm_name=f"ExpiringSet[{self._name}]#cleanup") def remove(self, item: T) -> None: self.contents.remove(item) def get_recent_items(self, count: int) -> List[T]: items = [] # noinspection PyTypeChecker entries = reversed(self._expiration_queue.queue.keys()) try: for _ in range(count): items.append(next(entries)) except StopIteration as _e: logger.debug("Attempted to fetch {} entries, but only {} existed.", count, len(items)) return items def cleanup(self) -> int: self._expiration_queue.remove_expired( remove_callback=self._safe_remove_item) return 0 def _safe_remove_item(self, item: T): if self._log_removal: logger.debug("Removing {} from expiring set.", item) if item in self.contents: self.contents.remove(item)
class EncryptedCache(object): """ Storage for in-progress received or sent encrypted blocks. """ def __init__(self, expiration_time_s, alarm_queue): self._cache = {} self._expiration_queue = ExpirationQueue(expiration_time_s) self._expiration_time_s = expiration_time_s self._alarm_queue = alarm_queue def encrypt_and_add_payload(self, payload): """ Encrypts payload, computing a hash and storing it along with the key for later release. If encryption is disabled for dev, store ciphertext identical to hash_key. """ key, ciphertext = symmetric_encrypt(bytes(payload)) hash_key = crypto.double_sha256(ciphertext) self._add(hash_key, key, ciphertext, payload) return ciphertext, hash_key def add_ciphertext(self, hash_key, ciphertext): if hash_key in self._cache: self._cache[hash_key].ciphertext = ciphertext else: self._add(hash_key, None, ciphertext, None) def add_key(self, hash_key, encryption_key): if hash_key in self._cache: self._cache[hash_key].key = encryption_key else: self._add(hash_key, encryption_key, None, None) def decrypt_and_get_payload(self, hash_key, encryption_key): """ Retrieves and decrypts stored ciphertext. Returns None if unable to decrypt. """ cache_item = self._cache[hash_key] cache_item.key = encryption_key return self._safe_decrypt_item(cache_item, hash_key) def decrypt_ciphertext(self, hash_key, ciphertext): """ Retrieves and decrypts ciphertext with stored key info. Stores info in cache. Returns None if unable to decrypt. """ cache_item = self._cache[hash_key] cache_item.ciphertext = ciphertext return self._safe_decrypt_item(cache_item, hash_key) def get_encryption_key(self, hash_key): return self._cache[hash_key].key def pop_ciphertext(self, hash_key): return self._cache.pop(hash_key).ciphertext def has_encryption_key_for_hash(self, hash_key): return hash_key in self._cache and self._cache[hash_key].key is not None def has_ciphertext_for_hash(self, hash_key): return hash_key in self._cache and self._cache[ hash_key].ciphertext is not None def hash_keys(self): return self._cache.keys() def encryption_items(self): return self._cache.values() def remove_item(self, hash_key): if hash_key in self._cache: del self._cache[hash_key] def _add(self, hash_key, encryption_key, ciphertext, payload): self._cache[hash_key] = EncryptionCacheItem(encryption_key, ciphertext, payload) self._expiration_queue.add(hash_key) self._alarm_queue.register_approx_alarm(self._expiration_time_s * 2, self._expiration_time_s, self._cleanup_old_cache_items) def _cleanup_old_cache_items(self): self._expiration_queue.remove_expired(remove_callback=self.remove_item) def __iter__(self): return iter(self._cache) def __len__(self): return len(self._cache) def _safe_decrypt_item(self, cache_item, hash_key): try: return cache_item.decrypt() except DecryptionError: failed_ciphertext = self.pop_ciphertext(hash_key) logger.warning( "Could not decrypt encrypted item with hash {}. Last four bytes: {}", convert.bytes_to_hex(hash_key), convert.bytes_to_hex(failed_ciphertext[-4:])) return None
class ExpiringSet(Generic[T]): """ Set with expiration time. For determining if items are in the set, use "if item in expiring_set.contents". __contains__ is intentionally not overwritten. This is a performance critical class, and we're avoiding extra function call overhead. """ contents: Set[T] _alarm_queue: AlarmQueue _expiration_queue: ExpirationQueue[T] _expiration_time: int _log_removal: bool def __init__(self, alarm_queue: AlarmQueue, expiration_time_s: int, log_removal: bool = False): self.contents = set() self._alarm_queue = alarm_queue self._expiration_queue = ExpirationQueue(expiration_time_s) self._expiration_time = expiration_time_s self._log_removal = log_removal def __contains__(self, item: T): return item in self.contents def __len__(self) -> int: return len(self.contents) def add(self, item: T): self.contents.add(item) self._expiration_queue.add(item) self._alarm_queue.register_approx_alarm(self._expiration_time * 2, self._expiration_time, self.cleanup) def get_recent_items(self, count: int) -> List[T]: items = [] # noinspection PyTypeChecker entries = reversed(self._expiration_queue.queue.keys() ) # pyre-ignore queue is actually an OrderedDict try: for i in range(count): items.append(next(entries)) except StopIteration as _e: logger.debug("Attempted to fetch {} entries, but only {} existed.", count, len(items)) return items def cleanup(self): self._expiration_queue.remove_expired( remove_callback=self._safe_remove_item) return 0 def _safe_remove_item(self, item: T): if self._log_removal: logger.debug("Removing {} from expiring set.", item) if item in self.contents: self.contents.remove(item)
class ExpirationQueueTests(unittest.TestCase): def setUp(self): self.time_to_live = 60 self.queue = ExpirationQueue(self.time_to_live) self.removed_items = [] def test_expiration_queue(self): # adding 2 items to the queue with 1 second difference item1 = 1 item2 = 2 self.queue.add(item1) time_1_added = time.time() time.time = MagicMock(return_value=time.time() + 1) self.queue.add(item2) time_2_added = time.time() self.assertEqual(len(self.queue), 2) self.assertEqual(int(time_1_added), int(self.queue.get_oldest_item_timestamp())) self.assertEqual(item1, self.queue.get_oldest()) # check that nothing is removed from queue before the first item expires self.queue.remove_expired(time_1_added + self.time_to_live / 2, remove_callback=self._remove_item) self.assertEqual(len(self.queue), 2) self.assertEqual(len(self.removed_items), 0) # check that first item removed after first item expired self.queue.remove_expired(time_1_added + self.time_to_live + 1, remove_callback=self._remove_item) self.assertEqual(len(self.queue), 1) self.assertEqual(len(self.removed_items), 1) self.assertEqual(self.removed_items[0], item1) self.assertEqual(int(time_2_added), int(self.queue.get_oldest_item_timestamp())) self.assertEqual(item2, self.queue.get_oldest()) # check that second item is removed after second item expires self.queue.remove_expired(time_2_added + self.time_to_live + 1, remove_callback=self._remove_item) self.assertEqual(len(self.queue), 0) self.assertEqual(len(self.removed_items), 2) self.assertEqual(self.removed_items[0], item1) self.assertEqual(self.removed_items[1], item2) def test_remove_oldest_item(self): items_count = 10 for i in range(items_count): self.queue.add(i) self.assertEqual(items_count, len(self.queue)) removed_items_1 = [] for i in range(items_count): self.assertEqual(i, self.queue.get_oldest()) self.queue.remove_oldest(removed_items_1.append) self.queue.add(1000 + i) for i in range(items_count): self.assertEqual(i, removed_items_1[i]) self.assertEqual(items_count, len(self.queue)) removed_items_2 = [] for i in range(items_count): self.assertEqual(i + 1000, self.queue.get_oldest()) self.queue.remove_oldest(removed_items_2.append) for i in range(items_count): self.assertEqual(i + 1000, removed_items_2[i]) self.assertEqual(0, len(self.queue)) def test_remove_not_oldest_item(self): # adding 2 items to the queue with 1 second difference item1 = 9 item2 = 5 self.queue.add(item1) time_1_added = time.time() time.time = MagicMock(return_value=time.time() + 1) self.queue.add(item2) self.assertEqual(len(self.queue), 2) self.assertEqual(int(time_1_added), int(self.queue.get_oldest_item_timestamp())) self.assertEqual(item1, self.queue.get_oldest()) self.queue.remove(item2) self.assertEqual(len(self.queue), 1) self.assertEqual(int(time_1_added), int(self.queue.get_oldest_item_timestamp())) self.assertEqual(item1, self.queue.get_oldest()) def _remove_item(self, item): self.removed_items.append(item)