예제 #1
0
 async def send_locations(self,
                          node: Node,
                          *,
                          request_id: int,
                          locations: Collection[Node]) -> int:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     batches = tuple(partition_all(NODES_PER_PAYLOAD, locations))
     self.logger.debug("Sending Locations with %d nodes to %s", len(locations), node)
     if batches:
         total_batches = len(batches)
         for batch in batches:
             payload = tuple(
                 node.to_payload()
                 for node in batch
             )
             response = Message(
                 Locations(request_id, total_batches, payload),
                 node,
             )
             await self.message_dispatcher.send_message(response)
             await self.events.sent_locations.trigger(response)
         return total_batches
     else:
         response = Message(
             Locations(request_id, 1, ()),
             node,
         )
         await self.message_dispatcher.send_message(response)
         await self.events.sent_locations.trigger(response)
         return 1
예제 #2
0
async def test_client_locate_request_single_response(alice_and_bob_clients):
    alice, bob = alice_and_bob_clients

    async with trio.open_nursery() as nursery:
        single_location = NodeFactory()

        async with bob.message_dispatcher.subscribe(Locate) as locate_subscription:
            async def _handle_locate():
                with trio.fail_after(1):
                    message = await locate_subscription.receive()

                assert isinstance(message.payload, Locate)
                await bob.send_locations(
                    message.node,
                    request_id=message.payload.request_id,
                    locations=(single_location,),
                )

            nursery.start_soon(_handle_locate)

            with trio.fail_after(1):
                messages = await alice.locate(bob.local_node, key=b'key')

            assert len(messages) == 1
            message = messages[0]

            assert isinstance(message.payload, Locations)
            payload = message.payload
            assert payload.total == 1
            assert len(payload.nodes) == 1
            node = Node.from_payload(payload.nodes[0])
            assert node == single_location
예제 #3
0
    async def _ping_occasionally(self) -> None:
        async for _ in every(self.config.PING_INTERVAL):  # noqa: F841
            if self.routing_table.is_empty:
                self.logger.warning("Routing table is empty, no one to ping")
                continue

            log_distance = self.routing_table.get_least_recently_updated_log_distance(
            )
            candidates = self.routing_table.get_nodes_at_log_distance(
                log_distance)
            for node_id in reversed(candidates):
                endpoint = self.endpoint_db.get_endpoint(node_id)
                node = Node(node_id, endpoint)

                with trio.move_on_after(PING_TIMEOUT) as scope:
                    await self.client.ping(node)

                if scope.cancelled_caught:
                    self.logger.debug(
                        'Node %s did not respond to ping.  Removing from routing table',
                        node_id,
                    )
                    self.routing_table.remove(node_id)
                else:
                    break
예제 #4
0
 async def do_announce(key: bytes) -> None:
     with trio.move_on_after(ANNOUNCE_TIMEOUT):
         await self.network.announce(
             key,
             Node(self.client.local_node_id,
                  self.client.external_endpoint),
         )
예제 #5
0
    async def _handle_session_packet(self, datagram: Datagram) -> None:
        packet = decode_packet(datagram.data)
        remote_node_id = recover_source_id_from_tag(packet.tag, self.local_node_id)
        remote_node = Node(remote_node_id, datagram.endpoint)
        session = await self._get_or_create_session(
            remote_node,
            is_initiator=False,
        )

        try:
            await session.handle_inbound_packet(packet)
        except (DecryptionError, CorruptSession):
            self.pool.remove_session(session.session_id)
            self.logger.debug('Removed defunkt session: %s', session)

            fresh_session = await self._get_or_create_session(
                remote_node,
                is_initiator=False,
            )
            # Now try again with a fresh session
            try:
                await fresh_session.handle_inbound_packet(packet)
            except DecryptionError:
                self.pool.remove_session(fresh_session.session_id)
                self.logger.debug(
                    'Unable to read packet after resetting session: %s',
                    fresh_session,
                )
예제 #6
0
    async def _handle_locate_requests(self) -> None:
        def get_endpoint(node_id: NodeID) -> Endpoint:
            try:
                return self.endpoint_db.get_endpoint(node_id)
            except KeyError:
                if node_id == self.client.local_node_id:
                    return self.client.external_endpoint
                else:
                    raise

        async with self.client.message_dispatcher.subscribe(
                Locate) as subscription:
            while self.manager.is_running:
                request = await subscription.receive()
                payload = request.payload
                content_id = content_key_to_node_id(payload.key)
                # TODO: ping the node to ensure it is available (unless it is the sending node).
                # TODO: verify content is actually available
                # TODO: check distance of key and store conditionally
                location_ids = self.content_manager.get_index(content_id)

                locations = tuple(
                    Node(node_id, get_endpoint(node_id))
                    for node_id in location_ids)
                await self.client.send_locations(
                    request.node,
                    request_id=payload.request_id,
                    locations=locations,
                )
예제 #7
0
def node_from_rpc(rpc_node: Tuple[str, str, int]) -> Node:
    node_id_as_hex, ip_address, port = rpc_node
    node_id = NodeID(to_int(hexstr=node_id_as_hex))
    node = Node(
        node_id,
        Endpoint(ipaddress.IPv4Address(ip_address), port),
    )
    return node
예제 #8
0
 async def advertise(self, node: Node, *, key: bytes, who: Node) -> MessageAPI[Ack]:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(Advertise(request_id, key, who.to_payload()), node)
     async with self.message_dispatcher.subscribe_request(message, Ack) as subscription:
         await self.events.sent_advertise.trigger(message)
         return await subscription.receive()
예제 #9
0
    def _get_nodes_at_distance(self, distance: int) -> Iterator[Node]:
        """Send a Nodes message containing ENRs of peers at a given node distance."""
        nodes_at_distance = self.routing_table.get_nodes_at_log_distance(
            distance)

        for node_id in nodes_at_distance:
            endpoint = self.endpoint_db.get_endpoint(node_id)
            yield Node(node_id, endpoint)
예제 #10
0
    def __init__(self,
                 private_key: keys.PrivateKey,
                 listen_on: Endpoint,
                 ) -> None:
        self._private_key = private_key
        self.public_key = private_key.public_key

        self.local_node_id = public_key_to_node_id(self.public_key)
        self.listen_on = listen_on
        self.local_node = Node(self.local_node_id, self.listen_on)

        # Datagrams
        (
            self._outbound_datagram_send_channel,
            self._outbound_datagram_receive_channel,
        ) = trio.open_memory_channel[Datagram](0)
        (
            self._inbound_datagram_send_channel,
            self._inbound_datagram_receive_channel,
        ) = trio.open_memory_channel[Datagram](0)

        # Packets
        (
            self._outbound_packet_send_channel,
            self._outbound_packet_receive_channel,
        ) = trio.open_memory_channel[NetworkPacket](0)
        (
            self._inbound_packet_send_channel,
            self._inbound_packet_receive_channel,
        ) = trio.open_memory_channel[NetworkPacket](0)

        # Messages
        (
            self._outbound_message_send_channel,
            self._outbound_message_receive_channel,
        ) = trio.open_memory_channel[MessageAPI[sedes.Serializable]](0)
        (
            self._inbound_message_send_channel,
            self._inbound_message_receive_channel,
        ) = trio.open_memory_channel[MessageAPI[sedes.Serializable]](0)

        self.events = Events()
        self.pool = Pool(
            private_key=self._private_key,
            events=self.events,
            outbound_packet_send_channel=self._outbound_packet_send_channel,
            inbound_message_send_channel=self._inbound_message_send_channel,
        )

        self.message_dispatcher = MessageDispatcher(
            self._outbound_message_send_channel,
            self._inbound_message_receive_channel,
        )

        self._ready = trio.Event()
예제 #11
0
 async def send_advertise(self, node: Node, *, key: bytes, who: Node) -> int:
     if node.node_id == self.local_node_id:
         raise ValueError("Cannot send to self")
     request_id = self.message_dispatcher.get_free_request_id(node.node_id)
     message = Message(
         Advertise(request_id, key, who.to_payload()),
         node,
     )
     self.logger.debug("Sending %s", message)
     await self.message_dispatcher.send_message(message)
     await self.events.sent_advertise.trigger(message)
     return request_id
예제 #12
0
    async def _handle_advertise_requests(self) -> None:
        async with self.client.message_dispatcher.subscribe(
                Advertise) as subscription:
            while self.manager.is_running:
                request = await subscription.receive()
                payload = request.payload
                await self.client.send_ack(request.node,
                                           request_id=payload.request_id)

                node = Node.from_payload(payload.node)

                # Queue the content to be ingested by the content manager
                try:
                    self._inbound_content_send_channel.send_nowait(
                        (node, payload.key))
                except trio.WouldBlock:
                    self.logger.error(
                        "Content processing channel is full.  Discarding "
                        "advertised content: %s@%s",
                        encode_hex(payload.key),
                        request.node,
                    )
예제 #13
0
async def test_application(bootnode, base_db_path):
    bootnodes = (Node(bootnode.client.local_node_id, bootnode.client.listen_on),)

    connected_nodes = []

    async def monitor_bootnode():
        async with bootnode.client.events.handshake_complete.subscribe() as subscription:
            async for session in subscription:
                logger.info('NODE_CONNECTED_TO_BOOTNODE: %s', humanize_node_id(session.remote_node_id))  # noqa: E501
                connected_nodes.append(session.remote_node_id)

    config = KademliaConfig(
        LOOKUP_INTERVAL=20,
        ANNOUNCE_INTERVAL=30,
        ANNOUNCE_CONCURRENCY=1,
        storage_config=StorageConfig(
            ephemeral_storage_size=1024,
            ephemeral_index_size=500,
            cache_storage_size=1024,
            cache_index_size=100
        ),
    )

    async with AsyncExitStack() as stack:
        for i in range(16):
            # small delay between starting each client.
            await trio.sleep(random.random())
            # content database
            durable_db = make_durable_db(base_db_path / f"client-{i}")
            app = ApplicationFactory(
                bootnodes=bootnodes,
                durable_db=durable_db,
                config=config,
            )
            logger.info('CLIENT-%d: %s', i, humanize_node_id(app.client.local_node_id))
            await stack.enter_async_context(background_trio_service(app))
        await trio.sleep_forever()
예제 #14
0
 async def single_lookup(self, node: Node, *,
                         distance: int) -> Tuple[Node, ...]:
     found_nodes = await self.client.find_nodes(node, distance=distance)
     return tuple(
         Node.from_payload(node_as_payload) for message in found_nodes
         for node_as_payload in message.payload.nodes)
예제 #15
0
 async def locate(self, node: Node, *, key: bytes) -> Tuple[Node, ...]:
     locations = await self.client.locate(node, key=key)
     return tuple(
         Node.from_payload(node_payload) for message in locations
         for node_payload in message.payload.nodes)
예제 #16
0
    async def iterative_lookup(
        self,
        target_id: NodeID,
        filter_self: bool = True,
    ) -> Tuple[Node, ...]:
        self.logger.debug("Starting looking up @ %s",
                          humanize_node_id(target_id))

        # tracks the nodes that have already been queried
        queried_node_ids: Set[NodeID] = set()
        # keeps track of the nodes that are unresponsive
        unresponsive_node_ids: Set[NodeID] = set()
        # accumulator of all of the valid responses received
        received_nodes: DefaultDict[
            NodeID, Set[Endpoint]] = collections.defaultdict(set)

        async def do_lookup(peer: Node) -> None:
            self.logger.debug(
                "Looking up %s via node %s",
                humanize_node_id(target_id),
                humanize_node_id(peer.node_id),
            )
            distance = compute_log_distance(peer.node_id, target_id)

            try:
                with trio.fail_after(FIND_NODES_TIMEOUT):
                    found_nodes = await self.single_lookup(
                        peer,
                        distance=distance,
                    )
            except trio.TooSlowError:
                unresponsive_node_ids.add(peer.node_id)
            else:
                if len(found_nodes) == 0:
                    unresponsive_node_ids.add(peer.node_id)
                else:
                    received_nodes[peer.node_id].add(peer.endpoint)
                    for node in found_nodes:
                        received_nodes[node.node_id].add(node.endpoint)

        @to_tuple
        def get_endpoints(node_id: NodeID) -> Iterator[Endpoint]:
            try:
                yield self.endpoint_db.get_endpoint(node_id)
            except KeyError:
                pass

            yield from received_nodes[node_id]

        for lookup_round_number in itertools.count():
            received_node_ids = tuple(received_nodes.keys())
            candidates = iter_closest_nodes(target_id, self.routing_table,
                                            received_node_ids)
            responsive_candidates = itertools.dropwhile(
                lambda node: node in unresponsive_node_ids,
                candidates,
            )
            closest_k_candidates = take(self.routing_table.bucket_size,
                                        responsive_candidates)
            closest_k_unqueried_candidates = (
                candidate for candidate in closest_k_candidates
                if candidate not in queried_node_ids)
            nodes_to_query = tuple(
                take(
                    LOOKUP_CONCURRENCY_FACTOR,
                    closest_k_unqueried_candidates,
                ))

            if nodes_to_query:
                self.logger.debug(
                    "Starting lookup round %d for %s",
                    lookup_round_number + 1,
                    humanize_node_id(target_id),
                )
                queried_node_ids.update(nodes_to_query)
                async with trio.open_nursery() as nursery:
                    for peer_id in nodes_to_query:
                        if peer_id == self.client.local_node_id:
                            continue
                        for endpoint in get_endpoints(peer_id):
                            nursery.start_soon(do_lookup,
                                               Node(peer_id, endpoint))
            else:
                self.logger.debug(
                    "Lookup for %s finished in %d rounds",
                    humanize_node_id(target_id),
                    lookup_round_number,
                )
                break

        found_nodes = tuple(
            Node(node_id, endpoint)
            for node_id, endpoints in received_nodes.items()
            for endpoint in endpoints
            if (not filter_self or node_id != self.client.local_node_id))
        sorted_found_nodes = tuple(
            sorted(
                found_nodes,
                key=lambda node: compute_distance(self.client.local_node_id,
                                                  node.node_id),
            ))
        self.logger.debug(
            "Finished looking up %s in %d rounds: Found %d nodes after querying %d nodes",
            humanize_node_id(target_id),
            lookup_round_number,
            len(found_nodes),
            len(queried_node_ids),
        )
        return sorted_found_nodes
예제 #17
0
async def main() -> None:
    DEFAULT_LISTEN_ON = Endpoint(ipaddress.IPv4Address('0.0.0.0'), 30314)

    args = parser.parse_args()
    setup_logging(args.log_level)

    if args.port is not None:
        listen_on = Endpoint(ipaddress.IPv4Address('0.0.0.0'), args.port)
    else:
        listen_on = DEFAULT_LISTEN_ON

    logger = logging.getLogger()

    if args.bootnodes is not None:
        bootnodes = tuple(
            Node.from_node_uri(node_uri) for node_uri in args.bootnodes)
    else:
        bootnodes = DEFAULT_BOOTNODES

    application_root_dir = get_xdg_alexandria_root()
    if not application_root_dir.exists():
        application_root_dir.mkdir(parents=True, exist_ok=True)

    ipc_path = application_root_dir / 'jsonrpc.ipc'

    if args.private_key_seed is None:
        node_key_path = application_root_dir / 'nodekey'
        if node_key_path.exists():
            private_key_hex = node_key_path.read_text().strip()
            private_key_bytes = decode_hex(private_key_hex)
            private_key = keys.PrivateKey(private_key_bytes)
        else:
            private_key_bytes = secrets.token_bytes(32)
            node_key_path.write_text(encode_hex(private_key_bytes))
            private_key = keys.PrivateKey(private_key_bytes)
    else:
        private_key = keys.PrivateKey(sha256(args.private_key_seed))

    durable_db_path = application_root_dir / 'durable-db'
    durable_db = DurableDB(durable_db_path)

    metrics_args: Optional[argparse.Namespace]
    if args.enable_metrics:
        metrics_args = args
    else:
        metrics_args = None

    config = KademliaConfig.from_args(args)

    alexandria = Alexandria(
        private_key=private_key,
        listen_on=listen_on,
        bootnodes=bootnodes,
        durable_db=durable_db,
        kademlia_config=config,
        ipc_path=ipc_path,
        metrics_args=metrics_args,
    )

    logger.info(ALEXANDRIA_HEADER)
    logger.info("Started main process (pid=%d)", os.getpid())
    async with background_trio_service(alexandria) as manager:
        await manager.wait_finished()
예제 #18
0
                        self.application.client.listen_on.port,
                        UPNP_PORTMAP_DURATION,
                    )
                except PortMapFailed:
                    continue
                external_endpoint = Endpoint(
                    ipaddress.IPv4Address(external_ip),
                    self.application.client.listen_on.port,
                )
                await self.application.client.events.new_external_ip.trigger(
                    external_endpoint)


DEFAULT_BOOTNODES = (
    Node.from_node_uri(
        'node://157d841a79faa0dc11180724ecca44322fa07f9b5b8950e4f10c13dcbac9e074@74.207.253.18:30314'
    ),  # noqa: E501
    Node.from_node_uri(
        'node://b8fe2b71b9138e65ee48e6a0ab3ebd63622c8e2e46c963cbf71bc351132b39af@192.155.84.246:30314'
    ),  # noqa: E501
)


async def main() -> None:
    DEFAULT_LISTEN_ON = Endpoint(ipaddress.IPv4Address('0.0.0.0'), 30314)

    args = parser.parse_args()
    setup_logging(args.log_level)

    if args.port is not None:
        listen_on = Endpoint(ipaddress.IPv4Address('0.0.0.0'), args.port)