示例#1
0
 async def send_talk_response(
     self,
     endpoint: Endpoint,
     node_id: NodeID,
     *,
     response: bytes,
     request_id: int,
 ) -> None:
     message = AnyOutboundMessage(
         TalkResponseMessage(
             request_id,
             response,
         ),
         endpoint,
         node_id,
     )
     await self.dispatcher.send_message(message)
示例#2
0
    async def send_find_nodes(
        self,
        endpoint: Endpoint,
        node_id: NodeID,
        *,
        distances: Collection[int],
        request_id: Optional[int] = None,
    ) -> int:
        with self._get_request_id(node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                FindNodeMessage(message_request_id, tuple(distances)),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id
示例#3
0
    async def send_ping(
        self,
        endpoint: Endpoint,
        node_id: NodeID,
        *,
        request_id: Optional[int] = None,
    ) -> int:
        with self._get_request_id(node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                PingMessage(message_request_id,
                            self.enr_manager.enr.sequence_number),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id
示例#4
0
 async def ping(
     self,
     node_id: NodeID,
     endpoint: Endpoint,
     *,
     request_id: Optional[bytes] = None,
 ) -> InboundMessage[PongMessage]:
     with self.request_tracker.reserve_request_id(node_id,
                                                  request_id) as request_id:
         request = AnyOutboundMessage(
             PingMessage(request_id, self.enr_manager.enr.sequence_number),
             endpoint,
             node_id,
         )
         async with self.dispatcher.subscribe_request(
                 request, PongMessage) as subscription:
             return await subscription.receive()
示例#5
0
 async def send_topic_query(
     self,
     node_id: NodeID,
     endpoint: Endpoint,
     *,
     topic: bytes,
     request_id: Optional[bytes] = None,
 ) -> bytes:
     with self.request_tracker.reserve_request_id(
             node_id, request_id) as message_request_id:
         message = AnyOutboundMessage(
             TopicQueryMessage(message_request_id, topic),
             endpoint,
             node_id,
         )
         await self.dispatcher.send_message(message)
     return message_request_id
示例#6
0
    async def send_talk_request(
        self,
        endpoint: Endpoint,
        node_id: NodeID,
        *,
        protocol: bytes,
        request: bytes,
        request_id: Optional[int] = None,
    ) -> int:
        with self._get_request_id(node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                TalkRequestMessage(message_request_id, protocol, request),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id
示例#7
0
 async def send_pong(
     self,
     endpoint: Endpoint,
     node_id: NodeID,
     *,
     request_id: int,
 ) -> None:
     message = AnyOutboundMessage(
         PongMessage(
             request_id,
             self.enr_manager.enr.sequence_number,
             endpoint.ip_address,
             endpoint.port,
         ),
         endpoint,
         node_id,
     )
     await self.dispatcher.send_message(message)
示例#8
0
    async def send_register_topic(
        self,
        endpoint: Endpoint,
        node_id: NodeID,
        *,
        topic: bytes,
        enr: ENRAPI,
        ticket: bytes = b"",
        request_id: Optional[int] = None,
    ) -> int:
        with self._get_request_id(node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                RegisterTopicMessage(message_request_id, topic, enr, ticket),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id
示例#9
0
    async def find_nodes(
        self, endpoint: Endpoint, node_id: NodeID, distances: Collection[int],
    ) -> Tuple[InboundMessage[FoundNodesMessage], ...]:
        with self._get_request_id(node_id) as request_id:
            request = AnyOutboundMessage(
                FindNodeMessage(request_id, tuple(distances)), endpoint, node_id,
            )
            async with self.dispatcher.subscribe_request(
                request, FoundNodesMessage
            ) as subscription:
                head_response = await subscription.receive()
                total = head_response.message.total
                responses: Tuple[InboundMessage[FoundNodesMessage], ...]
                if total == 1:
                    responses = (head_response,)
                elif total > 1:
                    tail_responses: List[InboundMessage[FoundNodesMessage]] = []
                    for _ in range(total - 1):
                        tail_responses.append(await subscription.receive())
                    responses = (head_response,) + tuple(tail_responses)
                else:
                    # TODO: this code path needs to be excercised and
                    # probably replaced with some sort of
                    # `SessionTerminated` exception.
                    raise Exception("Invalid `total` counter in response")

                # Validate that all responses are indeed at one of the
                # specified distances.
                for response in responses:
                    for enr in response.message.enrs:
                        if enr.node_id == node_id:
                            if 0 not in distances:
                                raise ValidationError(
                                    f"Invalid response: distance=0  expected={distances}"
                                )
                        else:
                            distance = compute_log_distance(enr.node_id, node_id)
                            if distance not in distances:
                                raise ValidationError(
                                    f"Invalid response: distance={distance}  expected={distances}"
                                )

                return responses
示例#10
0
    async def send_found_nodes(
        self,
        endpoint: Endpoint,
        node_id: NodeID,
        *,
        enrs: Sequence[ENRAPI],
        request_id: int,
    ) -> int:
        enr_batches = partition_enrs(
            enrs, max_payload_size=FOUND_NODES_MAX_PAYLOAD_SIZE
        )
        num_batches = len(enr_batches)
        for batch in enr_batches:
            message = AnyOutboundMessage(
                FoundNodesMessage(request_id, num_batches, batch,), endpoint, node_id,
            )
            await self.dispatcher.send_message(message)

        return num_batches
示例#11
0
 async def talk(
     self,
     node_id: NodeID,
     endpoint: Endpoint,
     protocol: bytes,
     payload: bytes,
     *,
     request_id: Optional[bytes] = None,
 ) -> InboundMessage[TalkResponseMessage]:
     with self.request_tracker.reserve_request_id(node_id,
                                                  request_id) as request_id:
         request = AnyOutboundMessage(
             TalkRequestMessage(request_id, protocol, payload),
             endpoint,
             node_id,
         )
         async with self.dispatcher.subscribe_request(
                 request, TalkResponseMessage) as subscription:
             return await subscription.receive()
示例#12
0
    async def send_talk_request(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        protocol: bytes,
        payload: bytes,
        request_id: Optional[bytes] = None,
    ) -> bytes:
        with self.request_tracker.reserve_request_id(
                node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                TalkRequestMessage(message_request_id, protocol, payload),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id
示例#13
0
    async def send_ping(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        enr_seq: Optional[int] = None,
        request_id: Optional[bytes] = None,
    ) -> bytes:
        if enr_seq is None:
            enr_seq = self.enr_manager.enr.sequence_number

        with self.request_tracker.reserve_request_id(
                node_id, request_id) as message_request_id:
            message = AnyOutboundMessage(
                PingMessage(message_request_id, enr_seq),
                endpoint,
                node_id,
            )
            await self.dispatcher.send_message(message)

        return message_request_id
示例#14
0
    async def _stream_find_nodes_response(
        response_message_type: Type[BaseMessage],
        send_channel: trio.abc.SendChannel[InboundMessage[FoundNodesMessage]],
    ) -> None:

        with trio.move_on_after(REQUEST_RESPONSE_TIMEOUT) as scope:
            async with send_channel:
                with client.request_tracker.reserve_request_id(
                        node_id, request_id) as reserved_request_id:
                    request = AnyOutboundMessage(
                        FindNodeMessage(reserved_request_id, tuple(distances)),
                        endpoint,
                        node_id,
                    )

                    async with client.dispatcher.subscribe_request(
                            request, response_message_type) as subscription:
                        head_response = await subscription.receive()
                        expected_total = head_response.message.total
                        validate_found_nodes_response(
                            head_response.message,
                            request,
                            expected_total,
                        )
                        await send_channel.send(head_response)

                        for _ in range(expected_total - 1):
                            response = await subscription.receive()
                            validate_found_nodes_response(
                                response.message, request, expected_total)

                            await send_channel.send(response)

        if scope.cancelled_caught:
            client.logger.debug(
                "Stream find nodes request disconnected: request=%s message_type=%s",
                request,
                reserved_request_id,
            )
            raise trio.TooSlowError("Timeout in stream_find_nodes")
示例#15
0
    async def send_pong(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        *,
        enr_seq: Optional[int] = None,
        request_id: bytes,
    ) -> None:
        if enr_seq is None:
            enr_seq = self.enr_manager.enr.sequence_number

        message = AnyOutboundMessage(
            PongMessage(
                request_id,
                enr_seq,
                endpoint.ip_address,
                endpoint.port,
            ),
            endpoint,
            node_id,
        )
        await self.dispatcher.send_message(message)