Ejemplo n.º 1
0
def test_request_tracker_reserve_request_id_generated():
    tracker = RequestTracker()

    node_id = NodeIDFactory()

    with tracker.reserve_request_id(node_id) as request_id:
        assert tracker.is_request_id_active(node_id, request_id)
    assert not tracker.is_request_id_active(node_id, request_id)
Ejemplo n.º 2
0
    def __init__(self, network: NetworkAPI) -> None:
        self.logger = get_extended_debug_logger("ddht.AlexandriaClient")

        self.network = network
        self.request_tracker = RequestTracker()
        self.subscription_manager = SubscriptionManager()

        network.add_talk_protocol(self)

        self._active_request_ids = set()
Ejemplo n.º 3
0
def test_request_tracker_reserve_request_id_provided():
    tracker = RequestTracker()

    node_id = NodeIDFactory()

    request_id = b"\x01\x02\x03\04"

    assert not tracker.is_request_id_active(node_id, request_id)

    with tracker.reserve_request_id(node_id, request_id) as actual_request_id:
        assert actual_request_id == request_id
        assert tracker.is_request_id_active(node_id, request_id)
    assert not tracker.is_request_id_active(node_id, request_id)
Ejemplo n.º 4
0
class AlexandriaClient(Service, AlexandriaClientAPI):
    protocol_id = ALEXANDRIA_PROTOCOL_ID

    _active_request_ids: Set[bytes]

    def __init__(self, network: NetworkAPI) -> None:
        self.logger = get_extended_debug_logger("ddht.AlexandriaClient")

        self.network = network
        self.request_tracker = RequestTracker()
        self.subscription_manager = SubscriptionManager()

        network.add_talk_protocol(self)

        self._active_request_ids = set()

    @property
    def local_private_key(self) -> keys.PrivateKey:
        return self.network.client.local_private_key

    #
    # Service API
    #
    async def run(self) -> None:
        self.manager.run_daemon_task(self._feed_talk_requests)
        self.manager.run_daemon_task(self._feed_talk_responses)

        await self.manager.wait_finished()

    #
    # Request/Response message sending primatives
    #
    async def _send_request(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        message: AlexandriaMessage[Any],
        *,
        request_id: Optional[bytes] = None,
    ) -> bytes:
        data_payload = message.to_wire_bytes()
        if len(data_payload) > MAX_PAYLOAD_SIZE:
            raise Exception(
                "Payload too large:  payload=%d  max_size=%d",
                len(data_payload),
                MAX_PAYLOAD_SIZE,
            )
        request_id = await self.network.client.send_talk_request(
            node_id,
            endpoint,
            protocol=ALEXANDRIA_PROTOCOL_ID,
            payload=data_payload,
            request_id=request_id,
        )
        return request_id

    async def _send_response(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        message: AlexandriaMessage[Any],
        *,
        request_id: bytes,
    ) -> None:
        if message.type != AlexandriaMessageType.RESPONSE:
            raise TypeError("%s is not of type RESPONSE", message)
        data_payload = message.to_wire_bytes()
        if len(data_payload) > MAX_PAYLOAD_SIZE:
            raise Exception(
                f"Payload too large:  payload={len(data_payload)}  max_size={MAX_PAYLOAD_SIZE}"
            )
        await self.network.client.send_talk_response(
            node_id,
            endpoint,
            payload=data_payload,
            request_id=request_id,
        )

    async def _request(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        request: AlexandriaMessage[Any],
        response_class: Type[TAlexandriaMessage],
        request_id: Optional[bytes] = None,
    ) -> TAlexandriaMessage:
        #
        # Request ID Shenanigans
        #
        # We need a subscription API for alexandria messages in order to be
        # able to cleanly implement functionality like responding to ping
        # messages with pong messages.
        #
        # We accomplish this by monitoring all TALKREQUEST and TALKRESPONSE
        # messages that occur on the alexandria protocol. For TALKREQUEST based
        # messages we can naively monitor incoming messages and match them
        # against the protocol_id. For the TALKRESPONSE however, the only way
        # for us to know that an incoming message belongs to this protocol is
        # to match it against an in-flight request (which we implicitely know
        # is part of the alexandria protocol).
        #
        # In order to do this, we need to know which request ids are active for
        # Alexandria protocol messages. We do this by using a separate
        # RequestTrackerAPI. The `request_id` for messages is still acquired
        # from the core tracker located on the base protocol `ClientAPI`.  We
        # then feed this into our local tracker, which allows us to query it
        # upon receiving an incoming TALKRESPONSE to see if the response is to
        # a message from this protocol.
        if request.type != AlexandriaMessageType.REQUEST:
            raise TypeError("%s is not of type REQUEST", request)
        if request_id is None:
            request_id = self.network.client.request_tracker.get_free_request_id(
                node_id)
        request_data = request.to_wire_bytes()
        if len(request_data) > MAX_PAYLOAD_SIZE:
            raise Exception(
                "Payload too large:  payload=%d  max_size=%d",
                len(request_data),
                MAX_PAYLOAD_SIZE,
            )
        with self.request_tracker.reserve_request_id(node_id, request_id):
            response_data = await self.network.talk(
                node_id,
                protocol=ALEXANDRIA_PROTOCOL_ID,
                payload=request_data,
                endpoint=endpoint,
                request_id=request_id,
            )

        response = decode_message(response_data)
        if type(response) is not response_class:
            raise DecodingError(
                f"Invalid response. expected={response_class}  got={type(response)}"
            )
        return response  # type: ignore

    def subscribe(
        self,
        message_type: Type[TAlexandriaMessage],
        endpoint: Optional[Endpoint] = None,
        node_id: Optional[NodeID] = None,
    ) -> AsyncContextManager[trio.abc.ReceiveChannel[
            InboundMessage[TAlexandriaMessage]]]:
        return self.subscription_manager.subscribe(
            message_type,
            endpoint,
            node_id,
        )  # type: ignore

    @asynccontextmanager
    async def subscribe_request(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        request: AlexandriaMessage[Any],
        response_message_type: Type[TAlexandriaMessage],
        request_id: Optional[bytes],
    ) -> AsyncIterator[trio.abc.ReceiveChannel[
            InboundMessage[TAlexandriaMessage]]]:
        send_channel, receive_channel = trio.open_memory_channel[
            InboundMessage[TAlexandriaMessage]](256)

        #
        # START
        #
        # There cannot be any `await/async` calls between here and the `END`
        # marker, otherwise we will be subject to a race condition where
        # another request could collide with this request id.
        #
        if request_id is None:
            request_id = self.network.client.request_tracker.get_free_request_id(
                node_id)

        self.logger.debug2(
            "Sending request: %s with request id %s",
            request,
            request_id.hex(),
        )

        with self.request_tracker.reserve_request_id(node_id, request_id):
            #
            # END
            #
            async with trio.open_nursery() as nursery:
                # The use of `functools.partial` below is due to an inadequacy
                # in the type hinting of `trio.Nursery.start_soon` which
                # doesn't support more than 4 positional argumeents.
                nursery.start_soon(
                    functools.partial(
                        self._manage_request_response,
                        node_id,
                        endpoint,
                        request,
                        response_message_type,
                        send_channel,
                        request_id,
                    ))
                try:
                    async with receive_channel:
                        try:
                            yield receive_channel
                        # Wrap EOC error with TSE to make the timeouts obvious
                        except trio.EndOfChannel as err:
                            raise trio.TooSlowError(
                                f"Timeout: request={request}  request_id={request_id.hex()}"
                            ) from err
                finally:
                    nursery.cancel_scope.cancel()

    async def _manage_request_response(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        request: AlexandriaMessage[Any],
        response_message_type: Type[AlexandriaMessage[Any]],
        send_channel: trio.abc.SendChannel[InboundMessage[
            AlexandriaMessage[Any]]],
        request_id: bytes,
    ) -> None:
        with trio.move_on_after(REQUEST_RESPONSE_TIMEOUT) as scope:
            subscription_ctx = self.subscription_manager.subscribe(
                response_message_type,
                endpoint,
                node_id,
            )
            async with subscription_ctx as subscription:
                self.logger.debug2(
                    "Sending request with request id %s",
                    request_id.hex(),
                )
                # Send the request
                await self._send_request(node_id,
                                         endpoint,
                                         request,
                                         request_id=request_id)

                # Wait for the response
                async with send_channel:
                    async for response in subscription:
                        if response.request_id != request_id:
                            continue
                        else:
                            await send_channel.send(response)
        if scope.cancelled_caught:
            self.logger.debug(
                "Abandoned request response monitor: request=%s message_type=%s",
                request,
                response_message_type,
            )

    #
    # Low Level Message Sending
    #
    async def send_ping(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        enr_seq: int,
        advertisement_radius: int,
        request_id: Optional[bytes] = None,
    ) -> bytes:
        message = PingMessage(PingPayload(enr_seq, advertisement_radius))
        return await self._send_request(node_id,
                                        endpoint,
                                        message,
                                        request_id=request_id)

    async def send_pong(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        enr_seq: int,
        advertisement_radius: int,
        request_id: bytes,
    ) -> None:
        message = PongMessage(PongPayload(enr_seq, advertisement_radius))
        await self._send_response(node_id,
                                  endpoint,
                                  message,
                                  request_id=request_id)

    async def send_find_nodes(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        distances: Collection[int],
        request_id: Optional[bytes] = None,
    ) -> bytes:
        message = FindNodesMessage(FindNodesPayload(tuple(distances)))
        return await self._send_request(node_id,
                                        endpoint,
                                        message,
                                        request_id=request_id)

    async def send_found_nodes(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        enrs: Sequence[ENRAPI],
        request_id: bytes,
    ) -> int:

        enr_batches = partition_enrs(
            enrs,
            max_payload_size=MAX_PAYLOAD_SIZE,
        )
        num_batches = len(enr_batches)
        for batch in enr_batches:
            message = FoundNodesMessage(
                FoundNodesPayload.from_enrs(num_batches, batch))
            await self._send_response(node_id,
                                      endpoint,
                                      message,
                                      request_id=request_id)

        return num_batches

    async def send_get_content(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        content_key: ContentKey,
        start_chunk_index: int,
        max_chunks: int,
        request_id: Optional[bytes] = None,
    ) -> bytes:
        message = GetContentMessage(
            GetContentPayload(content_key, start_chunk_index, max_chunks))
        return await self._send_request(node_id,
                                        endpoint,
                                        message,
                                        request_id=request_id)

    async def send_content(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        is_proof: bool,
        payload: bytes,
        request_id: bytes,
    ) -> None:
        message = ContentMessage(ContentPayload(is_proof, payload))
        return await self._send_response(node_id,
                                         endpoint,
                                         message,
                                         request_id=request_id)

    async def send_advertisements(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        advertisements: Sequence[Advertisement],
        request_id: Optional[bytes] = None,
    ) -> bytes:
        message = AdvertiseMessage(tuple(advertisements))
        return await self._send_request(node_id,
                                        endpoint,
                                        message,
                                        request_id=request_id)

    async def send_ack(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        advertisement_radius: int,
        acked: Tuple[bool, ...],
        request_id: bytes,
    ) -> None:
        message = AckMessage(AckPayload(advertisement_radius, acked))
        return await self._send_response(node_id,
                                         endpoint,
                                         message,
                                         request_id=request_id)

    async def send_locate(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        content_key: ContentKey,
        request_id: bytes,
    ) -> bytes:
        message = LocateMessage(LocatePayload(content_key))
        return await self._send_request(node_id,
                                        endpoint,
                                        message,
                                        request_id=request_id)

    async def send_locations(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        advertisements: Sequence[Advertisement],
        request_id: bytes,
    ) -> int:
        # Here we use a slightly smaller MAX_PAYLOAD_SIZE to account for the
        # other fields in the message.
        advertisement_batches = partition_advertisements(
            advertisements,
            max_payload_size=MAX_PAYLOAD_SIZE - 8,
        )
        num_batches = len(advertisement_batches)
        for batch in advertisement_batches:
            message = LocationsMessage(LocationsPayload(num_batches, batch))
            await self._send_response(node_id,
                                      endpoint,
                                      message,
                                      request_id=request_id)

        return num_batches

    #
    # High Level Request/Response
    #
    async def ping(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        enr_seq: int,
        advertisement_radius: int,
        request_id: Optional[bytes] = None,
    ) -> PongMessage:
        request = PingMessage(PingPayload(enr_seq, advertisement_radius))
        response = await self._request(node_id, endpoint, request, PongMessage,
                                       request_id)
        return response

    async def find_nodes(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        distances: Collection[int],
        *,
        request_id: Optional[bytes] = None,
    ) -> Tuple[InboundMessage[FoundNodesMessage], ...]:
        request = FindNodesMessage(FindNodesPayload(tuple(distances)))

        subscription: trio.abc.ReceiveChannel[
            InboundMessage[FoundNodesMessage]]
        # unclear why `subscribe_request` isn't properly carrying the type information
        async with self.subscribe_request(  # type: ignore
            node_id,
            endpoint,
            request,
            FoundNodesMessage,
            request_id=request_id,
        ) as subscription:
            head_response = await subscription.receive()
            total = head_response.message.payload.total
            responses: Tuple[InboundMessage[FoundNodesMessage], ...]
            if total == 1:
                responses = (head_response, )
            elif total > 1:
                tail_responses: List[InboundMessage[FoundNodesMessage]] = []
                for _ in range(total - 1):
                    tail_responses.append(await subscription.receive())
                responses = (head_response, ) + tuple(tail_responses)
            else:
                # TODO: this code path needs to be excercised and
                # probably replaced with some sort of
                # `SessionTerminated` exception.
                raise Exception("Invalid `total` counter in response")

            # Validate that all responses are indeed at one of the
            # specified distances.
            for response in responses:
                for enr in response.message.payload.enrs:
                    if enr.node_id == node_id:
                        if 0 not in distances:
                            raise ValidationError(
                                f"Invalid response: distance=0  expected={distances}"
                            )
                    else:
                        distance = compute_log_distance(enr.node_id, node_id)
                        if distance not in distances:
                            raise ValidationError(
                                f"Invalid response: distance={distance}  expected={distances}"
                            )

            return responses

    async def get_content(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        content_key: ContentKey,
        start_chunk_index: int,
        max_chunks: int,
        request_id: Optional[bytes] = None,
    ) -> ContentMessage:
        request = GetContentMessage(
            GetContentPayload(content_key, start_chunk_index, max_chunks))
        response = await self._request(node_id, endpoint, request,
                                       ContentMessage, request_id)
        return response

    async def advertise(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        advertisements: Collection[Advertisement],
    ) -> AckMessage:
        if not advertisements:
            raise Exception("Must send at least one advertisement")
        message = AdvertiseMessage(tuple(advertisements))
        return await self._request(node_id, endpoint, message, AckMessage)

    async def locate(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        content_key: ContentKey,
        request_id: Optional[bytes] = None,
    ) -> Tuple[InboundMessage[LocationsMessage], ...]:
        stream_locate_ctx = self.stream_locate(
            node_id,
            endpoint,
            content_key=content_key,
            request_id=request_id,
        )
        async with stream_locate_ctx as response_aiter:
            return tuple([response async for response in response_aiter])

    @asynccontextmanager
    async def stream_locate(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        content_key: ContentKey,
        request_id: Optional[bytes] = None,
    ) -> AsyncIterator[trio.abc.ReceiveChannel[
            InboundMessage[LocationsMessage]]]:
        request = LocateMessage(LocatePayload(content_key))

        async def _feed_responses(
            send_channel: trio.abc.SendChannel[
                InboundMessage[LocationsMessage]],
        ) -> None:
            subscription: trio.abc.ReceiveChannel[
                InboundMessage[LocationsMessage]]
            # unclear why `subscribe_request` isn't properly carrying the type information
            async with self.subscribe_request(  # type: ignore
                node_id,
                endpoint,
                request,
                LocationsMessage,
                request_id=request_id,
            ) as subscription:
                async with send_channel:
                    head_response = await subscription.receive()
                    await send_channel.send(head_response)
                    total = head_response.message.payload.total
                    for _ in range(total - 1):
                        response = await subscription.receive()
                        await send_channel.send(response)

        send_channel, receive_channel = trio.open_memory_channel[
            InboundMessage[LocationsMessage]](4)
        async with trio.open_nursery() as nursery:
            nursery.start_soon(_feed_responses, send_channel)

            async with receive_channel:
                yield receive_channel

            nursery.cancel_scope.cancel()

    #
    # Long Running Processes to manage subscriptions
    #
    async def _feed_talk_requests(self) -> None:
        async with self.network.client.dispatcher.subscribe(
                TalkRequestMessage) as subscription:
            async for request in subscription:
                if request.message.protocol != self.protocol_id:
                    continue

                try:
                    message = decode_message(request.message.payload)
                except DecodingError:
                    pass
                else:
                    if message.type != AlexandriaMessageType.REQUEST:
                        self.logger.debug(
                            "Received non-REQUEST msg via TALKREQ: %s",
                            message)
                        continue

                    self.subscription_manager.feed_subscriptions(
                        InboundMessage(
                            message=message,
                            sender_node_id=request.sender_node_id,
                            sender_endpoint=request.sender_endpoint,
                            explicit_request_id=request.message.request_id,
                        ))

    async def _feed_talk_responses(self) -> None:
        async with self.network.client.dispatcher.subscribe(
                TalkResponseMessage) as subscription:
            async for response in subscription:
                is_known_request_id = self.request_tracker.is_request_id_active(
                    response.sender_node_id,
                    response.request_id,
                )
                if not is_known_request_id:
                    continue
                elif not response.message.payload:
                    continue

                try:
                    message = decode_message(response.message.payload)
                except DecodingError:
                    pass
                else:
                    if message.type != AlexandriaMessageType.RESPONSE:
                        self.logger.debug(
                            "Received non-RESPONSE msg via TALKRESP: %s",
                            message)
                        continue

                    self.subscription_manager.feed_subscriptions(
                        InboundMessage(
                            message=message,
                            sender_node_id=response.sender_node_id,
                            sender_endpoint=response.sender_endpoint,
                            explicit_request_id=response.request_id,
                        ))
Ejemplo n.º 5
0
    def __init__(
        self,
        local_private_key: keys.PrivateKey,
        listen_on: Endpoint,
        enr_db: QueryableENRDatabaseAPI,
        session_cache_size: int,
        events: EventsAPI = None,
        message_type_registry: MessageTypeRegistry = v51_registry,
    ) -> None:
        self.local_private_key = local_private_key

        self.listen_on = listen_on
        self._listening = trio.Event()

        self.enr_manager = ENRManager(
            private_key=local_private_key,
            enr_db=enr_db,
        )
        self.enr_db = enr_db
        self._registry = message_type_registry

        self.request_tracker = RequestTracker()

        # Datagrams
        (
            self._outbound_datagram_send_channel,
            self._outbound_datagram_receive_channel,
        ) = trio.open_memory_channel[OutboundDatagram](256)
        (
            self._inbound_datagram_send_channel,
            self._inbound_datagram_receive_channel,
        ) = trio.open_memory_channel[InboundDatagram](256)

        # EnvelopePair
        (
            self._outbound_envelope_send_channel,
            self._outbound_envelope_receive_channel,
        ) = trio.open_memory_channel[OutboundEnvelope](256)
        (
            self._inbound_envelope_send_channel,
            self._inbound_envelope_receive_channel,
        ) = trio.open_memory_channel[InboundEnvelope](256)

        # Messages
        (
            self._outbound_message_send_channel,
            self._outbound_message_receive_channel,
        ) = trio.open_memory_channel[AnyOutboundMessage](256)
        (
            self._inbound_message_send_channel,
            self._inbound_message_receive_channel,
        ) = trio.open_memory_channel[AnyInboundMessage](256)

        if events is None:
            events = Events()
        self.events = events

        self.pool = Pool(
            local_private_key=self.local_private_key,
            local_node_id=self.enr_manager.enr.node_id,
            enr_db=self.enr_db,
            outbound_envelope_send_channel=self.
            _outbound_envelope_send_channel,
            inbound_message_send_channel=self._inbound_message_send_channel,
            session_cache_size=session_cache_size,
            message_type_registry=self._registry,
            events=self.events,
        )

        self.dispatcher = Dispatcher(
            self._inbound_envelope_receive_channel,
            self._inbound_message_receive_channel,
            self.pool,
            self.enr_db,
            self.events,
        )
        self.envelope_decoder = EnvelopeEncoder(
            self._outbound_envelope_receive_channel,
            self._outbound_datagram_send_channel,
        )
        self.envelope_encoder = EnvelopeDecoder(
            self._inbound_datagram_receive_channel,
            self._inbound_envelope_send_channel,
            self.enr_manager.enr.node_id,
        )

        self._ready = trio.Event()
Ejemplo n.º 6
0
class Client(Service, ClientAPI):
    logger = logging.getLogger("ddht.Client")

    def __init__(
        self,
        local_private_key: keys.PrivateKey,
        listen_on: Endpoint,
        enr_db: QueryableENRDatabaseAPI,
        session_cache_size: int,
        events: EventsAPI = None,
        message_type_registry: MessageTypeRegistry = v51_registry,
    ) -> None:
        self.local_private_key = local_private_key

        self.listen_on = listen_on
        self._listening = trio.Event()

        self.enr_manager = ENRManager(
            private_key=local_private_key,
            enr_db=enr_db,
        )
        self.enr_db = enr_db
        self._registry = message_type_registry

        self.request_tracker = RequestTracker()

        # Datagrams
        (
            self._outbound_datagram_send_channel,
            self._outbound_datagram_receive_channel,
        ) = trio.open_memory_channel[OutboundDatagram](256)
        (
            self._inbound_datagram_send_channel,
            self._inbound_datagram_receive_channel,
        ) = trio.open_memory_channel[InboundDatagram](256)

        # EnvelopePair
        (
            self._outbound_envelope_send_channel,
            self._outbound_envelope_receive_channel,
        ) = trio.open_memory_channel[OutboundEnvelope](256)
        (
            self._inbound_envelope_send_channel,
            self._inbound_envelope_receive_channel,
        ) = trio.open_memory_channel[InboundEnvelope](256)

        # Messages
        (
            self._outbound_message_send_channel,
            self._outbound_message_receive_channel,
        ) = trio.open_memory_channel[AnyOutboundMessage](256)
        (
            self._inbound_message_send_channel,
            self._inbound_message_receive_channel,
        ) = trio.open_memory_channel[AnyInboundMessage](256)

        if events is None:
            events = Events()
        self.events = events

        self.pool = Pool(
            local_private_key=self.local_private_key,
            local_node_id=self.enr_manager.enr.node_id,
            enr_db=self.enr_db,
            outbound_envelope_send_channel=self.
            _outbound_envelope_send_channel,
            inbound_message_send_channel=self._inbound_message_send_channel,
            session_cache_size=session_cache_size,
            message_type_registry=self._registry,
            events=self.events,
        )

        self.dispatcher = Dispatcher(
            self._inbound_envelope_receive_channel,
            self._inbound_message_receive_channel,
            self.pool,
            self.enr_db,
            self.events,
        )
        self.envelope_decoder = EnvelopeEncoder(
            self._outbound_envelope_receive_channel,
            self._outbound_datagram_send_channel,
        )
        self.envelope_encoder = EnvelopeDecoder(
            self._inbound_datagram_receive_channel,
            self._inbound_envelope_send_channel,
            self.enr_manager.enr.node_id,
        )

        self._ready = trio.Event()

    @property
    def local_node_id(self) -> NodeID:
        return self.pool.local_node_id

    async def run(self) -> None:
        self.manager.run_daemon_task(
            self._run_envelope_and_dispatcher_services)
        self.manager.run_daemon_task(self._do_listen, self.listen_on)

        await self.manager.wait_finished()

    async def _run_envelope_and_dispatcher_services(self) -> None:
        """
        Ensure that in the task hierarchy the envelope encode will be shut down
        *after* the dispatcher.

        run()
          |
          ---EnvelopeEncoder
                |
                ---EnvelopeDecoder
                      |
                      ---Dispatcher
        """
        async with background_trio_service(self.envelope_encoder):
            async with background_trio_service(self.envelope_decoder):
                async with background_trio_service(self.dispatcher):
                    await self.manager.wait_finished()

    async def wait_listening(self) -> None:
        await self._listening.wait()

    async def _do_listen(self, listen_on: Endpoint) -> None:
        sock = trio.socket.socket(
            family=trio.socket.AF_INET,
            type=trio.socket.SOCK_DGRAM,
        )
        ip_address, port = listen_on
        await sock.bind((socket.inet_ntoa(ip_address), port))

        self._listening.set()
        await self.events.listening.trigger(listen_on)

        self.logger.debug("Network connection listening on %s", listen_on)

        # TODO: the datagram services need to use the `EventsAPI`
        datagram_sender = DatagramSender(
            self._outbound_datagram_receive_channel,
            sock)  # type: ignore  # noqa: E501
        self.manager.run_daemon_child_service(datagram_sender)

        datagram_receiver = DatagramReceiver(
            sock,
            self._inbound_datagram_send_channel)  # type: ignore  # noqa: E501
        self.manager.run_daemon_child_service(datagram_receiver)

        await self.manager.wait_finished()

    #
    # Message API
    #
    async def send_ping(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        enr_seq: Optional[int] = None,
        request_id: Optional[bytes] = None,
    ) -> bytes:
        if enr_seq is None:
            enr_seq = self.enr_manager.enr.sequence_number

        with self.request_tracker.reserve_request_id(
                node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                PingMessage(message_request_id, enr_seq),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id

    async def send_pong(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        enr_seq: Optional[int] = None,
        request_id: bytes,
    ) -> None:
        if enr_seq is None:
            enr_seq = self.enr_manager.enr.sequence_number

        message = AnyOutboundMessage(
            PongMessage(
                request_id,
                enr_seq,
                endpoint.ip_address,
                endpoint.port,
            ),
            endpoint,
            node_id,
        )
        await self.dispatcher.send_message(message)

    async def send_find_nodes(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        distances: Collection[int],
        request_id: Optional[bytes] = None,
    ) -> bytes:
        with self.request_tracker.reserve_request_id(
                node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                FindNodeMessage(message_request_id, tuple(distances)),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id

    async def send_found_nodes(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        enrs: Sequence[ENRAPI],
        request_id: bytes,
    ) -> int:
        enr_batches = partition_enrs(
            enrs, max_payload_size=FOUND_NODES_MAX_PAYLOAD_SIZE)
        num_batches = len(enr_batches)
        for batch in enr_batches:
            message = AnyOutboundMessage(
                FoundNodesMessage(
                    request_id,
                    num_batches,
                    batch,
                ),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return num_batches

    async def send_talk_request(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        protocol: bytes,
        payload: bytes,
        request_id: Optional[bytes] = None,
    ) -> bytes:
        with self.request_tracker.reserve_request_id(
                node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                TalkRequestMessage(message_request_id, protocol, payload),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id

    async def send_talk_response(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        payload: bytes,
        request_id: bytes,
    ) -> None:
        message = AnyOutboundMessage(
            TalkResponseMessage(request_id, payload),
            endpoint,
            node_id,
        )
        await self.dispatcher.send_message(message)

    async def send_register_topic(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        topic: bytes,
        enr: ENRAPI,
        ticket: bytes = b"",
        request_id: Optional[bytes] = None,
    ) -> bytes:
        with self.request_tracker.reserve_request_id(
                node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                RegisterTopicMessage(message_request_id, topic, enr, ticket),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id

    async def send_ticket(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        ticket: bytes,
        wait_time: int,
        request_id: bytes,
    ) -> None:
        message = AnyOutboundMessage(
            TicketMessage(request_id, ticket, wait_time),
            endpoint,
            node_id,
        )
        await self.dispatcher.send_message(message)

    async def send_registration_confirmation(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        topic: bytes,
        request_id: bytes,
    ) -> None:
        message = AnyOutboundMessage(
            RegistrationConfirmationMessage(request_id, topic),
            endpoint,
            node_id,
        )
        await self.dispatcher.send_message(message)

    async def send_topic_query(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        topic: bytes,
        request_id: Optional[bytes] = None,
    ) -> bytes:
        with self.request_tracker.reserve_request_id(
                node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                TopicQueryMessage(message_request_id, topic),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)
        return message_request_id

    #
    # Request/Response API
    #
    async def ping(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        request_id: Optional[bytes] = None,
    ) -> InboundMessage[PongMessage]:
        with self.request_tracker.reserve_request_id(node_id,
                                                     request_id) as request_id:
            request = AnyOutboundMessage(
                PingMessage(request_id, self.enr_manager.enr.sequence_number),
                endpoint,
                node_id,
            )
            async with self.dispatcher.subscribe_request(
                    request, PongMessage) as subscription:
                return await subscription.receive()

    async def find_nodes(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        distances: Collection[int],
        *,
        request_id: Optional[bytes] = None,
    ) -> Tuple[InboundMessage[FoundNodesMessage], ...]:
        with self.request_tracker.reserve_request_id(node_id,
                                                     request_id) as request_id:
            request = AnyOutboundMessage(
                FindNodeMessage(request_id, tuple(distances)),
                endpoint,
                node_id,
            )
            async with self.dispatcher.subscribe_request(
                    request, FoundNodesMessage) as subscription:
                head_response = await subscription.receive()
                total = head_response.message.total
                responses: Tuple[InboundMessage[FoundNodesMessage], ...]
                if total == 1:
                    responses = (head_response, )
                elif total > 1:
                    tail_responses: List[
                        InboundMessage[FoundNodesMessage]] = []
                    for _ in range(total - 1):
                        tail_responses.append(await subscription.receive())
                    responses = (head_response, ) + tuple(tail_responses)
                else:
                    raise ValidationError(
                        f"Invalid `total` counter in response: total={total}")

                return responses

    def stream_find_nodes(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        distances: Collection[int],
        *,
        request_id: Optional[bytes] = None,
    ) -> AsyncContextManager[trio.abc.ReceiveChannel[
            InboundMessage[FoundNodesMessage]]]:
        return common_client_stream_find_nodes(self,
                                               node_id,
                                               endpoint,
                                               distances,
                                               request_id=request_id)

    async def talk(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        protocol: bytes,
        payload: bytes,
        *,
        request_id: Optional[bytes] = None,
    ) -> InboundMessage[TalkResponseMessage]:
        with self.request_tracker.reserve_request_id(node_id,
                                                     request_id) as request_id:
            request = AnyOutboundMessage(
                TalkRequestMessage(request_id, protocol, payload),
                endpoint,
                node_id,
            )
            async with self.dispatcher.subscribe_request(
                    request, TalkResponseMessage) as subscription:
                return await subscription.receive()

    async def register_topic(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        topic: bytes,
        ticket: Optional[bytes] = None,
        *,
        request_id: Optional[bytes] = None,
    ) -> Tuple[InboundMessage[TicketMessage],
               Optional[InboundMessage[RegistrationConfirmationMessage]], ]:
        raise NotImplementedError

    async def topic_query(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        topic: bytes,
        *,
        request_id: Optional[bytes] = None,
    ) -> InboundMessage[FoundNodesMessage]:
        raise NotImplementedError