Exemple #1
0
class ExpiringDict(Generic[KT, VT]):
    """
    Set with expiration time.

    For determining if items are in the set, use "if item in expiring_set.contents".
    __contains__ is intentionally not overwritten. This is a performance critical class,
    and we're avoiding extra function call overhead.
    """

    contents: Dict[KT, VT]
    _alarm_queue: AlarmQueue
    _expiration_queue: ExpirationQueue[KT]
    _expiration_time: int

    def __init__(self, alarm_queue, expiration_time_s):
        self.contents = {}
        self._alarm_queue = alarm_queue
        self._expiration_queue = ExpirationQueue(expiration_time_s)
        self._expiration_time = expiration_time_s

    def add(self, key, value):
        self.contents[key] = value
        self._expiration_queue.add(key)
        self._alarm_queue.register_approx_alarm(self._expiration_time * 2, self._expiration_time, self.cleanup)

    def cleanup(self):
        self._expiration_queue.remove_expired(remove_callback=self.remove_item)
        return 0

    def remove_item(self, key):
        if key in self.contents:
            del self.contents[key]
Exemple #2
0
 def __init__(self, alarm_queue: AlarmQueue, expiration_time_s: int,
              name: str) -> None:
     self.contents = {}
     self._alarm_queue = alarm_queue
     self._expiration_queue = ExpirationQueue(expiration_time_s)
     self._expiration_time = expiration_time_s
     self._name = name
Exemple #3
0
 def __init__(self,
              alarm_queue: AlarmQueue,
              expiration_time_s: int,
              log_removal: bool = False):
     self.contents = set()
     self._alarm_queue = alarm_queue
     self._expiration_queue = ExpirationQueue(expiration_time_s)
     self._expiration_time = expiration_time_s
     self._log_removal = log_removal
    def __init__(self, alarm_queue: AlarmQueue):
        self.recovered_blocks = []

        self._alarm_queue = alarm_queue

        self._bx_block_hash_to_sids = {}
        self._bx_block_hash_to_tx_hashes = {}
        self._bx_block_hash_to_block_hash = {}
        self._bx_block_hash_to_block = {}
        self.recovery_attempts_by_block = defaultdict(int)
        self._block_hash_to_bx_block_hashes = defaultdict(set)
        self._sid_to_bx_block_hashes = defaultdict(set)
        self._tx_hash_to_bx_block_hashes: Dict[
            Sha256Hash, Set[Sha256Hash]] = defaultdict(set)
        self._blocks_expiration_queue = ExpirationQueue(
            gateway_constants.BLOCK_RECOVERY_MAX_QUEUE_TIME)
Exemple #5
0
    def __init__(self, node: AbstractNode, network_num: int):
        """
        Constructor
        :param node: reference to node object
        :param network_num: network number
        """

        if node is None:
            raise ValueError("Node is required")

        if network_num is None:
            raise ValueError("Network number is required")

        self.node = node
        self.network_num = network_num

        self.tx_assign_alarm_scheduled = False

        self._tx_cache_key_to_short_ids = defaultdict(set)
        self._short_id_to_tx_cache_key = {}
        self._tx_cache_key_to_contents = {}
        self._tx_assignment_expire_queue: ExpirationQueue[
            int] = ExpirationQueue(node.opts.sid_expire_time)

        self._final_tx_confirmations_count = self._get_final_tx_confirmations_count(
        )
        self._tx_content_memory_limit = self._get_tx_contents_memory_limit()
        logger.debug(
            "Memory limit for transaction service by network number {} is {} bytes.",
            self.network_num, self._tx_content_memory_limit)

        # short ids seen in block ordered by them block hash
        self._short_ids_seen_in_block: OrderedDict[Sha256Hash,
                                                   List[int]] = OrderedDict()
        self._total_tx_contents_size = 0
        self._total_tx_removed_by_memory_limit = 0

        self._last_transaction_stats = TransactionServiceStats()
        self._removed_short_ids = set()
        if node.opts.dump_removed_short_ids:
            self.node.alarm_queue.register_alarm(
                constants.DUMP_REMOVED_SHORT_IDS_INTERVAL_S,
                self._dump_removed_short_ids)
        if constants.TRANSACTION_SERVICE_LOG_TRANSACTIONS_INTERVAL_S > 0:
            self.node.alarm_queue.register_alarm(
                constants.TRANSACTION_SERVICE_LOG_TRANSACTIONS_INTERVAL_S,
                self._log_transaction_service_histogram)
Exemple #6
0
class ExpiringDict(Generic[KT, VT]):
    """
    Dictionary with expiration time.
    """

    contents: Dict[KT, VT]
    _alarm_queue: AlarmQueue
    _expiration_queue: ExpirationQueue[KT]
    _expiration_time: int
    _name: str

    def __init__(self, alarm_queue: AlarmQueue, expiration_time_s: int,
                 name: str) -> None:
        self.contents = {}
        self._alarm_queue = alarm_queue
        self._expiration_queue = ExpirationQueue(expiration_time_s)
        self._expiration_time = expiration_time_s
        self._name = name

    def __contains__(self, item: KT):
        return item in self.contents

    def __setitem__(self, key: KT, value: VT):
        if key in self.contents:
            self.contents[key] = value
        else:
            self.add(key, value)

    def __delitem__(self, key: KT):
        del self.contents[key]
        self._expiration_queue.remove(key)

    def __getitem__(self, item: KT) -> VT:
        return self.contents[item]

    def add(self, key: KT, value: VT) -> None:
        self.contents[key] = value
        self._expiration_queue.add(key)
        self._alarm_queue.register_approx_alarm(
            self._expiration_time * 2,
            self._expiration_time,
            self.cleanup,
            alarm_name=f"ExpiringDict[{self._name}]#cleanup")

    def cleanup(self) -> float:
        self._expiration_queue.remove_expired(remove_callback=self.remove_item)
        return 0

    def remove_item(self, key: KT) -> Optional[VT]:
        if key in self.contents:
            return self.contents.pop(key)
        else:
            return None
class BlockRecoveryService:
    """
    Service class that handles blocks gateway receives with unknown transaction short ids are contents.

    Attributes
    ----------
    recovered blocks: queue to which recovered blocks are pushed to
    _alarm_queue: reference to alarm queue to schedule cleanup on

    _bx_block_hash_to_sids: map of compressed block hash to its set of unknown short ids
    _bx_block_hash_to_tx_hashes: map of compressed block hash to its set of unknown transaction hashes
    _bx_block_hash_to_block_hash: map of compressed block hash to its original block hash
    _bx_block_hash_to_block: map of compressed block hash to its compressed byte representation
    _block_hash_to_bx_block_hashes: map of original block hash to compressed block hashes waiting for recovery
    _sid_to_bx_block_hashes: map of short id to compressed block hashes waiting for recovery
    _tx_hash_to_bx_block_hashes: map of transaction hash to block hashes waiting for recovery

    _cleanup_scheduled: whether block recovery has an alarm scheduled to clean up recovering blocks
    _blocks_expiration_queue: queue to trigger expiration of waiting for block recovery
    """

    _alarm_queue: AlarmQueue
    _bx_block_hash_to_sids: Dict[Sha256Hash, Set[int]]
    _bx_block_hash_to_tx_hashes: Dict[Sha256Hash, Set[Sha256Hash]]
    _bx_block_hash_to_block_hash: Dict[Sha256Hash, Sha256Hash]
    _bx_block_hash_to_block: Dict[Sha256Hash, memoryview]
    _block_hash_to_bx_block_hashes: Dict[Sha256Hash, Set[Sha256Hash]]
    _sid_to_bx_block_hashes: Dict[int, Set[Sha256Hash]]
    _tx_hash_to_bx_block_hashes: Dict[Sha256Hash, Set[Sha256Hash]]
    _blocks_expiration_queue: ExpirationQueue
    _cleanup_scheduled: bool = False

    recovery_attempts_by_block: Dict[Sha256Hash, int]

    def __init__(self, alarm_queue: AlarmQueue):
        self.recovered_blocks = []

        self._alarm_queue = alarm_queue

        self._bx_block_hash_to_sids = {}
        self._bx_block_hash_to_tx_hashes = {}
        self._bx_block_hash_to_block_hash = {}
        self._bx_block_hash_to_block = {}
        self.recovery_attempts_by_block = defaultdict(int)
        self._block_hash_to_bx_block_hashes = defaultdict(set)
        self._sid_to_bx_block_hashes = defaultdict(set)
        self._tx_hash_to_bx_block_hashes: Dict[
            Sha256Hash, Set[Sha256Hash]] = defaultdict(set)
        self._blocks_expiration_queue = ExpirationQueue(
            gateway_constants.BLOCK_RECOVERY_MAX_QUEUE_TIME)

    def add_block(self, bx_block: memoryview, block_hash: Sha256Hash,
                  unknown_tx_sids: List[int],
                  unknown_tx_hashes: List[Sha256Hash]):
        """
        Adds a block that needs to recovery. Tracks unknown short ids and contents as they come in.
        :param bx_block: bytearray representation of compressed block
        :param block_hash: original ObjectHash of block
        :param unknown_tx_sids: list of unknown short ids
        :param unknown_tx_hashes: list of unknown tx ObjectHashes
        """
        logger.trace(
            "Recovering block with {} unknown short ids and {} contents: {}",
            len(unknown_tx_sids), len(unknown_tx_hashes), block_hash)
        bx_block_hash = Sha256Hash(crypto.double_sha256(bx_block))

        self._bx_block_hash_to_block[bx_block_hash] = bx_block
        self._bx_block_hash_to_block_hash[bx_block_hash] = block_hash
        self._bx_block_hash_to_sids[bx_block_hash] = set(unknown_tx_sids)
        self._bx_block_hash_to_tx_hashes[bx_block_hash] = set(
            unknown_tx_hashes)

        self._block_hash_to_bx_block_hashes[block_hash].add(bx_block_hash)
        for sid in unknown_tx_sids:
            self._sid_to_bx_block_hashes[sid].add(bx_block_hash)
        for tx_hash in unknown_tx_hashes:
            self._tx_hash_to_bx_block_hashes[tx_hash].add(bx_block_hash)

        self._blocks_expiration_queue.add(bx_block_hash)
        self._schedule_cleanup()

    def get_blocks_awaiting_recovery(self) -> List[BlockRecoveryInfo]:
        """
        Fetch all blocks still awaiting recovery and retry.
        """
        blocks_awaiting_recovery = []
        for block_hash, bx_block_hashes in self._block_hash_to_bx_block_hashes.items(
        ):
            unknown_short_ids = set()
            unknown_transaction_hashes = set()
            for bx_block_hash in bx_block_hashes:
                unknown_short_ids.update(
                    self._bx_block_hash_to_sids[bx_block_hash])
                unknown_transaction_hashes.update(
                    self._bx_block_hash_to_tx_hashes[bx_block_hash])
            blocks_awaiting_recovery.append(
                BlockRecoveryInfo(block_hash, unknown_short_ids,
                                  unknown_transaction_hashes, time.time()))
        return blocks_awaiting_recovery

    def check_missing_sid(self, sid: int,
                          recovered_txs_source: RecoveredTxsSource) -> bool:
        """
        Resolves recovering blocks depend on sid.
        :param sid: SID info that has been processed
        :param recovered_txs_source: source of recovered transaction
        """
        if sid in self._sid_to_bx_block_hashes:
            logger.trace("Resolved previously unknown short id: {0}.", sid)

            bx_block_hashes = self._sid_to_bx_block_hashes[sid]
            for bx_block_hash in bx_block_hashes:
                if bx_block_hash in self._bx_block_hash_to_sids:
                    if sid in self._bx_block_hash_to_sids[bx_block_hash]:
                        self._bx_block_hash_to_sids[bx_block_hash].discard(sid)
                        self._check_if_recovered(bx_block_hash,
                                                 recovered_txs_source)

            del self._sid_to_bx_block_hashes[sid]
            return True
        else:
            return False

    def check_missing_tx_hash(
            self, tx_hash: Sha256Hash,
            recovered_txs_source: RecoveredTxsSource) -> bool:
        """
        Resolves recovering blocks depend on transaction hash.
        :param tx_hash: transaction info that has been processed
        :param recovered_txs_source: source of recovered transaction
        """
        if tx_hash in self._tx_hash_to_bx_block_hashes:
            logger.trace("Resolved previously unknown transaction hash {0}.",
                         tx_hash)

            bx_block_hashes = self._tx_hash_to_bx_block_hashes[tx_hash]
            for bx_block_hash in bx_block_hashes:
                if bx_block_hash in self._bx_block_hash_to_tx_hashes:
                    if tx_hash in self._bx_block_hash_to_tx_hashes[
                            bx_block_hash]:
                        self._bx_block_hash_to_tx_hashes[
                            bx_block_hash].discard(tx_hash)
                        self._check_if_recovered(bx_block_hash,
                                                 recovered_txs_source)

            del self._tx_hash_to_bx_block_hashes[tx_hash]
            return True
        else:
            return False

    def cancel_recovery_for_block(self, block_hash: Sha256Hash) -> bool:
        """
        Cancels recovery for all compressed blocks matching a block hash
        :param block_hash: ObjectHash
        """
        if block_hash in self._block_hash_to_bx_block_hashes:
            logger.trace("Cancelled block recovery for block: {}", block_hash)
            self._remove_recovered_block_hash(block_hash)
            return True
        else:
            return False

    def awaiting_recovery(self, block_hash: Sha256Hash) -> bool:
        return block_hash in self._block_hash_to_bx_block_hashes

    # pyre-fixme[9]: clean_up_time has type `float`; used as `None`.
    def cleanup_old_blocks(self, clean_up_time: float = None):
        """
        Cleans up old compressed blocks awaiting recovery.
        :param clean_up_time:
        """
        logger.debug("Cleaning up block recovery.")
        num_blocks_awaiting_recovery = len(self._bx_block_hash_to_block)
        self._blocks_expiration_queue.remove_expired(
            current_time=clean_up_time,
            remove_callback=self._remove_not_recovered_block)
        logger.debug(
            "Cleaned up {} blocks awaiting recovery.",
            num_blocks_awaiting_recovery - len(self._bx_block_hash_to_block))

        if self._bx_block_hash_to_block:
            return gateway_constants.BLOCK_RECOVERY_MAX_QUEUE_TIME

        # disable clean up until receive the next block with unknown tx
        self._cleanup_scheduled = False
        return 0

    def clean_up_recovered_blocks(self):
        """
        Cleans up blocks that have finished recovery.
        :return:
        """
        logger.trace("Cleaning up {} recovered blocks.",
                     len(self.recovered_blocks))
        del self.recovered_blocks[:]

    def _check_if_recovered(self, bx_block_hash: Sha256Hash,
                            recovered_txs_source: RecoveredTxsSource):
        """
        Checks if a compressed block has received all short ids and transaction hashes necessary to recover.
        Adds block to recovered blocks if so.
        :param bx_block_hash: ObjectHash
        :return:
        """
        if self._is_block_recovered(bx_block_hash):
            bx_block = self._bx_block_hash_to_block[bx_block_hash]
            block_hash = self._bx_block_hash_to_block_hash[bx_block_hash]
            logger.debug(
                "Recovery status for block {}, compress block hash {}: "
                "Block recovered by gateway. Source of recovered txs is {}.",
                block_hash, bx_block_hash, recovered_txs_source)
            self._remove_recovered_block_hash(block_hash)
            self.recovered_blocks.append((bx_block, recovered_txs_source))

    def _is_block_recovered(self, bx_block_hash: Sha256Hash):
        """
        Indicates if a compressed block has received all short ids and transaction hashes necessary to recover.
        :param bx_block_hash: ObjectHash
        :return:
        """
        return len(self._bx_block_hash_to_sids[bx_block_hash]) == 0 and len(
            self._bx_block_hash_to_tx_hashes[bx_block_hash]) == 0

    def _remove_recovered_block_hash(self, block_hash: Sha256Hash):
        """
        Removes all compressed blocks awaiting recovery that are matches to a recovered block hash.
        :param block_hash: ObjectHash
        :return:
        """
        if block_hash in self._block_hash_to_bx_block_hashes:
            for bx_block_hash in self._block_hash_to_bx_block_hashes[
                    block_hash]:
                if bx_block_hash in self._bx_block_hash_to_block:
                    self._remove_sid_and_tx_mapping_for_bx_block_hash(
                        bx_block_hash)
                    del self._bx_block_hash_to_block[bx_block_hash]
                    del self._bx_block_hash_to_block_hash[bx_block_hash]
            del self._block_hash_to_bx_block_hashes[block_hash]

    def _remove_sid_and_tx_mapping_for_bx_block_hash(
            self, bx_block_hash: Sha256Hash):
        """
        Removes all short id and transaction mapping for a compressed block.
        :param bx_block_hash:
        :return:
        """
        if bx_block_hash not in self._bx_block_hash_to_block:
            raise ValueError(
                "Can't remove mapping for a block that isn't being recovered.")

        for sid in self._bx_block_hash_to_sids[bx_block_hash]:
            if sid in self._sid_to_bx_block_hashes:
                self._sid_to_bx_block_hashes[sid].discard(bx_block_hash)
                if len(self._sid_to_bx_block_hashes[sid]) == 0:
                    del self._sid_to_bx_block_hashes[sid]
        del self._bx_block_hash_to_sids[bx_block_hash]

        for tx_hash in self._bx_block_hash_to_tx_hashes[bx_block_hash]:
            if tx_hash in self._tx_hash_to_bx_block_hashes:
                self._tx_hash_to_bx_block_hashes[tx_hash].discard(
                    bx_block_hash)
                if len(self._tx_hash_to_bx_block_hashes[tx_hash]) == 0:
                    del self._tx_hash_to_bx_block_hashes[tx_hash]
        del self._bx_block_hash_to_tx_hashes[bx_block_hash]

    def _remove_not_recovered_block(self, bx_block_hash: Sha256Hash):
        """
        Removes compressed block that has not recovered.
        :param bx_block_hash: ObjectHash
        """
        if bx_block_hash in self._bx_block_hash_to_block:
            logger.trace("Block has failed recovery: {}", bx_block_hash)

            self._remove_sid_and_tx_mapping_for_bx_block_hash(bx_block_hash)
            del self._bx_block_hash_to_block[bx_block_hash]

            block_hash = self._bx_block_hash_to_block_hash.pop(bx_block_hash)
            self._block_hash_to_bx_block_hashes[block_hash].discard(
                bx_block_hash)
            if len(self._block_hash_to_bx_block_hashes[block_hash]) == 0:
                del self._block_hash_to_bx_block_hashes[block_hash]

    def _schedule_cleanup(self):
        if not self._cleanup_scheduled and self._bx_block_hash_to_block:
            logger.trace("Scheduling block recovery cleanup in {} seconds.",
                         gateway_constants.BLOCK_RECOVERY_MAX_QUEUE_TIME)
            self._alarm_queue.register_alarm(
                gateway_constants.BLOCK_RECOVERY_MAX_QUEUE_TIME,
                self.cleanup_old_blocks)
            self._cleanup_scheduled = True
Exemple #8
0
 def __init__(self, alarm_queue, expiration_time_s):
     self.contents = {}
     self._alarm_queue = alarm_queue
     self._expiration_queue = ExpirationQueue(expiration_time_s)
     self._expiration_time = expiration_time_s
Exemple #9
0
class ExpiringSet(Generic[T]):
    """
    Set with expiration time.
    """

    contents: Set[T]
    _alarm_queue: AlarmQueue
    _expiration_queue: ExpirationQueue[T]
    _expiration_time: int
    _log_removal: bool
    _name: str

    def __init__(self,
                 alarm_queue: AlarmQueue,
                 expiration_time_s: int,
                 name: str,
                 log_removal: bool = False):
        self.contents = set()
        self._alarm_queue = alarm_queue
        self._expiration_queue = ExpirationQueue(expiration_time_s)
        self._expiration_time = expiration_time_s
        self._log_removal = log_removal
        self._name = name

    def __contains__(self, item: T) -> bool:
        return item in self.contents

    def __len__(self) -> int:
        return len(self.contents)

    def add(self, item: T) -> None:
        self.contents.add(item)
        self._expiration_queue.add(item)
        self._alarm_queue.register_approx_alarm(
            self._expiration_time * 2,
            self._expiration_time,
            self.cleanup,
            alarm_name=f"ExpiringSet[{self._name}]#cleanup")

    def remove(self, item: T) -> None:
        self.contents.remove(item)

    def get_recent_items(self, count: int) -> List[T]:
        items = []
        # noinspection PyTypeChecker
        entries = reversed(self._expiration_queue.queue.keys())

        try:
            for _ in range(count):
                items.append(next(entries))
        except StopIteration as _e:
            logger.debug("Attempted to fetch {} entries, but only {} existed.",
                         count, len(items))

        return items

    def cleanup(self) -> int:
        self._expiration_queue.remove_expired(
            remove_callback=self._safe_remove_item)
        return 0

    def _safe_remove_item(self, item: T):
        if self._log_removal:
            logger.debug("Removing {} from expiring set.", item)
        if item in self.contents:
            self.contents.remove(item)
Exemple #10
0
 def __init__(self, expiration_time_s, alarm_queue):
     self._cache = {}
     self._expiration_queue = ExpirationQueue(expiration_time_s)
     self._expiration_time_s = expiration_time_s
     self._alarm_queue = alarm_queue
Exemple #11
0
class EncryptedCache(object):
    """
    Storage for in-progress received or sent encrypted blocks.
    """
    def __init__(self, expiration_time_s, alarm_queue):
        self._cache = {}
        self._expiration_queue = ExpirationQueue(expiration_time_s)
        self._expiration_time_s = expiration_time_s
        self._alarm_queue = alarm_queue

    def encrypt_and_add_payload(self, payload):
        """
        Encrypts payload, computing a hash and storing it along with the key for later release.
        If encryption is disabled for dev, store ciphertext identical to hash_key.
        """
        key, ciphertext = symmetric_encrypt(bytes(payload))
        hash_key = crypto.double_sha256(ciphertext)
        self._add(hash_key, key, ciphertext, payload)
        return ciphertext, hash_key

    def add_ciphertext(self, hash_key, ciphertext):
        if hash_key in self._cache:
            self._cache[hash_key].ciphertext = ciphertext
        else:
            self._add(hash_key, None, ciphertext, None)

    def add_key(self, hash_key, encryption_key):
        if hash_key in self._cache:
            self._cache[hash_key].key = encryption_key
        else:
            self._add(hash_key, encryption_key, None, None)

    def decrypt_and_get_payload(self, hash_key, encryption_key):
        """
        Retrieves and decrypts stored ciphertext.
        Returns None if unable to decrypt.
        """
        cache_item = self._cache[hash_key]
        cache_item.key = encryption_key
        return self._safe_decrypt_item(cache_item, hash_key)

    def decrypt_ciphertext(self, hash_key, ciphertext):
        """
        Retrieves and decrypts ciphertext with stored key info. Stores info in cache.
        Returns None if unable to decrypt.
        """

        cache_item = self._cache[hash_key]
        cache_item.ciphertext = ciphertext

        return self._safe_decrypt_item(cache_item, hash_key)

    def get_encryption_key(self, hash_key):
        return self._cache[hash_key].key

    def pop_ciphertext(self, hash_key):
        return self._cache.pop(hash_key).ciphertext

    def has_encryption_key_for_hash(self, hash_key):
        return hash_key in self._cache and self._cache[hash_key].key is not None

    def has_ciphertext_for_hash(self, hash_key):
        return hash_key in self._cache and self._cache[
            hash_key].ciphertext is not None

    def hash_keys(self):
        return self._cache.keys()

    def encryption_items(self):
        return self._cache.values()

    def remove_item(self, hash_key):
        if hash_key in self._cache:
            del self._cache[hash_key]

    def _add(self, hash_key, encryption_key, ciphertext, payload):
        self._cache[hash_key] = EncryptionCacheItem(encryption_key, ciphertext,
                                                    payload)
        self._expiration_queue.add(hash_key)
        self._alarm_queue.register_approx_alarm(self._expiration_time_s * 2,
                                                self._expiration_time_s,
                                                self._cleanup_old_cache_items)

    def _cleanup_old_cache_items(self):
        self._expiration_queue.remove_expired(remove_callback=self.remove_item)

    def __iter__(self):
        return iter(self._cache)

    def __len__(self):
        return len(self._cache)

    def _safe_decrypt_item(self, cache_item, hash_key):
        try:
            return cache_item.decrypt()
        except DecryptionError:
            failed_ciphertext = self.pop_ciphertext(hash_key)
            logger.warning(
                "Could not decrypt encrypted item with hash {}. Last four bytes: {}",
                convert.bytes_to_hex(hash_key),
                convert.bytes_to_hex(failed_ciphertext[-4:]))
            return None
Exemple #12
0
class ExpiringSet(Generic[T]):
    """
    Set with expiration time.

    For determining if items are in the set, use "if item in expiring_set.contents".
    __contains__ is intentionally not overwritten. This is a performance critical class,
    and we're avoiding extra function call overhead.
    """

    contents: Set[T]
    _alarm_queue: AlarmQueue
    _expiration_queue: ExpirationQueue[T]
    _expiration_time: int
    _log_removal: bool

    def __init__(self,
                 alarm_queue: AlarmQueue,
                 expiration_time_s: int,
                 log_removal: bool = False):
        self.contents = set()
        self._alarm_queue = alarm_queue
        self._expiration_queue = ExpirationQueue(expiration_time_s)
        self._expiration_time = expiration_time_s
        self._log_removal = log_removal

    def __contains__(self, item: T):
        return item in self.contents

    def __len__(self) -> int:
        return len(self.contents)

    def add(self, item: T):
        self.contents.add(item)
        self._expiration_queue.add(item)
        self._alarm_queue.register_approx_alarm(self._expiration_time * 2,
                                                self._expiration_time,
                                                self.cleanup)

    def get_recent_items(self, count: int) -> List[T]:
        items = []
        # noinspection PyTypeChecker
        entries = reversed(self._expiration_queue.queue.keys()
                           )  # pyre-ignore queue is actually an OrderedDict

        try:
            for i in range(count):
                items.append(next(entries))
        except StopIteration as _e:
            logger.debug("Attempted to fetch {} entries, but only {} existed.",
                         count, len(items))

        return items

    def cleanup(self):
        self._expiration_queue.remove_expired(
            remove_callback=self._safe_remove_item)
        return 0

    def _safe_remove_item(self, item: T):
        if self._log_removal:
            logger.debug("Removing {} from expiring set.", item)
        if item in self.contents:
            self.contents.remove(item)
Exemple #13
0
 def setUp(self):
     self.time_to_live = 60
     self.queue = ExpirationQueue(self.time_to_live)
     self.removed_items = []
Exemple #14
0
class ExpirationQueueTests(unittest.TestCase):
    def setUp(self):
        self.time_to_live = 60
        self.queue = ExpirationQueue(self.time_to_live)
        self.removed_items = []

    def test_expiration_queue(self):
        # adding 2 items to the queue with 1 second difference
        item1 = 1
        item2 = 2

        self.queue.add(item1)
        time_1_added = time.time()

        time.time = MagicMock(return_value=time.time() + 1)

        self.queue.add(item2)
        time_2_added = time.time()

        self.assertEqual(len(self.queue), 2)
        self.assertEqual(int(time_1_added),
                         int(self.queue.get_oldest_item_timestamp()))
        self.assertEqual(item1, self.queue.get_oldest())

        # check that nothing is removed from queue before the first item expires
        self.queue.remove_expired(time_1_added + self.time_to_live / 2,
                                  remove_callback=self._remove_item)
        self.assertEqual(len(self.queue), 2)
        self.assertEqual(len(self.removed_items), 0)

        # check that first item removed after first item expired
        self.queue.remove_expired(time_1_added + self.time_to_live + 1,
                                  remove_callback=self._remove_item)
        self.assertEqual(len(self.queue), 1)
        self.assertEqual(len(self.removed_items), 1)
        self.assertEqual(self.removed_items[0], item1)
        self.assertEqual(int(time_2_added),
                         int(self.queue.get_oldest_item_timestamp()))
        self.assertEqual(item2, self.queue.get_oldest())

        # check that second item is removed after second item expires
        self.queue.remove_expired(time_2_added + self.time_to_live + 1,
                                  remove_callback=self._remove_item)
        self.assertEqual(len(self.queue), 0)
        self.assertEqual(len(self.removed_items), 2)
        self.assertEqual(self.removed_items[0], item1)
        self.assertEqual(self.removed_items[1], item2)

    def test_remove_oldest_item(self):
        items_count = 10

        for i in range(items_count):
            self.queue.add(i)

        self.assertEqual(items_count, len(self.queue))

        removed_items_1 = []

        for i in range(items_count):
            self.assertEqual(i, self.queue.get_oldest())
            self.queue.remove_oldest(removed_items_1.append)
            self.queue.add(1000 + i)

        for i in range(items_count):
            self.assertEqual(i, removed_items_1[i])

        self.assertEqual(items_count, len(self.queue))

        removed_items_2 = []

        for i in range(items_count):
            self.assertEqual(i + 1000, self.queue.get_oldest())
            self.queue.remove_oldest(removed_items_2.append)

        for i in range(items_count):
            self.assertEqual(i + 1000, removed_items_2[i])

        self.assertEqual(0, len(self.queue))

    def test_remove_not_oldest_item(self):
        # adding 2 items to the queue with 1 second difference
        item1 = 9
        item2 = 5

        self.queue.add(item1)
        time_1_added = time.time()

        time.time = MagicMock(return_value=time.time() + 1)

        self.queue.add(item2)

        self.assertEqual(len(self.queue), 2)
        self.assertEqual(int(time_1_added),
                         int(self.queue.get_oldest_item_timestamp()))
        self.assertEqual(item1, self.queue.get_oldest())

        self.queue.remove(item2)
        self.assertEqual(len(self.queue), 1)
        self.assertEqual(int(time_1_added),
                         int(self.queue.get_oldest_item_timestamp()))
        self.assertEqual(item1, self.queue.get_oldest())

    def _remove_item(self, item):
        self.removed_items.append(item)