Пример #1
0
    def _prepare(*,
                 nonce: Optional[Nonce] = None,
                 initiator_key: Optional[AES128Key] = None,
                 message: BaseMessage,
                 auth_data: TAuthData,
                 source_node_id: Optional[NodeID] = None,
                 dest_node_id: Optional[NodeID] = None,
                 protocol_id: bytes = PROTOCOL_ID) -> Packet[TAuthData]:
        if nonce is None:
            nonce = Nonce(secrets.token_bytes(12))

        if initiator_key is None:
            initiator_key = AES128Key(secrets.token_bytes(16))

        if source_node_id is None:
            source_node_id = NodeID(secrets.token_bytes(32))

        if dest_node_id is None:
            dest_node_id = NodeID(secrets.token_bytes(32))

        return Packet.prepare(
            nonce=nonce,
            initiator_key=initiator_key,
            message=message,
            auth_data=auth_data,
            source_node_id=source_node_id,
            dest_node_id=dest_node_id,
            protocol_id=protocol_id,
        )
Пример #2
0
def validate_and_extract_destination(
        value: Any) -> Tuple[NodeID, Optional[Endpoint]]:
    node_id: NodeID
    endpoint: Optional[Endpoint]

    if is_hex_node_id(value):
        node_id = NodeID(decode_hex(value))
        endpoint = None
    elif value.startswith("enode://"):
        raw_node_id, _, raw_endpoint = value[8:].partition("@")

        validate_hex_node_id(raw_node_id)
        validate_endpoint(raw_endpoint)

        node_id = NodeID(decode_hex(raw_node_id))

        raw_ip_address, _, raw_port = raw_endpoint.partition(":")
        ip_address = ipaddress.ip_address(raw_ip_address)
        port = int(raw_port)
        endpoint = Endpoint(ip_address.packed, port)
    elif value.startswith("enr:"):
        enr = ENR.from_repr(value)
        node_id = enr.node_id
        endpoint = Endpoint.from_enr(enr)
    else:
        raise RPCError(f"Unrecognized node identifier: {value}")

    return node_id, endpoint
Пример #3
0
 def from_rpc_response(cls, response: BucketInfoDict) -> "BucketInfo":
     return cls(
         idx=response["idx"],
         nodes=tuple(
             NodeID(decode_hex(node_id_hex))
             for node_id_hex in response["nodes"]),
         replacement_cache=tuple(
             NodeID(decode_hex(node_id_hex))
             for node_id_hex in response["replacement_cache"]),
         is_full=response["is_full"],
     )
Пример #4
0
    def at_log_distance(cls, reference: NodeID, log_distance: int) -> NodeID:
        num_bits = len(reference) * 8

        if log_distance > num_bits:
            raise ValueError(
                "Log distance must not be greater than number of bits in the node id"
            )
        elif log_distance < 0:
            raise ValueError("Log distance cannot be negative")

        num_common_bits = num_bits - log_distance
        flipped_bit_index = num_common_bits
        num_random_bits = num_bits - num_common_bits - 1

        reference_bits = bytes_to_bits(reference)

        shared_bits = reference_bits[:num_common_bits]
        flipped_bit = not reference_bits[flipped_bit_index]
        random_bits = [
            bool(random.randint(0, 1))
            for _ in range(flipped_bit_index + 1, flipped_bit_index + 1 +
                           num_random_bits)
        ]

        result_bits = tuple(list(shared_bits) + [flipped_bit] + random_bits)
        result = NodeID(bits_to_bytes(result_bits))

        assert compute_log_distance(result, reference) == log_distance
        return result
Пример #5
0
 async def get_blacklisted(self) -> Tuple[NodeID, ...]:
     now = datetime.datetime.utcnow()
     # mypy doesn't know about the type of the `query()` function
     records = self.session.query(BlacklistRecord).filter(  # type: ignore
         BlacklistRecord.expires_at > now
     )
     return tuple(NodeID(to_bytes(hexstr=record.node_id)) for record in records)
Пример #6
0
    async def _manage_routing_table(self) -> None:
        # First load all the bootnode ENRs into our database
        for enr in self._bootnodes:
            try:
                self.enr_db.set_enr(enr)
            except OldSequenceNumber:
                pass

        # Now repeatedly try to bond with each bootnode until one succeeds.
        async with trio.open_nursery() as nursery:
            while self.manager.is_running:
                for enr in self._bootnodes:
                    if enr.node_id == self.local_node_id:
                        continue
                    endpoint = self._endpoint_for_enr(enr)
                    nursery.start_soon(self._bond, enr.node_id, endpoint)

                with trio.move_on_after(10):
                    await self._routing_table_ready.wait()
                    break

        # TODO: Need better logic here for more quickly populating the
        # routing table.  Should start off aggressively filling in the
        # table, only backing off once the table contains some minimum
        # number of records **or** searching for new records fails to find
        # new nodes.  Maybe use a TokenBucket
        async for _ in every(30):
            async with trio.open_nursery() as nursery:
                target_node_id = NodeID(secrets.token_bytes(32))
                found_enrs = await self.recursive_find_nodes(target_node_id)
                for enr in found_enrs:
                    endpoint = self._endpoint_for_enr(enr)
                    nursery.start_soon(self._bond, enr.node_id, endpoint)
Пример #7
0
    async def _explore(
        self,
        node_id: NodeID,
        max_distance: int,
    ) -> None:
        """
        Explore the neighborhood around the given `node_id` out to the
        specified `max_distance`
        """
        async with trio.open_nursery() as nursery:
            for distances in partition_all(2, range(max_distance, 0, -1)):
                try:
                    found_enrs = await self._network.find_nodes(
                        node_id, *distances)
                except trio.TooSlowError:
                    self.unresponsive.add(node_id)
                    return
                except MissingEndpointFields:
                    self.unreachable.add(node_id)
                    return
                except ValidationError:
                    self.invalid.add(node_id)
                    return
                else:
                    # once we encounter a pair of buckets that elicits an empty
                    # response we assume that all subsequent buckets will also
                    # be empty.
                    if not found_enrs:
                        self.logger.debug(
                            "explore-finish: node_id=%s  covered=%d-%d",
                            node_id.hex(),
                            max_distance,
                            distances[0],
                        )
                        break

                for enr in found_enrs:
                    try:
                        self._network.enr_db.set_enr(enr)
                    except OldSequenceNumber:
                        pass

                # check if we have found any new records.  If so, queue them and
                # wake up the new workers.  This is guarded by the `condition`
                # object to ensure we maintain a consistent view of the `seen`
                # nodes.
                async with self._condition:
                    new_enrs = tuple(enr for enr in reduce_enrs(found_enrs)
                                     if enr.node_id not in self.seen)

                    if new_enrs:
                        self.seen.update(enr.node_id for enr in new_enrs)
                        self._condition.notify_all()

                # use the `NetworkProtocol.bond` to perform a liveliness check
                for enr in new_enrs:
                    nursery.start_soon(self._bond_then_send, enr)
Пример #8
0
 def from_wire_bytes(cls, data: bytes) -> "HandshakeHeader":
     if len(data) != HANDSHAKE_HEADER_PACKET_SIZE:
         raise DecodingError(
             f"Invalid length for HandshakeHeader: length={len(data)}  data={data.hex()}"
         )
     stream = BytesIO(data)
     source_node_id = NodeID(stream.read(32))
     signature_size = stream.read(1)[0]
     ephemeral_key_size = stream.read(1)[0]
     return cls(source_node_id, signature_size, ephemeral_key_size)
Пример #9
0
 def from_rpc_response(cls, response: TableInfoResponse) -> "TableInfo":
     return cls(
         center_node_id=NodeID(decode_hex(response["center_node_id"])),
         num_buckets=response["num_buckets"],
         bucket_size=response["bucket_size"],
         buckets={
             int(idx): BucketInfo.from_rpc_response(bucket_stats)
             for idx, bucket_stats in response["buckets"].items()
         },
     )
Пример #10
0
    def extract_params(
            self, request: RPCRequest) -> Tuple[NodeID, Optional[Endpoint]]:
        try:
            raw_params = request["params"]
        except KeyError as err:
            raise RPCError(f"Missiing call params: {err}")

        if len(raw_params) != 1:
            raise RPCError(f"`ddht_ping` endpoint expects a single parameter: "
                           f"Got {len(raw_params)} params: {raw_params}")

        value = raw_params[0]

        node_id: NodeID
        endpoint: Optional[Endpoint]

        if is_hex_node_id(value):
            node_id = NodeID(decode_hex(value))
            endpoint = None
        elif value.startswith("enode://"):
            raw_node_id, _, raw_endpoint = value[8:].partition("@")

            validate_hex_node_id(raw_node_id)
            validate_endpoint(raw_endpoint)

            node_id = NodeID(decode_hex(raw_node_id))

            raw_ip_address, _, raw_port = raw_endpoint.partition(":")
            ip_address = ipaddress.ip_address(raw_ip_address)
            port = int(raw_port)
            endpoint = Endpoint(ip_address.packed, port)
        elif value.startswith("enr:"):
            enr = ENR.from_repr(value)
            node_id = enr.node_id
            endpoint = Endpoint.from_enr(enr)
        else:
            raise RPCError(f"Unrecognized node identifier: {value}")

        return node_id, endpoint
Пример #11
0
def at_log_distance(target: NodeID, distance: int) -> NodeID:
    node_as_int = int.from_bytes(target, "big")
    bits_in_common = 256 - (distance - 1)

    # This is the common prefix
    high_mask = (2**bits_in_common - 1) << distance
    # We flip the bit at the appropriate distance
    differential_bit = ~node_as_int & (2**(distance - 1))
    # We randomize all of the low bits.
    low_mask = secrets.randbelow(2**(distance - 1))

    node_at_distance = (node_as_int & high_mask) | differential_bit | low_mask
    return NodeID(node_at_distance.to_bytes(32, "big"))
Пример #12
0
 def from_row(cls, row: Tuple[bytes, int, bytes, bytes]) -> "Field":
     (
         raw_node_id,
         sequence_number,
         key,
         value,
     ) = row
     return cls(
         node_id=NodeID(raw_node_id),
         sequence_number=sequence_number,
         key=key,
         value=value,
     )
Пример #13
0
    async def bond(
        self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None
    ) -> bool:
        self.logger.debug2(
            "Bonding with %s", node_id.hex(),
        )

        try:
            pong = await self.ping(node_id, endpoint=endpoint)
        except trio.TooSlowError:
            self.logger.debug("Bonding with %s timed out during ping", node_id.hex())
            return False
        except KeyError:
            self.logger.debug(
                "Unable to lookup endpoint information for node: %s", node_id.hex()
            )
            return False

        try:
            enr = await self.lookup_enr(
                node_id, enr_seq=pong.enr_seq, endpoint=endpoint
            )
        except trio.TooSlowError:
            self.logger.debug(
                "Bonding with %s timed out during ENR retrieval", node_id.hex(),
            )
            return False

        self.routing_table.update(enr.node_id)

        self.logger.debug(
            "Bonded with %s successfully", node_id.hex(),
        )

        self._routing_table_ready.set()
        return True
Пример #14
0
 def from_row(
     cls, row: Tuple[bytes, int, bytes, str], fields: Collection[Field]
 ) -> "Record":
     (
         raw_node_id,
         sequence_number,
         signature,
         raw_created_at,
     ) = row
     return cls(
         node_id=NodeID(raw_node_id),
         sequence_number=sequence_number,
         signature=signature,
         created_at=datetime.datetime.strptime(raw_created_at, DB_DATETIME_FORMAT),
         fields=tuple(sorted(fields, key=operator.attrgetter("key"))),
     )
Пример #15
0
 def _init(self, enr: ENRAPI) -> None:
     try:
         ip = enr[IP_V4_ADDRESS_ENR_KEY]
         udp_port = enr[UDP_PORT_ENR_KEY]
     except KeyError:
         self._address = None
     else:
         tcp_port = enr.get(TCP_PORT_ENR_KEY, udp_port)
         self._address = Address(ip, udp_port, tcp_port)
     # FIXME: ENRs may use different pubkey formats and this would break, so instead of storing
     # a PublicKey with a certain format here we should simply use the APIs in the
     # ENR.identity_scheme for the crypto related operations.
     self._pubkey = keys.PublicKey.from_compressed_bytes(enr.public_key)
     self._id = NodeID(keccak(self.pubkey.to_bytes()))
     self._id_int = big_endian_to_int(self.id)
     self._enr = enr
Пример #16
0
    async def bond(
        self, node_id: NodeID, *, endpoint: Optional[Endpoint] = None
    ) -> bool:
        self.logger.debug2(
            "Bonding with %s", node_id.hex(),
        )

        try:
            pong = await self.ping(node_id, endpoint=endpoint)
        except trio.TooSlowError:
            self.logger.debug("Bonding with %s timed out during ping", node_id.hex())
            return False
        except MissingEndpointFields:
            self.logger.debug(
                "Bonding with %s failed due to missing endpoint information",
                node_id.hex(),
            )
            return False

        try:
            enr = await self.lookup_enr(
                node_id, enr_seq=pong.enr_seq, endpoint=endpoint
            )
        except trio.TooSlowError:
            self.logger.debug(
                "Bonding with %s timed out during ENR retrieval", node_id.hex(),
            )
            return False
        except EmptyFindNodesResponse:
            self.logger.debug(
                "Bonding with %s failed due to them not returing their ENR record",
                node_id.hex(),
            )
            return False

        self.routing_table.update(enr.node_id)

        self.logger.debug(
            "Bonded with %s successfully", node_id.hex(),
        )

        self._routing_table_ready.set()
        return True
Пример #17
0
def recover_source_id_from_tag(tag: Tag, destination_node_id: NodeID) -> NodeID:
    """Recover the node id of the source from the tag in a message packet."""
    destination_node_id_hash = hashlib.sha256(destination_node_id).digest()
    source_node_id = sxor(tag, destination_node_id_hash)
    return NodeID(source_node_id)
Пример #18
0
 def recursive_find_nodes(
     self, target: Union[NodeID, ContentID],
 ) -> AsyncContextManager[trio.abc.ReceiveChannel[ENRAPI]]:
     return common_recursive_find_nodes(self, NodeID(target))
Пример #19
0
 def from_wire_bytes(cls, data: bytes) -> "MessagePacket":
     if len(data) != MESSAGE_PACKET_SIZE:
         raise DecodingError(
             f"Invalid length for MessagePacket: length={len(data)}  data={data.hex()}"
         )
     return cls(NodeID(data))
Пример #20
0
 def extract_node_id(cls, enr: CommonENRAPI) -> NodeID:
     public_key_object = PublicKey.from_compressed_bytes(enr.public_key)
     uncompressed_bytes = public_key_object.to_bytes()
     return NodeID(keccak(uncompressed_bytes))
Пример #21
0
async def common_recursive_find_nodes(
    network: NetworkProtocol,
    target: NodeID,
    *,
    concurrency: int = 3,
    unresponsive_cache: Dict[NodeID, float] = UNRESPONSIVE_CACHE,
) -> AsyncIterator[trio.abc.ReceiveChannel[ENRAPI]]:
    """
    An optimized version of the recursive lookup algorithm for a kademlia
    network.

    Continually lookup nodes in the target part of the network, keeping track
    of all of the nodes we have seen.

    Exit once we have queried all of the `k` closest nodes to the target.

    The concurrency structure here is optimized to minimize the effect of
    unresponsive nodes on the total time it takes to perform the recursive
    lookup.  Some requests will hang for up to 10 seconds.  The
    `adaptive_timeout` combined with the multiple concurrent workers helps
    mitigate the overall slowdown caused by a few unresponsive nodes since the
    other queries can be issues concurrently.
    """
    network.logger.debug2("Recursive find nodes: %s", target.hex())
    start_at = trio.current_time()

    # The set of NodeID values we have already queried.
    queried_node_ids: Set[NodeID] = set()

    # The set of NodeID that timed out
    #
    # The `local_node_id` is
    # included in this as a convenience mechanism so that we don't have to
    # continually fiter it out of the various filters
    unresponsive_node_ids: Set[NodeID] = {network.local_node_id}

    # We maintain a cache of nodes that were recently deemed unresponsive
    # within the last 10 minutes.
    unresponsive_node_ids.update(
        node_id
        for node_id, last_unresponsive_at in unresponsive_cache.items()
        if trio.current_time() - last_unresponsive_at < 300
    )

    # Accumulator of the node_ids we have seen
    received_node_ids: Set[NodeID] = set()

    # Tracker for node_ids that are actively being requested.
    in_flight: Set[NodeID] = set()

    condition = trio.Condition()

    def get_unqueried_node_ids() -> Tuple[NodeID, ...]:
        """
        Get the three nodes that are closest to the target such that the node
        is in the closest `k` nodes which haven't been deemed unresponsive.
        """
        # Construct an iterable of *all* the nodes we know about ordered by
        # closeness to the target.
        candidates = iter_closest_nodes(
            target, network.routing_table, received_node_ids
        )
        # Remove any unresponsive nodes from that iterable
        responsive_candidates = itertools.filterfalse(
            lambda node_id: node_id in unresponsive_node_ids, candidates
        )
        # Grab the closest K
        closest_k_candidates = take(
            network.routing_table.bucket_size, responsive_candidates,
        )
        # Filter out any from the closest K that we've already queried or that are in-flight
        closest_k_unqueried = itertools.filterfalse(
            lambda node_id: node_id in queried_node_ids or node_id in in_flight,
            closest_k_candidates,
        )

        return tuple(take(3, closest_k_unqueried))

    async def do_lookup(
        node_id: NodeID, send_channel: trio.abc.SendChannel[ENRAPI]
    ) -> None:
        """
        Perform an individual lookup on the target part of the network from the
        given `node_id`
        """
        if node_id == target:
            distance = 0
        else:
            distance = compute_log_distance(node_id, target)

        try:
            found_enrs = await network.find_nodes(node_id, distance)
        except (trio.TooSlowError, MissingEndpointFields, ValidationError):
            unresponsive_node_ids.add(node_id)
            unresponsive_cache[node_id] = trio.current_time()
            return
        except trio.Cancelled:
            # We don't add these to the unresponsive cache since they didn't
            # necessarily exceed the fulle 10s request/response timeout.
            unresponsive_node_ids.add(node_id)
            raise

        for enr in found_enrs:
            try:
                network.enr_db.set_enr(enr)
            except OldSequenceNumber:
                pass

        async with condition:
            new_enrs = tuple(
                enr for enr in found_enrs if enr.node_id not in received_node_ids
            )
            received_node_ids.update(enr.node_id for enr in new_enrs)

        for enr in new_enrs:
            try:
                await send_channel.send(enr)
            except (trio.BrokenResourceError, trio.ClosedResourceError):
                # In the event that the consumer of `recursive_find_nodes`
                # exits early before the lookup has completed we can end up
                # operating on a closed channel.
                return

    async def worker(
        worker_id: NodeID, send_channel: trio.abc.SendChannel[ENRAPI]
    ) -> None:
        """
        Pulls unqueried nodes from the closest k nodes and performs a
        concurrent lookup on them.
        """
        for round in itertools.count():
            async with condition:
                node_ids = get_unqueried_node_ids()

                if not node_ids:
                    await condition.wait()
                    continue

                # Mark the node_ids as having been queried.
                queried_node_ids.update(node_ids)
                # Mark the node_ids as being in-flight.
                in_flight.update(node_ids)

                # Some of the node ids may have come from our routing table.
                # These won't be present in the `received_node_ids` so we
                # detect this here and send them over the channel.
                try:
                    for node_id in node_ids:
                        if node_id not in received_node_ids:
                            enr = network.enr_db.get_enr(node_id)
                            received_node_ids.add(node_id)
                            await send_channel.send(enr)
                except (trio.BrokenResourceError, trio.ClosedResourceError):
                    # In the event that the consumer of `recursive_find_nodes`
                    # exits early before the lookup has completed we can end up
                    # operating on a closed channel.
                    return

            if len(node_ids) == 1:
                await do_lookup(node_ids[0], send_channel)
            else:
                tasks = tuple(
                    (do_lookup, (node_id, send_channel)) for node_id in node_ids
                )
                try:
                    await adaptive_timeout(*tasks, threshold=1, variance=2.0)
                except trio.TooSlowError:
                    pass

            async with condition:
                # Remove the `node_ids` from the in_flight set.
                in_flight.difference_update(node_ids)

                condition.notify_all()

    async def _monitor_done(send_channel: trio.abc.SendChannel[ENRAPI]) -> None:
        async with send_channel:
            async with condition:
                while True:
                    # this `fail_after` is a failsafe to prevent deadlock situations
                    # which are possible with `Condition` objects.
                    with trio.move_on_after(60) as scope:
                        node_ids = get_unqueried_node_ids()

                        if not node_ids and not in_flight:
                            break
                        else:
                            await condition.wait()

                    if scope.cancelled_caught:
                        network.logger.error("Deadlock")

    send_channel, receive_channel = trio.open_memory_channel[ENRAPI](256)

    async with trio.open_nursery() as nursery:
        nursery.start_soon(_monitor_done, send_channel)

        for worker_id in range(concurrency):
            nursery.start_soon(worker, worker_id, send_channel)

        async with receive_channel:
            yield receive_channel

        nursery.cancel_scope.cancel()

    elapsed = trio.current_time() - start_at

    network.logger.debug(
        "Lookup for %s finished in %f seconds: seen=%d  queried=%d  unresponsive=%d",
        target.hex(),
        elapsed,
        len(received_node_ids),
        len(queried_node_ids),
        len(unresponsive_node_ids),
    )
Пример #22
0
 def from_rpc_response(cls, response: NodeInfoResponse) -> "NodeInfo":
     return cls(
         node_id=NodeID(decode_hex(response["node_id"])),
         enr=ENR.from_repr(response["enr"]),
     )
Пример #23
0
 async def run(self) -> None:
     async for _ in every(ROUTING_TABLE_LOOKUP_INTERVAL):
         target = NodeID(secrets.token_bytes(32))
         await self.lookup(target)
Пример #24
0
 def node_id(self) -> NodeID:
     return NodeID(keccak(self.public_key.to_bytes()))
Пример #25
0
    async def _periodically_advertise_content(self) -> None:
        await self._network.routing_table_ready()

        send_channel, receive_channel = trio.open_memory_channel[ContentKey](
            self._concurrency)

        for _ in range(self._concurrency):
            self.manager.run_daemon_task(self._broadcast_worker,
                                         receive_channel)

        async for _ in every(30 * 60):
            start_at = trio.current_time()

            total_keys = len(self.content_storage)
            if not total_keys:
                continue

            first_key = first(
                self.content_storage.iter_closest(
                    NodeID(secrets.token_bytes(32))))

            self.logger.info(
                "content-processing-starting: total=%d  start=%s",
                total_keys,
                first_key.hex(),
            )

            processed_keys = 0

            last_key = first_key
            has_wrapped_around = False

            while self.manager.is_running:
                elapsed = trio.current_time() - start_at
                content_keys = tuple(
                    take(
                        self._concurrency * 2,
                        self.content_storage.enumerate_keys(
                            start_key=last_key),
                    ))

                # TODO: We need to adjust the
                # `ContentStorageAPI.enumerate_keys` to allow a
                # non-inclusive left bound so we can query all the keys
                # **after** the last key we processed.
                if content_keys and content_keys[0] == last_key:
                    content_keys = content_keys[1:]

                if not content_keys:
                    last_key = None
                    has_wrapped_around = True
                    continue

                for content_key in content_keys:
                    await send_channel.send(content_key)

                last_key = content_keys[-1]
                if has_wrapped_around and last_key >= first_key:
                    break

                processed_keys += len(content_keys)
                progress = processed_keys * 100 / total_keys

                self.logger.debug(
                    "content-processing: progress=%0.1f  processed=%d  "
                    "total=%d  at=%s  elapsed=%s",
                    progress,
                    processed_keys,
                    total_keys,
                    "None" if last_key is None else last_key.hex(),
                    humanize_seconds(int(elapsed)),
                )

            self.logger.info(
                "content-processing-finished: processed=%d/%d  elapsed=%s",
                processed_keys,
                total_keys,
                humanize_seconds(int(elapsed)),
            )