Ejemplo n.º 1
0
async def test_alexandria_client_send_found_nodes(bob, bob_client, enrs,
                                                  alice_alexandria_client):
    with trio.fail_after(2):
        async with bob_client.dispatcher.subscribe(
                TalkResponseMessage) as subscription:
            await alice_alexandria_client.send_found_nodes(
                bob.node_id,
                bob.endpoint,
                enrs=enrs,
                request_id=b"\x01\x02",
            )

            with trio.fail_after(2):
                first_message = await subscription.receive()
                decoded_first_message = decode_message(
                    first_message.message.payload)
                remaining_messages = []
                for _ in range(decoded_first_message.payload.total - 1):
                    remaining_messages.append(await subscription.receive())

        all_decoded_messages = (decoded_first_message, ) + tuple(
            decode_message(message.message.payload)
            for message in remaining_messages)
        all_received_enrs = tuple(
            itertools.chain(*(message.payload.enrs
                              for message in all_decoded_messages)))
        expected_enrs_by_node_id = {enr.node_id: enr for enr in enrs}
        assert len(expected_enrs_by_node_id) == len(enrs)
        actual_enrs_by_node_id = {
            enr.node_id: enr
            for enr in all_received_enrs
        }

        assert expected_enrs_by_node_id == actual_enrs_by_node_id
Ejemplo n.º 2
0
    async def _feed_talk_requests(self) -> None:
        async with self.network.client.dispatcher.subscribe(
                TalkRequestMessage) as subscription:
            async for request in subscription:
                if request.message.protocol != self.protocol_id:
                    continue

                try:
                    message = decode_message(request.message.payload)
                except DecodingError:
                    pass
                else:
                    if message.type != AlexandriaMessageType.REQUEST:
                        self.logger.debug(
                            "Received non-REQUEST msg via TALKREQ: %s",
                            message)
                        continue

                    self.subscription_manager.feed_subscriptions(
                        InboundMessage(
                            message=message,
                            sender_node_id=request.sender_node_id,
                            sender_endpoint=request.sender_endpoint,
                            explicit_request_id=request.message.request_id,
                        ))
Ejemplo n.º 3
0
    async def _feed_talk_responses(self) -> None:
        async with self.network.client.dispatcher.subscribe(
                TalkResponseMessage) as subscription:
            async for response in subscription:
                is_known_request_id = self.request_tracker.is_request_id_active(
                    response.sender_node_id,
                    response.request_id,
                )
                if not is_known_request_id:
                    continue
                elif not response.message.payload:
                    continue

                try:
                    message = decode_message(response.message.payload)
                except DecodingError:
                    pass
                else:
                    if message.type != AlexandriaMessageType.RESPONSE:
                        self.logger.debug(
                            "Received non-RESPONSE msg via TALKRESP: %s",
                            message)
                        continue

                    self.subscription_manager.feed_subscriptions(
                        InboundMessage(
                            message=message,
                            sender_node_id=response.sender_node_id,
                            sender_endpoint=response.sender_endpoint,
                            explicit_request_id=response.request_id,
                        ))
Ejemplo n.º 4
0
def test_pong_message_encoding_round_trip(enr_seq, advertisement_radius):
    payload = PongPayload(enr_seq=enr_seq,
                          advertisement_radius=advertisement_radius)
    message = PongMessage(payload)
    encoded = message.to_wire_bytes()
    result = decode_message(encoded)
    assert result == message
Ejemplo n.º 5
0
def test_found_nodes_message_encoding_round_trip(num_enr_records):
    enrs = tuple(ENRFactory() for _ in range(num_enr_records))
    encoded_enrs = tuple(rlp.encode(enr) for enr in enrs)
    payload = FoundNodesPayload(num_enr_records, encoded_enrs)
    message = FoundNodesMessage(payload)
    encoded = message.to_wire_bytes()
    result = decode_message(encoded)
    assert result.payload == message.payload
Ejemplo n.º 6
0
    async def _request(
        self,
        node_id: NodeID,
        endpoint: Endpoint,
        request: AlexandriaMessage[Any],
        response_class: Type[TAlexandriaMessage],
        request_id: Optional[bytes] = None,
    ) -> TAlexandriaMessage:
        #
        # Request ID Shenanigans
        #
        # We need a subscription API for alexandria messages in order to be
        # able to cleanly implement functionality like responding to ping
        # messages with pong messages.
        #
        # We accomplish this by monitoring all TALKREQUEST and TALKRESPONSE
        # messages that occur on the alexandria protocol. For TALKREQUEST based
        # messages we can naively monitor incoming messages and match them
        # against the protocol_id. For the TALKRESPONSE however, the only way
        # for us to know that an incoming message belongs to this protocol is
        # to match it against an in-flight request (which we implicitely know
        # is part of the alexandria protocol).
        #
        # In order to do this, we need to know which request ids are active for
        # Alexandria protocol messages. We do this by using a separate
        # RequestTrackerAPI. The `request_id` for messages is still acquired
        # from the core tracker located on the base protocol `ClientAPI`.  We
        # then feed this into our local tracker, which allows us to query it
        # upon receiving an incoming TALKRESPONSE to see if the response is to
        # a message from this protocol.
        if request.type != AlexandriaMessageType.REQUEST:
            raise TypeError("%s is not of type REQUEST", request)
        if request_id is None:
            request_id = self.network.client.request_tracker.get_free_request_id(
                node_id)
        request_data = request.to_wire_bytes()
        if len(request_data) > MAX_PAYLOAD_SIZE:
            raise Exception(
                "Payload too large:  payload=%d  max_size=%d",
                len(request_data),
                MAX_PAYLOAD_SIZE,
            )
        with self.request_tracker.reserve_request_id(node_id, request_id):
            response_data = await self.network.talk(
                node_id,
                protocol=ALEXANDRIA_PROTOCOL_ID,
                payload=request_data,
                endpoint=endpoint,
                request_id=request_id,
            )

        response = decode_message(response_data)
        if type(response) is not response_class:
            raise DecodingError(
                f"Invalid response. expected={response_class}  got={type(response)}"
            )
        return response  # type: ignore
Ejemplo n.º 7
0
async def test_alexandria_client_send_find_nodes(alice, bob, bob_client,
                                                 alice_alexandria_client):
    async with bob_client.dispatcher.subscribe(
            TalkRequestMessage) as subscription:
        await alice_alexandria_client.send_find_nodes(
            bob.node_id,
            bob.endpoint,
            distances=(0, 255, 254),
        )
        with trio.fail_after(1):
            talk_request = await subscription.receive()
        message = decode_message(talk_request.message.payload)
        assert isinstance(message, FindNodesMessage)
        assert message.payload.distances == (0, 255, 254)
Ejemplo n.º 8
0
async def test_alexandria_client_send_ack(bob, bob_network,
                                          alice_alexandria_client):
    async with bob_network.dispatcher.subscribe(
            TalkResponseMessage) as subscription:
        await alice_alexandria_client.send_ack(
            bob.node_id,
            bob.endpoint,
            advertisement_radius=12345,
            acked=(True, ),
            request_id=b"\x01",
        )
        with trio.fail_after(1):
            talk_response = await subscription.receive()
        message = decode_message(talk_response.message.payload)
        assert isinstance(message, AckMessage)
        assert message.payload.advertisement_radius == 12345
Ejemplo n.º 9
0
async def test_alexandria_client_send_advertisements(alice, bob, bob_network,
                                                     alice_alexandria_client):
    advertisements = (AdvertisementFactory(private_key=alice.private_key), )

    async with bob_network.dispatcher.subscribe(
            TalkRequestMessage) as subscription:
        await alice_alexandria_client.send_advertisements(
            bob.node_id,
            bob.endpoint,
            advertisements=advertisements,
        )
        with trio.fail_after(1):
            talk_response = await subscription.receive()
        message = decode_message(talk_response.message.payload)
        assert isinstance(message, AdvertiseMessage)
        assert message.payload == advertisements
Ejemplo n.º 10
0
async def test_alexandria_client_send_find_content(bob, bob_network,
                                                   alice_alexandria_client):
    content_key = b"test-content-key"

    async with bob_network.dispatcher.subscribe(
            TalkRequestMessage) as subscription:
        await alice_alexandria_client.send_find_content(
            bob.node_id,
            bob.endpoint,
            content_key=content_key,
        )
        with trio.fail_after(1):
            talk_response = await subscription.receive()
        message = decode_message(talk_response.message.payload)
        assert isinstance(message, FindContentMessage)
        assert message.payload.content_key == content_key
Ejemplo n.º 11
0
def test_found_content_message_encoding_round_trip(data):
    is_content = data.draw(st.booleans())
    if is_content:
        content = data.draw(st.binary(min_size=32, max_size=32))
        enrs = ()
    else:
        num_enrs = data.draw(st.integers(min_value=0, max_value=3))
        enrs = tuple(ENRFactory() for _ in range(num_enrs))
        content = b""

    encoded_enrs = tuple(rlp.encode(enr) for enr in enrs)
    payload = FoundContentPayload(encoded_enrs, content)
    message = FoundContentMessage(payload)
    encoded = message.to_wire_bytes()
    result = decode_message(encoded)
    assert result.payload == message.payload
Ejemplo n.º 12
0
async def test_alexandria_client_send_locate(bob, bob_network,
                                             alice_alexandria_client):
    async with bob_network.dispatcher.subscribe(
            TalkRequestMessage) as subscription:
        await alice_alexandria_client.send_locate(
            bob.node_id,
            bob.endpoint,
            content_key=b"\x01test-key",
            request_id=b"\x01",
        )
        with trio.fail_after(1):
            talk_response = await subscription.receive()
        assert talk_response.request_id == b"\x01"
        message = decode_message(talk_response.message.payload)
        assert isinstance(message, LocateMessage)
        assert message.payload.content_key == b"\x01test-key"
Ejemplo n.º 13
0
async def test_alexandria_client_send_content(bob, bob_network,
                                              alice_alexandria_client):

    async with bob_network.dispatcher.subscribe(
            TalkResponseMessage) as subscription:
        await alice_alexandria_client.send_content(
            bob.node_id,
            bob.endpoint,
            is_proof=True,
            payload=b"test-payload",
            request_id=b"\x01",
        )
        with trio.fail_after(1):
            talk_response = await subscription.receive()
        message = decode_message(talk_response.message.payload)
        assert isinstance(message, ContentMessage)
        assert message.payload.is_proof is True
        assert message.payload.payload == b"test-payload"
Ejemplo n.º 14
0
async def test_alexandria_client_send_locations_single_message(
        alice, bob, bob_network, alice_alexandria_client):
    advertisements = (AdvertisementFactory(private_key=alice.private_key), )

    async with bob_network.dispatcher.subscribe(
            TalkResponseMessage) as subscription:
        await alice_alexandria_client.send_locations(
            bob.node_id,
            bob.endpoint,
            advertisements=advertisements,
            request_id=b"\x01")
        with trio.fail_after(1):
            talk_request = await subscription.receive()
        assert talk_request.request_id == b"\x01"
        message = decode_message(talk_request.message.payload)
        assert isinstance(message, LocationsMessage)
        assert message.payload.total == 1
        assert message.payload.locations == advertisements
Ejemplo n.º 15
0
async def test_alexandria_client_send_ping(bob, bob_network,
                                           alice_alexandria_client):
    async with bob_network.dispatcher.subscribe(
            TalkRequestMessage) as subscription:
        await alice_alexandria_client.send_ping(
            bob.node_id,
            bob.endpoint,
            enr_seq=1234,
            advertisement_radius=4321,
            request_id=b"\x01\x02",
        )
        with trio.fail_after(1):
            inbound_msg = await subscription.receive()

        assert inbound_msg.request_id == b"\x01\x02"
        message = decode_message(inbound_msg.message.payload)
        assert isinstance(message, PingMessage)
        assert message.payload.enr_seq == 1234
        assert message.payload.advertisement_radius == 4321
Ejemplo n.º 16
0
async def test_alexandria_client_send_get_content(bob, bob_network,
                                                  alice_alexandria_client):
    content_key = b"test-content-key"
    start_chunk_index = 5
    max_chunks = 16

    async with bob_network.dispatcher.subscribe(
            TalkRequestMessage) as subscription:
        await alice_alexandria_client.send_get_content(
            bob.node_id,
            bob.endpoint,
            content_key=content_key,
            start_chunk_index=start_chunk_index,
            max_chunks=max_chunks,
        )
        with trio.fail_after(1):
            talk_response = await subscription.receive()
        message = decode_message(talk_response.message.payload)
        assert isinstance(message, GetContentMessage)
        assert message.payload.content_key == content_key
        assert message.payload.start_chunk_index == start_chunk_index
        assert message.payload.max_chunks == max_chunks
Ejemplo n.º 17
0
def test_ack_message_encoding_round_trip(advertisement_radius):
    payload = AckPayload(advertisement_radius, (True, False))
    message = AckMessage(payload)
    encoded = message.to_wire_bytes()
    result = decode_message(encoded)
    assert result == message
Ejemplo n.º 18
0
def test_advertisement_message_encoding_round_trip(advertisements):
    message = AdvertiseMessage(advertisements)
    encoded = message.to_wire_bytes()
    result = decode_message(encoded)
    assert result == message
Ejemplo n.º 19
0
def test_find_nodes_message_encoding_round_trip(distances):
    payload = FindNodesPayload(distances)
    message = FindNodesMessage(payload)
    encoded = message.to_wire_bytes()
    result = decode_message(encoded)
    assert result == message
Ejemplo n.º 20
0
def test_locations_message_encoding_round_trip(advertisements):
    payload = LocationsPayload(len(advertisements), advertisements)
    message = LocationsMessage(payload)
    encoded = message.to_wire_bytes()
    result = decode_message(encoded)
    assert result == message
Ejemplo n.º 21
0
def test_locate_message_encoding_round_trip(content_key):
    payload = LocatePayload(content_key)
    message = LocateMessage(payload)
    encoded = message.to_wire_bytes()
    result = decode_message(encoded)
    assert result == message
Ejemplo n.º 22
0
def test_find_content_message_encoding_round_trip(content_key):
    payload = FindContentPayload(content_key)
    message = FindContentMessage(payload)
    encoded = message.to_wire_bytes()
    result = decode_message(encoded)
    assert result == message