def test_composite_validator(validators_for_app):
    validator = CompositeValidator(validators_for_app['A'])
    assert ([type(item) for item in validator._validators] ==
        [SchemaValidator, RSASignatureValidator])

    validator.extend(validators_for_app['B'])
    assert ([type(item) for item in validator._validators] ==
        [SchemaValidator, RSASignatureValidator])
    assert len(validator._validators[0]._schemas) == 2

    local_public_key = validators_for_app['A'][0].local_public_key
    record = DHTRecord(key=DHTID.generate(source='field_b').to_bytes(),
                       subkey=DHTProtocol.serializer.dumps(local_public_key),
                       value=DHTProtocol.serializer.dumps(777),
                       expiration_time=hivemind.get_dht_time() + 10)

    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
    # Expect only one signature since two RSASignatureValidatos have been merged
    assert signed_record.value.count(b'[signature:') == 1
    # Expect successful validation since the second SchemaValidator has been merged to the first
    assert validator.validate(signed_record)
    assert validator.strip_value(signed_record) == record.value

    record = DHTRecord(key=DHTID.generate(source='unknown_key').to_bytes(),
                       subkey=DHTProtocol.IS_REGULAR_VALUE,
                       value=DHTProtocol.serializer.dumps(777),
                       expiration_time=hivemind.get_dht_time() + 10)

    signed_record = dataclasses.replace(record, value=validator.sign_value(record))
    assert signed_record.value.count(b'[signature:') == 0
    # Expect failed validation since `unknown_key` is not a part of any schema
    assert not validator.validate(signed_record)
Exemple #2
0
def test_ids_basic():
    # basic functionality tests
    for i in range(100):
        id1, id2 = DHTID.generate(), DHTID.generate()
        assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX
        assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2,
                                                                  id2) == 0
        assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2)
        assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes(
            id2.to_bytes()) == id2
Exemple #3
0
def test_routing_table_search():
    for table_size, lower_active, upper_active in [(10, 10, 10),
                                                   (10_000, 800, 1100)]:
        node_id = DHTID.generate()
        routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
        num_added = 0
        total_nodes = 0

        for phony_neighbor_port in random.sample(range(1_000_000), table_size):
            routing_table.add_or_update_node(
                DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
            new_total = sum(
                len(bucket.nodes_to_endpoint)
                for bucket in routing_table.buckets)
            num_added += new_total > total_nodes
            total_nodes = new_total
Exemple #4
0
    async def get_many(
        self,
        keys: Collection[DHTKey],
        sufficient_expiration_time: Optional[DHTExpiration] = None,
        **kwargs
    ) -> Dict[DHTKey,
              Union[Optional[ValueWithExpiration[DHTValue]],
                    Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
        """
        Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.

        :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
        :param sufficient_expiration_time: if the search finds a value that expires after this time,
            default = time of call, find any value that did not expire by the time of call
            If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
        :param kwargs: for full list of parameters, see DHTNode.get_many_by_id
        :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
        :note: in order to check if get returned a value, please check if (expiration_time is None)
        """
        keys = tuple(keys)
        key_ids = [DHTID.generate(key) for key in keys]
        id_to_original_key = dict(zip(key_ids, keys))
        results_by_id = await self.get_many_by_id(key_ids,
                                                  sufficient_expiration_time,
                                                  **kwargs)
        return {
            id_to_original_key[key]: result_or_future
            for key, result_or_future in results_by_id.items()
        }
Exemple #5
0
    def __init__(self,
                 schema: pydantic.BaseModel,
                 *,
                 allow_extra_keys: bool = True,
                 prefix: Optional[str] = None):
        """
        :param schema: The Pydantic model (a subclass of pydantic.BaseModel).

            You must always use strict types for the number fields
            (e.g. ``StrictInt`` instead of ``int``,
            ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.).
            See the validate() docstring for details.

            The model will be patched to adjust it for the schema validation.

        :param allow_extra_keys: Whether to allow keys that are not defined in the schema.

            If a SchemaValidator is merged with another SchemaValidator, this option applies to
            keys that are not defined in each of the schemas.

        :param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields.
        """

        self._patch_schema(schema)
        self._schemas = [schema]

        self._key_id_to_field_name = {}
        for field in schema.__fields__.values():
            raw_key = f'{prefix}_{field.name}' if prefix is not None else field.name
            self._key_id_to_field_name[DHTID.generate(
                source=raw_key).to_bytes()] = field.name
        self._allow_extra_keys = allow_extra_keys
Exemple #6
0
def test_routing_table_parameters():
    for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
        (20, 5, 45, 65),
        (50, 5, 35, 45),
        (20, 10, 650, 800),
        (20, 1, 7, 15),
    ]:
        node_id = DHTID.generate()
        routing_table = RoutingTable(node_id,
                                     bucket_size=bucket_size,
                                     depth_modulo=modulo)
        for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
            routing_table.add_or_update_node(
                DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
        for bucket in routing_table.buckets:
            assert len(bucket.replacement_nodes) == 0 or len(
                bucket.nodes_to_endpoint) <= bucket.size
        assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
            f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}"
        )
Exemple #7
0
def test_routing_table_basic():
    node_id = DHTID.generate()
    routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
    added_nodes = []

    for phony_neighbor_port in random.sample(range(10000), 100):
        phony_id = DHTID.generate()
        routing_table.add_or_update_node(phony_id,
                                         f'{LOCALHOST}:{phony_neighbor_port}')
        assert phony_id in routing_table
        assert f'{LOCALHOST}:{phony_neighbor_port}' in routing_table
        assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
        assert routing_table[f'{LOCALHOST}:{phony_neighbor_port}'] == phony_id
        added_nodes.append(phony_id)

    assert routing_table.buckets[
        0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
    for bucket in routing_table.buckets:
        assert len(
            bucket.replacement_nodes
        ) == 0, "There should be no replacement nodes in a table with 100 entries"
    assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets)

    random_node = random.choice(added_nodes)
    assert routing_table.get(node_id=random_node) == routing_table[random_node]
    dummy_node = DHTID.generate()
    assert (dummy_node
            not in routing_table) == (routing_table.get(node_id=dummy_node) is
                                      None)

    for node in added_nodes:
        found_bucket_index = routing_table.get_bucket_index(node)
        for bucket_index, bucket in enumerate(routing_table.buckets):
            if bucket.lower <= node < bucket.upper:
                break
        else:
            raise ValueError(
                "Naive search could not find bucket. Universe has gone crazy.")
        assert bucket_index == found_bucket_index
Exemple #8
0
            num_added += new_total > total_nodes
            total_nodes = new_total
        num_replacements = sum(
            len(bucket.replacement_nodes) for bucket in routing_table.buckets)

        all_active_neighbors = list(
            chain(*(bucket.nodes_to_endpoint.keys()
                    for bucket in routing_table.buckets)))
        assert lower_active <= len(all_active_neighbors) <= upper_active
        assert len(all_active_neighbors) == num_added
        assert num_added + num_replacements == table_size

        # random queries
        for i in range(1000):
            k = random.randint(1, 100)
            query_id = DHTID.generate()
            exclude = query_id if random.random() < 0.5 else None
            our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors(
                query_id, k=k, exclude=exclude))
            reference_knn = heapq.nsmallest(k,
                                            all_active_neighbors,
                                            key=query_id.xor_distance)
            assert all(our == ref
                       for our, ref in zip_longest(our_knn, reference_knn))
            assert all(
                our_endpoint == routing_table[our_node]
                for our_node, our_endpoint in zip(our_knn, our_endpoints))

        # queries from table
        for i in range(1000):
            k = random.randint(1, 100)
Exemple #9
0
    async def create(cls,
                     node_id: Optional[DHTID] = None,
                     initial_peers: List[Endpoint] = (),
                     bucket_size: int = 20,
                     num_replicas: int = 5,
                     depth_modulo: int = 5,
                     parallel_rpc: int = None,
                     wait_timeout: float = 5,
                     refresh_timeout: Optional[float] = None,
                     bootstrap_timeout: Optional[float] = None,
                     cache_locally: bool = True,
                     cache_nearest: int = 1,
                     cache_size=None,
                     cache_refresh_before_expiry: float = 5,
                     cache_on_store: bool = True,
                     reuse_get_requests: bool = True,
                     num_workers: int = 1,
                     chunk_size: int = 16,
                     listen: bool = True,
                     listen_on: Endpoint = "0.0.0.0:*",
                     **kwargs) -> DHTNode:
        """
        :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
        :param initial_peers: connects to these peers to populate routing table, defaults to no peers
        :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
          either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
          Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
        :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
        :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
        :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
          Reduce this value if your RPC requests register no response despite the peer sending the response.
        :param wait_timeout: a kademlia rpc request is deemed lost if we did not receive a reply in this many seconds
        :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
          if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
        :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
        :param cache_locally: if True, caches all values (stored or found) in a node-local cache
        :param cache_on_store: if True, update cache entries for a key after storing a new item for that key
        :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
          nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
        :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
        :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
          if they are accessed this many seconds before expiration time.
        :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
          all concurrent get requests for the same key will reuse the procedure that is currently in progress
        :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
        :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
        :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
          if False, this node will refuse any incoming request, effectively being only a "client"
        :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
        :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
        :param kwargs: extra parameters used in grpc.aio.server
        """
        self = cls(_initialized_with_create=True)
        self.node_id = node_id = node_id if node_id is not None else DHTID.generate(
        )
        self.num_replicas, self.num_workers, self.chunk_size = num_replicas, num_workers, chunk_size
        self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh

        self.reuse_get_requests = reuse_get_requests
        self.pending_get_requests = defaultdict(
            partial(SortedList,
                    key=lambda _res: -_res.sufficient_expiration_time))

        # caching policy
        self.refresh_timeout = refresh_timeout
        self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
        self.cache_refresh_before_expiry = cache_refresh_before_expiry
        self.cache_refresh_queue = CacheRefreshQueue()
        self.cache_refresh_evt = asyncio.Event()
        self.cache_refresh_task = None

        self.protocol = await DHTProtocol.create(self.node_id, bucket_size,
                                                 depth_modulo, num_replicas,
                                                 wait_timeout, parallel_rpc,
                                                 cache_size, listen, listen_on,
                                                 **kwargs)
        self.port = self.protocol.port

        if initial_peers:
            # stage 1: ping initial_peers, add each other to the routing table
            bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
            start_time = get_dht_time()
            ping_tasks = map(self.protocol.call_ping, initial_peers)
            finished_pings, unfinished_pings = await asyncio.wait(
                ping_tasks, return_when=asyncio.FIRST_COMPLETED)

            # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
            if unfinished_pings:
                finished_in_time, stragglers = await asyncio.wait(
                    unfinished_pings,
                    timeout=bootstrap_timeout - get_dht_time() + start_time)
                for straggler in stragglers:
                    straggler.cancel()
                finished_pings |= finished_in_time

            if not finished_pings:
                warn(
                    "DHTNode bootstrap failed: none of the initial_peers responded to a ping."
                )

            # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
            # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
            # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
            await asyncio.wait([
                asyncio.create_task(self.find_nearest_nodes([self.node_id])),
                asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)
            ],
                               return_when=asyncio.FIRST_COMPLETED)

        if self.refresh_timeout is not None:
            asyncio.create_task(
                self._refresh_routing_table(period=self.refresh_timeout))
        return self
Exemple #10
0
    async def store_many(self,
                         keys: List[DHTKey],
                         values: List[DHTValue],
                         expiration_time: Union[DHTExpiration,
                                                List[DHTExpiration]],
                         subkeys: Optional[Union[
                             Subkey, List[Optional[Subkey]]]] = None,
                         exclude_self: bool = False,
                         await_all_replicas=True,
                         **kwargs) -> Dict[DHTKey, bool]:
        """
        Traverse DHT to find up :num_replicas: to best nodes to store multiple (key, value, expiration_time) pairs.

        :param keys: arbitrary serializable keys associated with each value
        :param values: serializable "payload" for each key
        :param expiration_time: either one expiration time for all keys or individual expiration times (see class doc)
        :param subkeys: an optional list of same shape as keys. If specified, this
        :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
        :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
        :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
        :param await_all_replicas: if False, this function returns after first store_ok and proceeds in background
            if True, the function will wait for num_replicas successful stores or running out of beam_size nodes
        :returns: for each key: True if store succeeds, False if it fails (due to no response or newer value)
        """
        if isinstance(expiration_time, DHTExpiration):
            expiration_time = [expiration_time] * len(keys)
        if subkeys is None:
            subkeys = [None] * len(keys)

        assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
            "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."

        key_id_to_data: DefaultDict[DHTID, List[Tuple[
            DHTKey, Subkey, DHTValue, DHTExpiration]]] = defaultdict(list)
        for key, subkey, value, expiration in zip(keys, subkeys, values,
                                                  expiration_time):
            key_id_to_data[DHTID.generate(source=key)].append(
                (key, subkey, value, expiration))

        unfinished_key_ids = set(key_id_to_data.keys(
        ))  # use this set to ensure that each store request is finished
        store_ok = {(key, subkey): None
                    for key, subkey in zip(keys, subkeys)
                    }  # outputs, updated during search
        store_finished_events = {(key, subkey): asyncio.Event()
                                 for key, subkey in zip(keys, subkeys)}

        # pre-populate node_to_endpoint
        node_to_endpoint: Dict[DHTID, Endpoint] = dict()
        for key_id in unfinished_key_ids:
            node_to_endpoint.update(
                self.protocol.routing_table.get_nearest_neighbors(
                    key_id, self.protocol.bucket_size, exclude=self.node_id))

        async def on_found(key_id: DHTID, nearest_nodes: List[DHTID],
                           visited_nodes: Set[DHTID]) -> None:
            """ This will be called once per key when find_nearest_nodes is done for a particular node """
            # note: we use callbacks instead of returned values to call store immediately without waiting for stragglers
            assert key_id in unfinished_key_ids, "Internal error: traverse_dht finished the same query twice"
            assert self.node_id not in nearest_nodes
            unfinished_key_ids.remove(key_id)

            # ensure k nodes stored the value, optionally include self.node_id as a candidate
            num_successful_stores = 0
            pending_store_tasks = set()
            store_candidates = sorted(
                nearest_nodes + ([] if exclude_self else [self.node_id]),
                key=key_id.xor_distance,
                reverse=True)  # ordered so that .pop() returns nearest
            [original_key,
             *_], current_subkeys, current_values, current_expirations = zip(
                 *key_id_to_data[key_id])
            binary_values: List[bytes] = list(
                map(self.protocol.serializer.dumps, current_values))

            while num_successful_stores < self.num_replicas and (
                    store_candidates or pending_store_tasks):
                while store_candidates and num_successful_stores + len(
                        pending_store_tasks) < self.num_replicas:
                    node_id: DHTID = store_candidates.pop(
                    )  # nearest untried candidate

                    if node_id == self.node_id:
                        num_successful_stores += 1
                        for subkey, value, expiration_time in zip(
                                current_subkeys, binary_values,
                                current_expirations):
                            store_ok[original_key,
                                     subkey] = self.protocol.storage.store(
                                         key_id,
                                         value,
                                         expiration_time,
                                         subkey=subkey)
                            if not await_all_replicas:
                                store_finished_events[original_key,
                                                      subkey].set()
                    else:
                        pending_store_tasks.add(
                            asyncio.create_task(
                                self.protocol.call_store(
                                    node_to_endpoint[node_id],
                                    keys=[key_id] * len(current_values),
                                    values=binary_values,
                                    expiration_time=current_expirations,
                                    subkeys=current_subkeys)))

                # await nearest task. If it fails, dispatch more on the next iteration
                if pending_store_tasks:
                    finished_store_tasks, pending_store_tasks = await asyncio.wait(
                        pending_store_tasks,
                        return_when=asyncio.FIRST_COMPLETED)
                    for task in finished_store_tasks:
                        if task.result() is not None:
                            num_successful_stores += 1
                            for subkey, store_status in zip(
                                    current_subkeys, task.result()):
                                store_ok[original_key, subkey] = store_status
                                if not await_all_replicas:
                                    store_finished_events[original_key,
                                                          subkey].set()

            if self.cache_on_store:
                self._update_cache_on_store(key_id,
                                            current_subkeys,
                                            binary_values,
                                            current_expirations,
                                            store_ok=[
                                                store_ok[original_key, subkey]
                                                for subkey in current_subkeys
                                            ])

            for subkey, value_bytes, expiration in zip(current_subkeys,
                                                       binary_values,
                                                       current_expirations):
                store_finished_events[original_key, subkey].set()

        store_task = asyncio.create_task(
            self.find_nearest_nodes(queries=set(unfinished_key_ids),
                                    k_nearest=self.num_replicas,
                                    node_to_endpoint=node_to_endpoint,
                                    found_callback=on_found,
                                    exclude_self=exclude_self,
                                    **kwargs))
        try:
            await asyncio.wait([
                evt.wait() for evt in store_finished_events.values()
            ])  # wait for items to be stored
            assert len(
                unfinished_key_ids
            ) == 0, "Internal error: traverse_dht didn't finish search"
            return {(key, subkey) if subkey else key: status or False
                    for (key, subkey), status in store_ok.items()}
        except asyncio.CancelledError as e:
            store_task.cancel()
            raise e
Exemple #11
0
    async def get_many(
        self,
        keys: Collection[DHTKey],
        sufficient_expiration_time: Optional[DHTExpiration] = None,
        num_workers: Optional[int] = None,
        beam_size: Optional[int] = None
    ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]:
        """
        :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
        :param sufficient_expiration_time: if the search finds a value that expires after this time,
            default = time of call, find any value that did not expire by the time of call
            If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
        :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
        :param num_workers: override for default num_workers, see traverse_dht num_workers param
        :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key
        :note: in order to check if get returned a value, please check (expiration_time is None)
        """
        key_ids = [DHTID.generate(key) for key in keys]
        id_to_original_key = dict(zip(key_ids, keys))
        sufficient_expiration_time = sufficient_expiration_time or get_dht_time(
        )
        beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
        num_workers = num_workers if num_workers is not None else self.num_workers

        # search metadata
        unfinished_key_ids = set(
            key_ids)  # track key ids for which the search is not terminated
        node_to_endpoint: Dict[
            DHTID, Endpoint] = dict()  # global routing table for all queries

        SearchResult = namedtuple(
            "SearchResult",
            ["binary_value", "expiration_time", "source_node_id"])
        latest_results = {
            key_id: SearchResult(b'', -float('inf'), None)
            for key_id in key_ids
        }

        # stage 1: value can be stored in our local cache
        for key_id in key_ids:
            maybe_value, maybe_expiration_time = self.protocol.storage.get(
                key_id)
            if maybe_expiration_time is None:
                maybe_value, maybe_expiration_time = self.protocol.cache.get(
                    key_id)
            if maybe_expiration_time is not None and maybe_expiration_time > latest_results[
                    key_id].expiration_time:
                latest_results[key_id] = SearchResult(maybe_value,
                                                      maybe_expiration_time,
                                                      self.node_id)
                if maybe_expiration_time >= sufficient_expiration_time:
                    unfinished_key_ids.remove(key_id)

        # stage 2: traverse the DHT for any unfinished keys
        for key_id in unfinished_key_ids:
            node_to_endpoint.update(
                self.protocol.routing_table.get_nearest_neighbors(
                    key_id, self.protocol.bucket_size, exclude=self.node_id))

        async def get_neighbors(
            peer: DHTID, queries: Collection[DHTID]
        ) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
            queries = list(queries)
            response = await self.protocol.call_find(node_to_endpoint[peer],
                                                     queries)
            if not response:
                return {query: ([], False) for query in queries}

            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
            for key_id, (maybe_value, maybe_expiration_time,
                         peers) in response.items():
                node_to_endpoint.update(peers)
                if maybe_expiration_time is not None and maybe_expiration_time > latest_results[
                        key_id].expiration_time:
                    latest_results[key_id] = SearchResult(
                        maybe_value, maybe_expiration_time, peer)
                should_interrupt = (latest_results[key_id].expiration_time >=
                                    sufficient_expiration_time)
                output[key_id] = list(peers.keys()), should_interrupt
            return output

        nearest_nodes_per_query, visited_nodes = await traverse_dht(
            queries=list(unfinished_key_ids),
            initial_nodes=list(node_to_endpoint),
            beam_size=beam_size,
            num_workers=num_workers,
            queries_per_call=int(len(unfinished_key_ids)**0.5),
            get_neighbors=get_neighbors,
            visited_nodes={
                key_id: {self.node_id}
                for key_id in unfinished_key_ids
            })

        # stage 3: cache any new results depending on caching parameters
        for key_id, nearest_nodes in nearest_nodes_per_query.items():
            latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[
                key_id]
            should_cache = latest_expiration_time >= sufficient_expiration_time  # if we found a newer value, cache it
            if should_cache and self.cache_locally:
                self.protocol.cache.store(key_id, latest_value_bytes,
                                          latest_expiration_time)

            if should_cache and self.cache_nearest:
                num_cached_nodes = 0
                for node_id in nearest_nodes:
                    if node_id == latest_node_id:
                        continue
                    asyncio.create_task(
                        self.protocol.call_store(node_to_endpoint[node_id],
                                                 [key_id],
                                                 [latest_value_bytes],
                                                 [latest_expiration_time],
                                                 in_cache=True))
                    num_cached_nodes += 1
                    if num_cached_nodes >= self.cache_nearest:
                        break

        # stage 4: deserialize data and assemble function output
        find_result: Dict[DHTKey, Tuple[Optional[DHTValue],
                                        Optional[DHTExpiration]]] = {}
        for key_id, (latest_value_bytes, latest_expiration_time,
                     _) in latest_results.items():
            if latest_expiration_time != -float('inf'):
                latest_value = self.serializer.loads(latest_value_bytes)
                find_result[id_to_original_key[key_id]] = (
                    latest_value, latest_expiration_time)
            else:
                find_result[id_to_original_key[key_id]] = None, None
        return find_result