def test_request_tracker_reserve_request_id_generated(): tracker = RequestTracker() node_id = NodeIDFactory() with tracker.reserve_request_id(node_id) as request_id: assert tracker.is_request_id_active(node_id, request_id) assert not tracker.is_request_id_active(node_id, request_id)
def __init__(self, network: NetworkAPI) -> None: self.logger = get_extended_debug_logger("ddht.AlexandriaClient") self.network = network self.request_tracker = RequestTracker() self.subscription_manager = SubscriptionManager() network.add_talk_protocol(self) self._active_request_ids = set()
def test_request_tracker_reserve_request_id_provided(): tracker = RequestTracker() node_id = NodeIDFactory() request_id = b"\x01\x02\x03\04" assert not tracker.is_request_id_active(node_id, request_id) with tracker.reserve_request_id(node_id, request_id) as actual_request_id: assert actual_request_id == request_id assert tracker.is_request_id_active(node_id, request_id) assert not tracker.is_request_id_active(node_id, request_id)
class AlexandriaClient(Service, AlexandriaClientAPI): protocol_id = ALEXANDRIA_PROTOCOL_ID _active_request_ids: Set[bytes] def __init__(self, network: NetworkAPI) -> None: self.logger = get_extended_debug_logger("ddht.AlexandriaClient") self.network = network self.request_tracker = RequestTracker() self.subscription_manager = SubscriptionManager() network.add_talk_protocol(self) self._active_request_ids = set() @property def local_private_key(self) -> keys.PrivateKey: return self.network.client.local_private_key # # Service API # async def run(self) -> None: self.manager.run_daemon_task(self._feed_talk_requests) self.manager.run_daemon_task(self._feed_talk_responses) await self.manager.wait_finished() # # Request/Response message sending primatives # async def _send_request( self, node_id: NodeID, endpoint: Endpoint, message: AlexandriaMessage[Any], *, request_id: Optional[bytes] = None, ) -> bytes: data_payload = message.to_wire_bytes() if len(data_payload) > MAX_PAYLOAD_SIZE: raise Exception( "Payload too large: payload=%d max_size=%d", len(data_payload), MAX_PAYLOAD_SIZE, ) request_id = await self.network.client.send_talk_request( node_id, endpoint, protocol=ALEXANDRIA_PROTOCOL_ID, payload=data_payload, request_id=request_id, ) return request_id async def _send_response( self, node_id: NodeID, endpoint: Endpoint, message: AlexandriaMessage[Any], *, request_id: bytes, ) -> None: if message.type != AlexandriaMessageType.RESPONSE: raise TypeError("%s is not of type RESPONSE", message) data_payload = message.to_wire_bytes() if len(data_payload) > MAX_PAYLOAD_SIZE: raise Exception( f"Payload too large: payload={len(data_payload)} max_size={MAX_PAYLOAD_SIZE}" ) await self.network.client.send_talk_response( node_id, endpoint, payload=data_payload, request_id=request_id, ) async def _request( self, node_id: NodeID, endpoint: Endpoint, request: AlexandriaMessage[Any], response_class: Type[TAlexandriaMessage], request_id: Optional[bytes] = None, ) -> TAlexandriaMessage: # # Request ID Shenanigans # # We need a subscription API for alexandria messages in order to be # able to cleanly implement functionality like responding to ping # messages with pong messages. # # We accomplish this by monitoring all TALKREQUEST and TALKRESPONSE # messages that occur on the alexandria protocol. For TALKREQUEST based # messages we can naively monitor incoming messages and match them # against the protocol_id. For the TALKRESPONSE however, the only way # for us to know that an incoming message belongs to this protocol is # to match it against an in-flight request (which we implicitely know # is part of the alexandria protocol). # # In order to do this, we need to know which request ids are active for # Alexandria protocol messages. We do this by using a separate # RequestTrackerAPI. The `request_id` for messages is still acquired # from the core tracker located on the base protocol `ClientAPI`. We # then feed this into our local tracker, which allows us to query it # upon receiving an incoming TALKRESPONSE to see if the response is to # a message from this protocol. if request.type != AlexandriaMessageType.REQUEST: raise TypeError("%s is not of type REQUEST", request) if request_id is None: request_id = self.network.client.request_tracker.get_free_request_id( node_id) request_data = request.to_wire_bytes() if len(request_data) > MAX_PAYLOAD_SIZE: raise Exception( "Payload too large: payload=%d max_size=%d", len(request_data), MAX_PAYLOAD_SIZE, ) with self.request_tracker.reserve_request_id(node_id, request_id): response_data = await self.network.talk( node_id, protocol=ALEXANDRIA_PROTOCOL_ID, payload=request_data, endpoint=endpoint, request_id=request_id, ) response = decode_message(response_data) if type(response) is not response_class: raise DecodingError( f"Invalid response. expected={response_class} got={type(response)}" ) return response # type: ignore def subscribe( self, message_type: Type[TAlexandriaMessage], endpoint: Optional[Endpoint] = None, node_id: Optional[NodeID] = None, ) -> AsyncContextManager[trio.abc.ReceiveChannel[ InboundMessage[TAlexandriaMessage]]]: return self.subscription_manager.subscribe( message_type, endpoint, node_id, ) # type: ignore @asynccontextmanager async def subscribe_request( self, node_id: NodeID, endpoint: Endpoint, request: AlexandriaMessage[Any], response_message_type: Type[TAlexandriaMessage], request_id: Optional[bytes], ) -> AsyncIterator[trio.abc.ReceiveChannel[ InboundMessage[TAlexandriaMessage]]]: send_channel, receive_channel = trio.open_memory_channel[ InboundMessage[TAlexandriaMessage]](256) # # START # # There cannot be any `await/async` calls between here and the `END` # marker, otherwise we will be subject to a race condition where # another request could collide with this request id. # if request_id is None: request_id = self.network.client.request_tracker.get_free_request_id( node_id) self.logger.debug2( "Sending request: %s with request id %s", request, request_id.hex(), ) with self.request_tracker.reserve_request_id(node_id, request_id): # # END # async with trio.open_nursery() as nursery: # The use of `functools.partial` below is due to an inadequacy # in the type hinting of `trio.Nursery.start_soon` which # doesn't support more than 4 positional argumeents. nursery.start_soon( functools.partial( self._manage_request_response, node_id, endpoint, request, response_message_type, send_channel, request_id, )) try: async with receive_channel: try: yield receive_channel # Wrap EOC error with TSE to make the timeouts obvious except trio.EndOfChannel as err: raise trio.TooSlowError( f"Timeout: request={request} request_id={request_id.hex()}" ) from err finally: nursery.cancel_scope.cancel() async def _manage_request_response( self, node_id: NodeID, endpoint: Endpoint, request: AlexandriaMessage[Any], response_message_type: Type[AlexandriaMessage[Any]], send_channel: trio.abc.SendChannel[InboundMessage[ AlexandriaMessage[Any]]], request_id: bytes, ) -> None: with trio.move_on_after(REQUEST_RESPONSE_TIMEOUT) as scope: subscription_ctx = self.subscription_manager.subscribe( response_message_type, endpoint, node_id, ) async with subscription_ctx as subscription: self.logger.debug2( "Sending request with request id %s", request_id.hex(), ) # Send the request await self._send_request(node_id, endpoint, request, request_id=request_id) # Wait for the response async with send_channel: async for response in subscription: if response.request_id != request_id: continue else: await send_channel.send(response) if scope.cancelled_caught: self.logger.debug( "Abandoned request response monitor: request=%s message_type=%s", request, response_message_type, ) # # Low Level Message Sending # async def send_ping( self, node_id: NodeID, endpoint: Endpoint, *, enr_seq: int, advertisement_radius: int, request_id: Optional[bytes] = None, ) -> bytes: message = PingMessage(PingPayload(enr_seq, advertisement_radius)) return await self._send_request(node_id, endpoint, message, request_id=request_id) async def send_pong( self, node_id: NodeID, endpoint: Endpoint, *, enr_seq: int, advertisement_radius: int, request_id: bytes, ) -> None: message = PongMessage(PongPayload(enr_seq, advertisement_radius)) await self._send_response(node_id, endpoint, message, request_id=request_id) async def send_find_nodes( self, node_id: NodeID, endpoint: Endpoint, *, distances: Collection[int], request_id: Optional[bytes] = None, ) -> bytes: message = FindNodesMessage(FindNodesPayload(tuple(distances))) return await self._send_request(node_id, endpoint, message, request_id=request_id) async def send_found_nodes( self, node_id: NodeID, endpoint: Endpoint, *, enrs: Sequence[ENRAPI], request_id: bytes, ) -> int: enr_batches = partition_enrs( enrs, max_payload_size=MAX_PAYLOAD_SIZE, ) num_batches = len(enr_batches) for batch in enr_batches: message = FoundNodesMessage( FoundNodesPayload.from_enrs(num_batches, batch)) await self._send_response(node_id, endpoint, message, request_id=request_id) return num_batches async def send_get_content( self, node_id: NodeID, endpoint: Endpoint, *, content_key: ContentKey, start_chunk_index: int, max_chunks: int, request_id: Optional[bytes] = None, ) -> bytes: message = GetContentMessage( GetContentPayload(content_key, start_chunk_index, max_chunks)) return await self._send_request(node_id, endpoint, message, request_id=request_id) async def send_content( self, node_id: NodeID, endpoint: Endpoint, *, is_proof: bool, payload: bytes, request_id: bytes, ) -> None: message = ContentMessage(ContentPayload(is_proof, payload)) return await self._send_response(node_id, endpoint, message, request_id=request_id) async def send_advertisements( self, node_id: NodeID, endpoint: Endpoint, *, advertisements: Sequence[Advertisement], request_id: Optional[bytes] = None, ) -> bytes: message = AdvertiseMessage(tuple(advertisements)) return await self._send_request(node_id, endpoint, message, request_id=request_id) async def send_ack( self, node_id: NodeID, endpoint: Endpoint, *, advertisement_radius: int, acked: Tuple[bool, ...], request_id: bytes, ) -> None: message = AckMessage(AckPayload(advertisement_radius, acked)) return await self._send_response(node_id, endpoint, message, request_id=request_id) async def send_locate( self, node_id: NodeID, endpoint: Endpoint, *, content_key: ContentKey, request_id: bytes, ) -> bytes: message = LocateMessage(LocatePayload(content_key)) return await self._send_request(node_id, endpoint, message, request_id=request_id) async def send_locations( self, node_id: NodeID, endpoint: Endpoint, *, advertisements: Sequence[Advertisement], request_id: bytes, ) -> int: # Here we use a slightly smaller MAX_PAYLOAD_SIZE to account for the # other fields in the message. advertisement_batches = partition_advertisements( advertisements, max_payload_size=MAX_PAYLOAD_SIZE - 8, ) num_batches = len(advertisement_batches) for batch in advertisement_batches: message = LocationsMessage(LocationsPayload(num_batches, batch)) await self._send_response(node_id, endpoint, message, request_id=request_id) return num_batches # # High Level Request/Response # async def ping( self, node_id: NodeID, endpoint: Endpoint, enr_seq: int, advertisement_radius: int, request_id: Optional[bytes] = None, ) -> PongMessage: request = PingMessage(PingPayload(enr_seq, advertisement_radius)) response = await self._request(node_id, endpoint, request, PongMessage, request_id) return response async def find_nodes( self, node_id: NodeID, endpoint: Endpoint, distances: Collection[int], *, request_id: Optional[bytes] = None, ) -> Tuple[InboundMessage[FoundNodesMessage], ...]: request = FindNodesMessage(FindNodesPayload(tuple(distances))) subscription: trio.abc.ReceiveChannel[ InboundMessage[FoundNodesMessage]] # unclear why `subscribe_request` isn't properly carrying the type information async with self.subscribe_request( # type: ignore node_id, endpoint, request, FoundNodesMessage, request_id=request_id, ) as subscription: head_response = await subscription.receive() total = head_response.message.payload.total responses: Tuple[InboundMessage[FoundNodesMessage], ...] if total == 1: responses = (head_response, ) elif total > 1: tail_responses: List[InboundMessage[FoundNodesMessage]] = [] for _ in range(total - 1): tail_responses.append(await subscription.receive()) responses = (head_response, ) + tuple(tail_responses) else: # TODO: this code path needs to be excercised and # probably replaced with some sort of # `SessionTerminated` exception. raise Exception("Invalid `total` counter in response") # Validate that all responses are indeed at one of the # specified distances. for response in responses: for enr in response.message.payload.enrs: if enr.node_id == node_id: if 0 not in distances: raise ValidationError( f"Invalid response: distance=0 expected={distances}" ) else: distance = compute_log_distance(enr.node_id, node_id) if distance not in distances: raise ValidationError( f"Invalid response: distance={distance} expected={distances}" ) return responses async def get_content( self, node_id: NodeID, endpoint: Endpoint, *, content_key: ContentKey, start_chunk_index: int, max_chunks: int, request_id: Optional[bytes] = None, ) -> ContentMessage: request = GetContentMessage( GetContentPayload(content_key, start_chunk_index, max_chunks)) response = await self._request(node_id, endpoint, request, ContentMessage, request_id) return response async def advertise( self, node_id: NodeID, endpoint: Endpoint, *, advertisements: Collection[Advertisement], ) -> AckMessage: if not advertisements: raise Exception("Must send at least one advertisement") message = AdvertiseMessage(tuple(advertisements)) return await self._request(node_id, endpoint, message, AckMessage) async def locate( self, node_id: NodeID, endpoint: Endpoint, *, content_key: ContentKey, request_id: Optional[bytes] = None, ) -> Tuple[InboundMessage[LocationsMessage], ...]: stream_locate_ctx = self.stream_locate( node_id, endpoint, content_key=content_key, request_id=request_id, ) async with stream_locate_ctx as response_aiter: return tuple([response async for response in response_aiter]) @asynccontextmanager async def stream_locate( self, node_id: NodeID, endpoint: Endpoint, *, content_key: ContentKey, request_id: Optional[bytes] = None, ) -> AsyncIterator[trio.abc.ReceiveChannel[ InboundMessage[LocationsMessage]]]: request = LocateMessage(LocatePayload(content_key)) async def _feed_responses( send_channel: trio.abc.SendChannel[ InboundMessage[LocationsMessage]], ) -> None: subscription: trio.abc.ReceiveChannel[ InboundMessage[LocationsMessage]] # unclear why `subscribe_request` isn't properly carrying the type information async with self.subscribe_request( # type: ignore node_id, endpoint, request, LocationsMessage, request_id=request_id, ) as subscription: async with send_channel: head_response = await subscription.receive() await send_channel.send(head_response) total = head_response.message.payload.total for _ in range(total - 1): response = await subscription.receive() await send_channel.send(response) send_channel, receive_channel = trio.open_memory_channel[ InboundMessage[LocationsMessage]](4) async with trio.open_nursery() as nursery: nursery.start_soon(_feed_responses, send_channel) async with receive_channel: yield receive_channel nursery.cancel_scope.cancel() # # Long Running Processes to manage subscriptions # async def _feed_talk_requests(self) -> None: async with self.network.client.dispatcher.subscribe( TalkRequestMessage) as subscription: async for request in subscription: if request.message.protocol != self.protocol_id: continue try: message = decode_message(request.message.payload) except DecodingError: pass else: if message.type != AlexandriaMessageType.REQUEST: self.logger.debug( "Received non-REQUEST msg via TALKREQ: %s", message) continue self.subscription_manager.feed_subscriptions( InboundMessage( message=message, sender_node_id=request.sender_node_id, sender_endpoint=request.sender_endpoint, explicit_request_id=request.message.request_id, )) async def _feed_talk_responses(self) -> None: async with self.network.client.dispatcher.subscribe( TalkResponseMessage) as subscription: async for response in subscription: is_known_request_id = self.request_tracker.is_request_id_active( response.sender_node_id, response.request_id, ) if not is_known_request_id: continue elif not response.message.payload: continue try: message = decode_message(response.message.payload) except DecodingError: pass else: if message.type != AlexandriaMessageType.RESPONSE: self.logger.debug( "Received non-RESPONSE msg via TALKRESP: %s", message) continue self.subscription_manager.feed_subscriptions( InboundMessage( message=message, sender_node_id=response.sender_node_id, sender_endpoint=response.sender_endpoint, explicit_request_id=response.request_id, ))
def __init__( self, local_private_key: keys.PrivateKey, listen_on: Endpoint, enr_db: QueryableENRDatabaseAPI, session_cache_size: int, events: EventsAPI = None, message_type_registry: MessageTypeRegistry = v51_registry, ) -> None: self.local_private_key = local_private_key self.listen_on = listen_on self._listening = trio.Event() self.enr_manager = ENRManager( private_key=local_private_key, enr_db=enr_db, ) self.enr_db = enr_db self._registry = message_type_registry self.request_tracker = RequestTracker() # Datagrams ( self._outbound_datagram_send_channel, self._outbound_datagram_receive_channel, ) = trio.open_memory_channel[OutboundDatagram](256) ( self._inbound_datagram_send_channel, self._inbound_datagram_receive_channel, ) = trio.open_memory_channel[InboundDatagram](256) # EnvelopePair ( self._outbound_envelope_send_channel, self._outbound_envelope_receive_channel, ) = trio.open_memory_channel[OutboundEnvelope](256) ( self._inbound_envelope_send_channel, self._inbound_envelope_receive_channel, ) = trio.open_memory_channel[InboundEnvelope](256) # Messages ( self._outbound_message_send_channel, self._outbound_message_receive_channel, ) = trio.open_memory_channel[AnyOutboundMessage](256) ( self._inbound_message_send_channel, self._inbound_message_receive_channel, ) = trio.open_memory_channel[AnyInboundMessage](256) if events is None: events = Events() self.events = events self.pool = Pool( local_private_key=self.local_private_key, local_node_id=self.enr_manager.enr.node_id, enr_db=self.enr_db, outbound_envelope_send_channel=self. _outbound_envelope_send_channel, inbound_message_send_channel=self._inbound_message_send_channel, session_cache_size=session_cache_size, message_type_registry=self._registry, events=self.events, ) self.dispatcher = Dispatcher( self._inbound_envelope_receive_channel, self._inbound_message_receive_channel, self.pool, self.enr_db, self.events, ) self.envelope_decoder = EnvelopeEncoder( self._outbound_envelope_receive_channel, self._outbound_datagram_send_channel, ) self.envelope_encoder = EnvelopeDecoder( self._inbound_datagram_receive_channel, self._inbound_envelope_send_channel, self.enr_manager.enr.node_id, ) self._ready = trio.Event()
class Client(Service, ClientAPI): logger = logging.getLogger("ddht.Client") def __init__( self, local_private_key: keys.PrivateKey, listen_on: Endpoint, enr_db: QueryableENRDatabaseAPI, session_cache_size: int, events: EventsAPI = None, message_type_registry: MessageTypeRegistry = v51_registry, ) -> None: self.local_private_key = local_private_key self.listen_on = listen_on self._listening = trio.Event() self.enr_manager = ENRManager( private_key=local_private_key, enr_db=enr_db, ) self.enr_db = enr_db self._registry = message_type_registry self.request_tracker = RequestTracker() # Datagrams ( self._outbound_datagram_send_channel, self._outbound_datagram_receive_channel, ) = trio.open_memory_channel[OutboundDatagram](256) ( self._inbound_datagram_send_channel, self._inbound_datagram_receive_channel, ) = trio.open_memory_channel[InboundDatagram](256) # EnvelopePair ( self._outbound_envelope_send_channel, self._outbound_envelope_receive_channel, ) = trio.open_memory_channel[OutboundEnvelope](256) ( self._inbound_envelope_send_channel, self._inbound_envelope_receive_channel, ) = trio.open_memory_channel[InboundEnvelope](256) # Messages ( self._outbound_message_send_channel, self._outbound_message_receive_channel, ) = trio.open_memory_channel[AnyOutboundMessage](256) ( self._inbound_message_send_channel, self._inbound_message_receive_channel, ) = trio.open_memory_channel[AnyInboundMessage](256) if events is None: events = Events() self.events = events self.pool = Pool( local_private_key=self.local_private_key, local_node_id=self.enr_manager.enr.node_id, enr_db=self.enr_db, outbound_envelope_send_channel=self. _outbound_envelope_send_channel, inbound_message_send_channel=self._inbound_message_send_channel, session_cache_size=session_cache_size, message_type_registry=self._registry, events=self.events, ) self.dispatcher = Dispatcher( self._inbound_envelope_receive_channel, self._inbound_message_receive_channel, self.pool, self.enr_db, self.events, ) self.envelope_decoder = EnvelopeEncoder( self._outbound_envelope_receive_channel, self._outbound_datagram_send_channel, ) self.envelope_encoder = EnvelopeDecoder( self._inbound_datagram_receive_channel, self._inbound_envelope_send_channel, self.enr_manager.enr.node_id, ) self._ready = trio.Event() @property def local_node_id(self) -> NodeID: return self.pool.local_node_id async def run(self) -> None: self.manager.run_daemon_task( self._run_envelope_and_dispatcher_services) self.manager.run_daemon_task(self._do_listen, self.listen_on) await self.manager.wait_finished() async def _run_envelope_and_dispatcher_services(self) -> None: """ Ensure that in the task hierarchy the envelope encode will be shut down *after* the dispatcher. run() | ---EnvelopeEncoder | ---EnvelopeDecoder | ---Dispatcher """ async with background_trio_service(self.envelope_encoder): async with background_trio_service(self.envelope_decoder): async with background_trio_service(self.dispatcher): await self.manager.wait_finished() async def wait_listening(self) -> None: await self._listening.wait() async def _do_listen(self, listen_on: Endpoint) -> None: sock = trio.socket.socket( family=trio.socket.AF_INET, type=trio.socket.SOCK_DGRAM, ) ip_address, port = listen_on await sock.bind((socket.inet_ntoa(ip_address), port)) self._listening.set() await self.events.listening.trigger(listen_on) self.logger.debug("Network connection listening on %s", listen_on) # TODO: the datagram services need to use the `EventsAPI` datagram_sender = DatagramSender( self._outbound_datagram_receive_channel, sock) # type: ignore # noqa: E501 self.manager.run_daemon_child_service(datagram_sender) datagram_receiver = DatagramReceiver( sock, self._inbound_datagram_send_channel) # type: ignore # noqa: E501 self.manager.run_daemon_child_service(datagram_receiver) await self.manager.wait_finished() # # Message API # async def send_ping( self, node_id: NodeID, endpoint: Endpoint, *, enr_seq: Optional[int] = None, request_id: Optional[bytes] = None, ) -> bytes: if enr_seq is None: enr_seq = self.enr_manager.enr.sequence_number with self.request_tracker.reserve_request_id( node_id, request_id) as message_request_id: message = AnyOutboundMessage( PingMessage(message_request_id, enr_seq), endpoint, node_id, ) await self.dispatcher.send_message(message) return message_request_id async def send_pong( self, node_id: NodeID, endpoint: Endpoint, *, enr_seq: Optional[int] = None, request_id: bytes, ) -> None: if enr_seq is None: enr_seq = self.enr_manager.enr.sequence_number message = AnyOutboundMessage( PongMessage( request_id, enr_seq, endpoint.ip_address, endpoint.port, ), endpoint, node_id, ) await self.dispatcher.send_message(message) async def send_find_nodes( self, node_id: NodeID, endpoint: Endpoint, *, distances: Collection[int], request_id: Optional[bytes] = None, ) -> bytes: with self.request_tracker.reserve_request_id( node_id, request_id) as message_request_id: message = AnyOutboundMessage( FindNodeMessage(message_request_id, tuple(distances)), endpoint, node_id, ) await self.dispatcher.send_message(message) return message_request_id async def send_found_nodes( self, node_id: NodeID, endpoint: Endpoint, *, enrs: Sequence[ENRAPI], request_id: bytes, ) -> int: enr_batches = partition_enrs( enrs, max_payload_size=FOUND_NODES_MAX_PAYLOAD_SIZE) num_batches = len(enr_batches) for batch in enr_batches: message = AnyOutboundMessage( FoundNodesMessage( request_id, num_batches, batch, ), endpoint, node_id, ) await self.dispatcher.send_message(message) return num_batches async def send_talk_request( self, node_id: NodeID, endpoint: Endpoint, *, protocol: bytes, payload: bytes, request_id: Optional[bytes] = None, ) -> bytes: with self.request_tracker.reserve_request_id( node_id, request_id) as message_request_id: message = AnyOutboundMessage( TalkRequestMessage(message_request_id, protocol, payload), endpoint, node_id, ) await self.dispatcher.send_message(message) return message_request_id async def send_talk_response( self, node_id: NodeID, endpoint: Endpoint, *, payload: bytes, request_id: bytes, ) -> None: message = AnyOutboundMessage( TalkResponseMessage(request_id, payload), endpoint, node_id, ) await self.dispatcher.send_message(message) async def send_register_topic( self, node_id: NodeID, endpoint: Endpoint, *, topic: bytes, enr: ENRAPI, ticket: bytes = b"", request_id: Optional[bytes] = None, ) -> bytes: with self.request_tracker.reserve_request_id( node_id, request_id) as message_request_id: message = AnyOutboundMessage( RegisterTopicMessage(message_request_id, topic, enr, ticket), endpoint, node_id, ) await self.dispatcher.send_message(message) return message_request_id async def send_ticket( self, node_id: NodeID, endpoint: Endpoint, *, ticket: bytes, wait_time: int, request_id: bytes, ) -> None: message = AnyOutboundMessage( TicketMessage(request_id, ticket, wait_time), endpoint, node_id, ) await self.dispatcher.send_message(message) async def send_registration_confirmation( self, node_id: NodeID, endpoint: Endpoint, *, topic: bytes, request_id: bytes, ) -> None: message = AnyOutboundMessage( RegistrationConfirmationMessage(request_id, topic), endpoint, node_id, ) await self.dispatcher.send_message(message) async def send_topic_query( self, node_id: NodeID, endpoint: Endpoint, *, topic: bytes, request_id: Optional[bytes] = None, ) -> bytes: with self.request_tracker.reserve_request_id( node_id, request_id) as message_request_id: message = AnyOutboundMessage( TopicQueryMessage(message_request_id, topic), endpoint, node_id, ) await self.dispatcher.send_message(message) return message_request_id # # Request/Response API # async def ping( self, node_id: NodeID, endpoint: Endpoint, *, request_id: Optional[bytes] = None, ) -> InboundMessage[PongMessage]: with self.request_tracker.reserve_request_id(node_id, request_id) as request_id: request = AnyOutboundMessage( PingMessage(request_id, self.enr_manager.enr.sequence_number), endpoint, node_id, ) async with self.dispatcher.subscribe_request( request, PongMessage) as subscription: return await subscription.receive() async def find_nodes( self, node_id: NodeID, endpoint: Endpoint, distances: Collection[int], *, request_id: Optional[bytes] = None, ) -> Tuple[InboundMessage[FoundNodesMessage], ...]: with self.request_tracker.reserve_request_id(node_id, request_id) as request_id: request = AnyOutboundMessage( FindNodeMessage(request_id, tuple(distances)), endpoint, node_id, ) async with self.dispatcher.subscribe_request( request, FoundNodesMessage) as subscription: head_response = await subscription.receive() total = head_response.message.total responses: Tuple[InboundMessage[FoundNodesMessage], ...] if total == 1: responses = (head_response, ) elif total > 1: tail_responses: List[ InboundMessage[FoundNodesMessage]] = [] for _ in range(total - 1): tail_responses.append(await subscription.receive()) responses = (head_response, ) + tuple(tail_responses) else: raise ValidationError( f"Invalid `total` counter in response: total={total}") return responses def stream_find_nodes( self, node_id: NodeID, endpoint: Endpoint, distances: Collection[int], *, request_id: Optional[bytes] = None, ) -> AsyncContextManager[trio.abc.ReceiveChannel[ InboundMessage[FoundNodesMessage]]]: return common_client_stream_find_nodes(self, node_id, endpoint, distances, request_id=request_id) async def talk( self, node_id: NodeID, endpoint: Endpoint, protocol: bytes, payload: bytes, *, request_id: Optional[bytes] = None, ) -> InboundMessage[TalkResponseMessage]: with self.request_tracker.reserve_request_id(node_id, request_id) as request_id: request = AnyOutboundMessage( TalkRequestMessage(request_id, protocol, payload), endpoint, node_id, ) async with self.dispatcher.subscribe_request( request, TalkResponseMessage) as subscription: return await subscription.receive() async def register_topic( self, node_id: NodeID, endpoint: Endpoint, topic: bytes, ticket: Optional[bytes] = None, *, request_id: Optional[bytes] = None, ) -> Tuple[InboundMessage[TicketMessage], Optional[InboundMessage[RegistrationConfirmationMessage]], ]: raise NotImplementedError async def topic_query( self, node_id: NodeID, endpoint: Endpoint, topic: bytes, *, request_id: Optional[bytes] = None, ) -> InboundMessage[FoundNodesMessage]: raise NotImplementedError