Example #1
0
    async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
        """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
        while period is not None:  # if None run once, otherwise run forever
            refresh_time = get_dht_time()
            staleness_threshold = refresh_time - period
            stale_buckets = [
                bucket for bucket in self.protocol.routing_table.buckets
                if bucket.last_updated < staleness_threshold
            ]
            for bucket in stale_buckets:
                refresh_id = DHTID(
                    random.randint(bucket.lower, bucket.upper - 1))
                await self.find_nearest_nodes(refresh_id)

            await asyncio.sleep(
                max(0.0, period - (get_dht_time() - refresh_time)))
Example #2
0
 def remove_outdated(self):
     while self.expiration_heap and (
             self.expiration_heap[0][0] < get_dht_time()
             or len(self.expiration_heap) > self.cache_size):
         heap_entry = heapq.heappop(self.expiration_heap)
         key = heap_entry[1]
         if self.key_to_heap[key] == heap_entry:
             del self.data[key], self.key_to_heap[key]
Example #3
0
 async def _get_experts(self, node: DHTNode, uids: List[str],
                        expiration_time: Optional[DHTExpiration],
                        future: MPFuture):
     if expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if self.max_workers is None else min(
         len(uids), self.max_workers)
     response = await node.get_many(uids,
                                    expiration_time,
                                    num_workers=num_workers)
     future.set_result([
         RemoteExpert(uid, maybe_endpoint)
         if maybe_expiration_time else None
         for uid, (maybe_endpoint,
                   maybe_expiration_time) in response.items()
     ])
Example #4
0
 def store(self, key: DHTID, value: BinaryDHTValue,
           expiration_time: DHTExpiration) -> bool:
     """
     Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
     :returns: True if new value was stored, False it was rejected (current value is newer)
     """
     if expiration_time < get_dht_time():
         return False
     self.key_to_heap[key] = (expiration_time, key)
     heapq.heappush(self.expiration_heap, (expiration_time, key))
     if key in self.data:
         if self.data[key][1] < expiration_time:
             self.data[key] = (value, expiration_time)
             return True
         return False
     self.data[key] = (value, expiration_time)
     self.remove_outdated()
     return True
Example #5
0
    async def _declare_experts(self, node: DHTNode, uids: List[str],
                               endpoint: Endpoint, future: Optional[MPFuture]):
        num_workers = len(uids) if self.max_workers is None else min(
            len(uids), self.max_workers)
        expiration_time = get_dht_time() + self.expiration

        data_to_store = {}
        for uid in uids:
            uid_parts = uid.split(self.UID_DELIMITER)
            for i in range(len(uid_parts)):
                uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
                data_to_store[uid_prefix_i] = endpoint

        store_keys, store_values = zip(*data_to_store.items())
        store_ok = await node.store_many(store_keys,
                                         store_values,
                                         expiration_time,
                                         num_workers=num_workers)
        if future is not None:
            future.set_result([store_ok[key] for key in data_to_store.keys()])
Example #6
0
    async def create(cls,
                     node_id: Optional[DHTID] = None,
                     initial_peers: List[Endpoint] = (),
                     bucket_size: int = 20,
                     num_replicas: int = 5,
                     depth_modulo: int = 5,
                     parallel_rpc: int = None,
                     wait_timeout: float = 5,
                     refresh_timeout: Optional[float] = None,
                     bootstrap_timeout: Optional[float] = None,
                     num_workers: int = 1,
                     cache_locally: bool = True,
                     cache_nearest: int = 1,
                     cache_size=None,
                     listen: bool = True,
                     listen_on: Endpoint = "0.0.0.0:*",
                     **kwargs) -> DHTNode:
        """
        :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
        :param initial_peers: connects to these peers to populate routing table, defaults to no peers
        :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
          either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
          Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
        :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
        :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
        :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
          Reduce this value if your RPC requests register no response despite the peer sending the response.
        :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
        :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
          if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
        :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
        :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
        :param cache_locally: if True, caches all values (stored or found) in a node-local cache
        :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
          nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
        :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
        :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
          if False, this node will refuse any incoming request, effectively being only a "client"
        :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
        :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
        :param kwargs: extra parameters used in grpc.aio.server
        """
        self = cls(_initialized_with_create=True)
        self.node_id = node_id = node_id if node_id is not None else DHTID.generate(
        )
        self.num_replicas, self.num_workers = num_replicas, num_workers
        self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
        self.refresh_timeout = refresh_timeout

        self.protocol = await DHTProtocol.create(self.node_id, bucket_size,
                                                 depth_modulo, num_replicas,
                                                 wait_timeout, parallel_rpc,
                                                 cache_size, listen, listen_on,
                                                 **kwargs)
        self.port = self.protocol.port

        if initial_peers:
            # stage 1: ping initial_peers, add each other to the routing table
            bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
            start_time = get_dht_time()
            ping_tasks = map(self.protocol.call_ping, initial_peers)
            finished_pings, unfinished_pings = await asyncio.wait(
                ping_tasks, return_when=asyncio.FIRST_COMPLETED)

            # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
            if unfinished_pings:
                finished_in_time, stragglers = await asyncio.wait(
                    unfinished_pings,
                    timeout=bootstrap_timeout - get_dht_time() + start_time)
                for straggler in stragglers:
                    straggler.cancel()
                finished_pings |= finished_in_time

            if not finished_pings:
                warn(
                    "DHTNode bootstrap failed: none of the initial_peers responded to a ping."
                )

            # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
            # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
            # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
            await asyncio.wait([
                asyncio.create_task(self.find_nearest_nodes([self.node_id])),
                asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)
            ],
                               return_when=asyncio.FIRST_COMPLETED)

        if self.refresh_timeout is not None:
            asyncio.create_task(
                self._refresh_routing_table(period=self.refresh_timeout))
        return self
Example #7
0
    async def get_many(
        self,
        keys: Collection[DHTKey],
        sufficient_expiration_time: Optional[DHTExpiration] = None,
        num_workers: Optional[int] = None,
        beam_size: Optional[int] = None
    ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]:
        """
        :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
        :param sufficient_expiration_time: if the search finds a value that expires after this time,
            default = time of call, find any value that did not expire by the time of call
            If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
        :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
        :param num_workers: override for default num_workers, see traverse_dht num_workers param
        :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key
        :note: in order to check if get returned a value, please check (expiration_time is None)
        """
        key_ids = [DHTID.generate(key) for key in keys]
        id_to_original_key = dict(zip(key_ids, keys))
        sufficient_expiration_time = sufficient_expiration_time or get_dht_time(
        )
        beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
        num_workers = num_workers if num_workers is not None else self.num_workers

        # search metadata
        unfinished_key_ids = set(
            key_ids)  # track key ids for which the search is not terminated
        node_to_endpoint: Dict[
            DHTID, Endpoint] = dict()  # global routing table for all queries

        SearchResult = namedtuple(
            "SearchResult",
            ["binary_value", "expiration_time", "source_node_id"])
        latest_results = {
            key_id: SearchResult(b'', -float('inf'), None)
            for key_id in key_ids
        }

        # stage 1: value can be stored in our local cache
        for key_id in key_ids:
            maybe_value, maybe_expiration_time = self.protocol.storage.get(
                key_id)
            if maybe_expiration_time is None:
                maybe_value, maybe_expiration_time = self.protocol.cache.get(
                    key_id)
            if maybe_expiration_time is not None and maybe_expiration_time > latest_results[
                    key_id].expiration_time:
                latest_results[key_id] = SearchResult(maybe_value,
                                                      maybe_expiration_time,
                                                      self.node_id)
                if maybe_expiration_time >= sufficient_expiration_time:
                    unfinished_key_ids.remove(key_id)

        # stage 2: traverse the DHT for any unfinished keys
        for key_id in unfinished_key_ids:
            node_to_endpoint.update(
                self.protocol.routing_table.get_nearest_neighbors(
                    key_id, self.protocol.bucket_size, exclude=self.node_id))

        async def get_neighbors(
            peer: DHTID, queries: Collection[DHTID]
        ) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
            queries = list(queries)
            response = await self.protocol.call_find(node_to_endpoint[peer],
                                                     queries)
            if not response:
                return {query: ([], False) for query in queries}

            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
            for key_id, (maybe_value, maybe_expiration_time,
                         peers) in response.items():
                node_to_endpoint.update(peers)
                if maybe_expiration_time is not None and maybe_expiration_time > latest_results[
                        key_id].expiration_time:
                    latest_results[key_id] = SearchResult(
                        maybe_value, maybe_expiration_time, peer)
                should_interrupt = (latest_results[key_id].expiration_time >=
                                    sufficient_expiration_time)
                output[key_id] = list(peers.keys()), should_interrupt
            return output

        nearest_nodes_per_query, visited_nodes = await traverse_dht(
            queries=list(unfinished_key_ids),
            initial_nodes=list(node_to_endpoint),
            beam_size=beam_size,
            num_workers=num_workers,
            queries_per_call=int(len(unfinished_key_ids)**0.5),
            get_neighbors=get_neighbors,
            visited_nodes={
                key_id: {self.node_id}
                for key_id in unfinished_key_ids
            })

        # stage 3: cache any new results depending on caching parameters
        for key_id, nearest_nodes in nearest_nodes_per_query.items():
            latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[
                key_id]
            should_cache = latest_expiration_time >= sufficient_expiration_time  # if we found a newer value, cache it
            if should_cache and self.cache_locally:
                self.protocol.cache.store(key_id, latest_value_bytes,
                                          latest_expiration_time)

            if should_cache and self.cache_nearest:
                num_cached_nodes = 0
                for node_id in nearest_nodes:
                    if node_id == latest_node_id:
                        continue
                    asyncio.create_task(
                        self.protocol.call_store(node_to_endpoint[node_id],
                                                 [key_id],
                                                 [latest_value_bytes],
                                                 [latest_expiration_time],
                                                 in_cache=True))
                    num_cached_nodes += 1
                    if num_cached_nodes >= self.cache_nearest:
                        break

        # stage 4: deserialize data and assemble function output
        find_result: Dict[DHTKey, Tuple[Optional[DHTValue],
                                        Optional[DHTExpiration]]] = {}
        for key_id, (latest_value_bytes, latest_expiration_time,
                     _) in latest_results.items():
            if latest_expiration_time != -float('inf'):
                latest_value = self.serializer.loads(latest_value_bytes)
                find_result[id_to_original_key[key_id]] = (
                    latest_value, latest_expiration_time)
            else:
                find_result[id_to_original_key[key_id]] = None, None
        return find_result