예제 #1
0
    async def send_chunks(self,
                          node: Node,
                          *,
                          request_id: int,
                          data: bytes) -> int:
        if node.node_id == self.local_node_id:
            raise ValueError("Cannot send to self")
        if not data:
            response = Message(
                Chunk(request_id, 1, 0, b''),
                node,
            )
            self.logger.debug("Sending Chunks with empty data to %s", node)
            await self.message_dispatcher.send_message(response)
            await self.events.sent_chunk.trigger(response)
            return 1

        all_chunks = split_data_to_chunks(CHUNK_MAX_SIZE, data)
        total_chunks = len(all_chunks)
        self.logger.debug("Sending %d chuncks for data payload to %s", total_chunks, node)

        for index, chunk in enumerate(all_chunks):
            response = Message(
                Chunk(request_id, total_chunks, index, chunk),
                node,
            )
            await self.message_dispatcher.send_message(response)
            await self.events.sent_chunk.trigger(response)

        return total_chunks
예제 #2
0
 async def send_locations(self,
                          node: Node,
                          *,
                          request_id: int,
                          locations: Collection[Node]) -> int:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     batches = tuple(partition_all(NODES_PER_PAYLOAD, locations))
     self.logger.debug("Sending Locations with %d nodes to %s", len(locations), node)
     if batches:
         total_batches = len(batches)
         for batch in batches:
             payload = tuple(
                 node.to_payload()
                 for node in batch
             )
             response = Message(
                 Locations(request_id, total_batches, payload),
                 node,
             )
             await self.message_dispatcher.send_message(response)
             await self.events.sent_locations.trigger(response)
         return total_batches
     else:
         response = Message(
             Locations(request_id, 1, ()),
             node,
         )
         await self.message_dispatcher.send_message(response)
         await self.events.sent_locations.trigger(response)
         return 1
예제 #3
0
 async def locate(self, node: Node, *, key: bytes) -> Tuple[MessageAPI[Locations], ...]:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(Locate(request_id, key), node)
     await self.events.sent_locate.trigger(message)
     return await self._do_request_with_multi_response(message, Locations)
예제 #4
0
 async def find_nodes(self, node: Node, *, distance: int) -> Tuple[MessageAPI[FoundNodes], ...]:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(FindNodes(request_id, distance), node)
     await self.events.sent_find_nodes.trigger(message)
     return await self._do_request_with_multi_response(message, FoundNodes)
예제 #5
0
 async def send_graph_deleted(self, node: Node, *, request_id: int) -> None:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     message = Message(GraphDeleted(request_id), node)
     self.logger.debug("Sending %s", message)
     await self.message_dispatcher.send_message(message)
     await self.events.sent_graph_deleted.trigger(message)
예제 #6
0
 async def advertise(self, node: Node, *, key: bytes, who: Node) -> MessageAPI[Ack]:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(Advertise(request_id, key, who.to_payload()), node)
     async with self.message_dispatcher.subscribe_request(message, Ack) as subscription:
         await self.events.sent_advertise.trigger(message)
         return await subscription.receive()
예제 #7
0
 async def get_graph_introduction(self, node: Node) -> MessageAPI[GraphIntroduction]:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(GraphGetIntroduction(request_id), node)
     async with self.message_dispatcher.subscribe_request(message, GraphIntroduction) as subscription:  # noqa: E501
         await self.events.sent_graph_get_introduction.trigger(message)
         return await subscription.receive()
예제 #8
0
 async def get_graph_node(self, node: Node, *, key: Key) -> MessageAPI[GraphNode]:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(GraphGetNode(request_id, graph_key_to_content_key(key)), node)
     async with self.message_dispatcher.subscribe_request(message, GraphNode) as subscription:  # noqa: E501
         await self.events.sent_graph_get_node.trigger(message)
         return await subscription.receive()
예제 #9
0
 async def retrieve(self, node: Node, *, key: bytes) -> Tuple[MessageAPI[Chunk], ...]:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(Retrieve(request_id, key), node)
     responses = await self._do_request_with_multi_response(message, Chunk)
     await self.events.sent_retrieve.trigger(message)
     return tuple(sorted(responses, key=lambda response: response.payload.index))
예제 #10
0
 async def ping(self, node: Node) -> MessageAPI[Pong]:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(Ping(request_id), node)
     async with self.message_dispatcher.subscribe_request(message, Pong) as subscription:
         await self.events.sent_ping.trigger(message)
         return await subscription.receive()
예제 #11
0
 async def send_graph_get_node(self, node: Node, *, key: Key) -> int:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(GraphGetNode(request_id, graph_key_to_content_key(key)), node)
     self.logger.debug("Sending %s", message)
     await self.message_dispatcher.send_message(message)
     await self.events.sent_graph_get_node.trigger(message)
     return request_id
예제 #12
0
 async def send_graph_get_introduction(self, node: Node) -> int:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(GraphGetIntroduction(request_id), node)
     self.logger.debug("Sending %s", message)
     await self.message_dispatcher.send_message(message)
     await self.events.sent_graph_get_introduction.trigger(message)
     return request_id
예제 #13
0
 async def send_locate(self, node: Node, *, key: bytes) -> int:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(
         Locate(request_id, key),
         node,
     )
     self.logger.debug("Sending %s", message)
     await self.message_dispatcher.send_message(message)
     await self.events.sent_locate.trigger(message)
     return request_id
예제 #14
0
 async def send_find_nodes(self, node: Node, *, distance: int) -> int:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(
         FindNodes(request_id, distance),
         node,
     )
     self.logger.debug("Sending %s", message)
     await self.message_dispatcher.send_message(message)
     await self.events.sent_find_nodes.trigger(message)
     return request_id
예제 #15
0
 async def send_graph_introduction(self,
                                   node: Node,
                                   *,
                                   request_id: int,
                                   graph_nodes: Collection[SGNodeAPI],
                                   ) -> None:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     graph_nodes_payload = tuple(
         SkipGraphNode.from_sg_node(sg_node) for sg_node in graph_nodes
     )
     # TODO: ensure payload size is within bounds
     message = Message(GraphIntroduction(request_id, graph_nodes_payload), node)
     self.logger.debug("Sending %s", message)
     await self.message_dispatcher.send_message(message)
     await self.events.sent_graph_introduction.trigger(message)
예제 #16
0
 async def send_graph_node(self,
                           node: Node,
                           *,
                           request_id: int,
                           sg_node: Optional[SGNodeAPI]) -> None:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     payload: Optional[SkipGraphNode]
     if sg_node is None:
         payload = None
     else:
         payload = SkipGraphNode.from_sg_node(sg_node)
     message = Message(GraphNode(request_id, payload), node)
     self.logger.debug("Sending %s", message)
     await self.message_dispatcher.send_message(message)
     await self.events.sent_graph_node.trigger(message)
예제 #17
0
    async def handle_inbound_packet(self, packet: PacketAPI) -> None:
        self.last_message_at = time.monotonic()

        if self.is_handshake_complete:
            if isinstance(packet, MessagePacket):
                payload = packet.decrypt_payload(
                    self._session_keys.decryption_key)
                message = Message(
                    payload=payload,
                    node=self.remote_node,
                )
                self.logger.debug(
                    '%s: processed inbound message packet: %s',
                    self,
                    message,
                )
                await self._inbound_message_send_channel.send(message)
            else:
                self.logger.debug(
                    '%s: Ignoring packet of type %s received after handshake complete',
                    self,
                    type(packet),
                )
        elif self.is_before_handshake:
            if isinstance(packet, MessagePacket):
                await self.receive_handshake_initiation(packet)
                await self.send_handshake_response(packet)
            else:
                # TODO: full buffer handling
                # TODO: manage buffer...
                self._inbound_packet_buffer_channels[0].send_nowait(packet)
        elif self.is_during_handshake:
            if isinstance(packet, CompleteHandshakePacket):
                self._session_keys = await self.receive_handshake_completion(
                    packet)
                self._status = SessionStatus.AFTER
                await self._events.handshake_complete.trigger(self)
            else:
                try:
                    self._inbound_packet_buffer_channels[0].send_nowait(packet)
                except trio.WouldBlock:
                    self.logger.error(
                        "Discarding message during handshake.  Buffer full: %s",
                        packet,
                    )
        else:
            raise Exception("Invalid state")
예제 #18
0
    async def handle_inbound_packet(self, packet: PacketAPI) -> None:
        self.last_message_at = time.monotonic()

        if self.is_handshake_complete:
            if isinstance(packet, MessagePacket):
                payload = packet.decrypt_payload(
                    self._session_keys.decryption_key)
                message = Message(
                    payload=payload,
                    node=self.remote_node,
                )
                self.logger.debug(
                    '%s: processed inbound message packet: %s',
                    self,
                    message,
                )
                await self._inbound_message_send_channel.send(message)
            else:
                self.logger.debug(
                    '%s: Ignoring packet of type %s received after handshake complete',
                    self,
                    type(packet),
                )
        elif self.is_before_handshake:
            # Likely that both nodes are handshaking with each other at the
            # same time...  Put the message in the queue and see what happens.
            # TODO: deal with buffer after handshake...
            self._inbound_packet_buffer_channels[0].send_nowait(packet)
        elif self.is_during_handshake:
            if isinstance(packet, HandshakeResponse):
                self._session_keys, ephemeral_public_key = await self.receive_handshake_response(
                    packet)
                self._status = SessionStatus.AFTER
                await self._events.handshake_complete.trigger(self)

                await self.send_handshake_completion(
                    self._session_keys,
                    ephemeral_public_key,
                    packet,
                )
            else:
                raise CorruptSession(
                    f"Suspected corrupted session. Got {packet} packet during handshake initiation."
                )
        else:
            raise Exception("Invalid state")
예제 #19
0
    async def receive_handshake_completion(
            self, packet: CompleteHandshakePacket) -> SessionKeys:
        self.logger.debug('%s: received handshake completion', self)
        remote_node_id = recover_source_id_from_tag(
            packet.tag,
            self.local_node_id,
        )
        if remote_node_id != self.remote_node_id:
            raise ValidationError(
                f"Remote node ids do not match: {remote_node_id} != {self.remote_node_id}"
            )

        self.remote_public_key = packet.header.public_key
        expected_remote_node_id = public_key_to_node_id(self.remote_public_key)
        if expected_remote_node_id != remote_node_id:
            raise ValidationError(
                f"Remote node ids does not match expected node id: "
                f"{remote_node_id} != {self.remote_node_id}")

        ephemeral_public_key = packet.header.ephemeral_public_key

        session_keys = compute_session_keys(
            local_private_key=self.private_key,
            remote_public_key=ephemeral_public_key,
            local_node_id=self.local_node_id,
            remote_node_id=self.remote_node_id,
            id_nonce=self.handshake_response_packet.id_nonce,
            is_initiator=False,
        )

        self.decrypt_and_validate_auth_response(
            packet,
            session_keys.auth_response_key,
            self.handshake_response_packet.id_nonce,
        )
        payload = self.decrypt_and_validate_message(
            packet,
            session_keys.decryption_key,
        )
        message = Message(
            payload=payload,
            node=self.remote_node,
        )
        await self._inbound_message_send_channel.send(message)
        return session_keys