Ejemplo n.º 1
0
    async def create(cls,
                     node_id: DHTID,
                     bucket_size: int,
                     depth_modulo: int,
                     num_replicas: int,
                     wait_timeout: float,
                     parallel_rpc: Optional[int] = None,
                     cache_size: Optional[int] = None,
                     listen=True,
                     listen_on='0.0.0.0:*',
                     endpoint: Optional[Endpoint] = None,
                     channel_options: Sequence[Tuple[str, Any]] = (),
                     **kwargs) -> DHTProtocol:
        """
        A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
        As a side-effect, DHTProtocol also maintains a routing table as described in
        https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf

        See DHTNode (node.py) for a more detailed description.

        :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
         for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
         Only the call_* methods are meant to be called publicly, e.g. from DHTNode
         Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
        """
        self = cls(_initialized_with_create=True)
        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
        self.wait_timeout, self.channel_options = wait_timeout, tuple(
            channel_options)
        self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage(
            maxsize=cache_size)
        self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
        self.rpc_semaphore = asyncio.Semaphore(
            parallel_rpc if parallel_rpc is not None else float('inf'))

        if listen:  # set up server to process incoming rpc requests
            grpc.aio.init_grpc_aio()
            self.server = grpc.aio.server(**kwargs,
                                          options=GRPC_KEEPALIVE_OPTIONS)
            dht_grpc.add_DHTServicer_to_server(self, self.server)

            self.port = self.server.add_insecure_port(listen_on)
            assert self.port != 0, f"Failed to listen to {listen_on}"
            if endpoint is not None and endpoint.endswith('*'):
                endpoint = replace_port(endpoint, self.port)
            self.node_info = dht_pb2.NodeInfo(
                node_id=node_id.to_bytes(),
                rpc_port=self.port,
                endpoint=endpoint
                or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value)
            await self.server.start()
        else:  # not listening to incoming requests, client-only mode
            # note: use empty node_info so peers won't add you to their routing tables
            self.node_info, self.server, self.port = dht_pb2.NodeInfo(
            ), None, None
            if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
                logger.warning(
                    f"DHTProtocol has no server (due to listen=False), listen_on"
                    f"and kwargs have no effect (unused kwargs: {kwargs})")
        return self
Ejemplo n.º 2
0
    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[Dict[
        DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, Endpoint]]]]:
        """
        Request keys from a peer. For each key, look for its (value, expiration time) locally and
         k additional peers that are most likely to have this key (ranked by XOR distance)

        :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
         value: value stored by the recipient with that key, or None if peer doesn't have this value
         expiration time: expiration time of the returned value, None if no value was found
         neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table
         If peer didn't respond, returns None
        """
        keys = list(keys)
        find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
        try:
            async with self.rpc_semaphore:
                response = await self._get_dht_stub(peer).rpc_find(find_request, timeout=self.wait_timeout)
            if response.peer and response.peer.node_id:
                peer_id = DHTID.from_bytes(response.peer.node_id)
                asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
            assert len(keys) == len(response.results), "DHTProtocol: response is not aligned with keys"

            output = {}  # unpack data depending on its type
            for key, result in zip(keys, response.results):
                key_bytes = DHTID.to_bytes(key)
                nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints))

                if result.type == dht_pb2.NOT_FOUND:
                    output[key] = None, nearest
                elif result.type == dht_pb2.FOUND_REGULAR:
                    if not self._validate_record(
                            key_bytes, self.IS_REGULAR_VALUE, result.value, result.expiration_time):
                        output[key] = None, nearest
                        continue

                    output[key] = ValueWithExpiration(result.value, result.expiration_time), nearest
                elif result.type == dht_pb2.FOUND_DICTIONARY:
                    value_dictionary = self.serializer.loads(result.value)
                    if not self._validate_dictionary(key_bytes, value_dictionary):
                        output[key] = None, nearest
                        continue

                    output[key] = ValueWithExpiration(value_dictionary, result.expiration_time), nearest
                else:
                    logger.error(f"Unknown result type: {result.type}")

            return output
        except grpc.aio.AioRpcError as error:
            logger.debug(f"DHTProtocol failed to find at {peer}: {error.code()}")
            asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))