async def test_alexandria_client_send_found_nodes(bob, bob_client, enrs, alice_alexandria_client): with trio.fail_after(2): async with bob_client.dispatcher.subscribe( TalkResponseMessage) as subscription: await alice_alexandria_client.send_found_nodes( bob.node_id, bob.endpoint, enrs=enrs, request_id=b"\x01\x02", ) with trio.fail_after(2): first_message = await subscription.receive() decoded_first_message = decode_message( first_message.message.payload) remaining_messages = [] for _ in range(decoded_first_message.payload.total - 1): remaining_messages.append(await subscription.receive()) all_decoded_messages = (decoded_first_message, ) + tuple( decode_message(message.message.payload) for message in remaining_messages) all_received_enrs = tuple( itertools.chain(*(message.payload.enrs for message in all_decoded_messages))) expected_enrs_by_node_id = {enr.node_id: enr for enr in enrs} assert len(expected_enrs_by_node_id) == len(enrs) actual_enrs_by_node_id = { enr.node_id: enr for enr in all_received_enrs } assert expected_enrs_by_node_id == actual_enrs_by_node_id
async def _feed_talk_requests(self) -> None: async with self.network.client.dispatcher.subscribe( TalkRequestMessage) as subscription: async for request in subscription: if request.message.protocol != self.protocol_id: continue try: message = decode_message(request.message.payload) except DecodingError: pass else: if message.type != AlexandriaMessageType.REQUEST: self.logger.debug( "Received non-REQUEST msg via TALKREQ: %s", message) continue self.subscription_manager.feed_subscriptions( InboundMessage( message=message, sender_node_id=request.sender_node_id, sender_endpoint=request.sender_endpoint, explicit_request_id=request.message.request_id, ))
async def _feed_talk_responses(self) -> None: async with self.network.client.dispatcher.subscribe( TalkResponseMessage) as subscription: async for response in subscription: is_known_request_id = self.request_tracker.is_request_id_active( response.sender_node_id, response.request_id, ) if not is_known_request_id: continue elif not response.message.payload: continue try: message = decode_message(response.message.payload) except DecodingError: pass else: if message.type != AlexandriaMessageType.RESPONSE: self.logger.debug( "Received non-RESPONSE msg via TALKRESP: %s", message) continue self.subscription_manager.feed_subscriptions( InboundMessage( message=message, sender_node_id=response.sender_node_id, sender_endpoint=response.sender_endpoint, explicit_request_id=response.request_id, ))
def test_pong_message_encoding_round_trip(enr_seq, advertisement_radius): payload = PongPayload(enr_seq=enr_seq, advertisement_radius=advertisement_radius) message = PongMessage(payload) encoded = message.to_wire_bytes() result = decode_message(encoded) assert result == message
def test_found_nodes_message_encoding_round_trip(num_enr_records): enrs = tuple(ENRFactory() for _ in range(num_enr_records)) encoded_enrs = tuple(rlp.encode(enr) for enr in enrs) payload = FoundNodesPayload(num_enr_records, encoded_enrs) message = FoundNodesMessage(payload) encoded = message.to_wire_bytes() result = decode_message(encoded) assert result.payload == message.payload
async def _request( self, node_id: NodeID, endpoint: Endpoint, request: AlexandriaMessage[Any], response_class: Type[TAlexandriaMessage], request_id: Optional[bytes] = None, ) -> TAlexandriaMessage: # # Request ID Shenanigans # # We need a subscription API for alexandria messages in order to be # able to cleanly implement functionality like responding to ping # messages with pong messages. # # We accomplish this by monitoring all TALKREQUEST and TALKRESPONSE # messages that occur on the alexandria protocol. For TALKREQUEST based # messages we can naively monitor incoming messages and match them # against the protocol_id. For the TALKRESPONSE however, the only way # for us to know that an incoming message belongs to this protocol is # to match it against an in-flight request (which we implicitely know # is part of the alexandria protocol). # # In order to do this, we need to know which request ids are active for # Alexandria protocol messages. We do this by using a separate # RequestTrackerAPI. The `request_id` for messages is still acquired # from the core tracker located on the base protocol `ClientAPI`. We # then feed this into our local tracker, which allows us to query it # upon receiving an incoming TALKRESPONSE to see if the response is to # a message from this protocol. if request.type != AlexandriaMessageType.REQUEST: raise TypeError("%s is not of type REQUEST", request) if request_id is None: request_id = self.network.client.request_tracker.get_free_request_id( node_id) request_data = request.to_wire_bytes() if len(request_data) > MAX_PAYLOAD_SIZE: raise Exception( "Payload too large: payload=%d max_size=%d", len(request_data), MAX_PAYLOAD_SIZE, ) with self.request_tracker.reserve_request_id(node_id, request_id): response_data = await self.network.talk( node_id, protocol=ALEXANDRIA_PROTOCOL_ID, payload=request_data, endpoint=endpoint, request_id=request_id, ) response = decode_message(response_data) if type(response) is not response_class: raise DecodingError( f"Invalid response. expected={response_class} got={type(response)}" ) return response # type: ignore
async def test_alexandria_client_send_find_nodes(alice, bob, bob_client, alice_alexandria_client): async with bob_client.dispatcher.subscribe( TalkRequestMessage) as subscription: await alice_alexandria_client.send_find_nodes( bob.node_id, bob.endpoint, distances=(0, 255, 254), ) with trio.fail_after(1): talk_request = await subscription.receive() message = decode_message(talk_request.message.payload) assert isinstance(message, FindNodesMessage) assert message.payload.distances == (0, 255, 254)
async def test_alexandria_client_send_ack(bob, bob_network, alice_alexandria_client): async with bob_network.dispatcher.subscribe( TalkResponseMessage) as subscription: await alice_alexandria_client.send_ack( bob.node_id, bob.endpoint, advertisement_radius=12345, acked=(True, ), request_id=b"\x01", ) with trio.fail_after(1): talk_response = await subscription.receive() message = decode_message(talk_response.message.payload) assert isinstance(message, AckMessage) assert message.payload.advertisement_radius == 12345
async def test_alexandria_client_send_advertisements(alice, bob, bob_network, alice_alexandria_client): advertisements = (AdvertisementFactory(private_key=alice.private_key), ) async with bob_network.dispatcher.subscribe( TalkRequestMessage) as subscription: await alice_alexandria_client.send_advertisements( bob.node_id, bob.endpoint, advertisements=advertisements, ) with trio.fail_after(1): talk_response = await subscription.receive() message = decode_message(talk_response.message.payload) assert isinstance(message, AdvertiseMessage) assert message.payload == advertisements
async def test_alexandria_client_send_find_content(bob, bob_network, alice_alexandria_client): content_key = b"test-content-key" async with bob_network.dispatcher.subscribe( TalkRequestMessage) as subscription: await alice_alexandria_client.send_find_content( bob.node_id, bob.endpoint, content_key=content_key, ) with trio.fail_after(1): talk_response = await subscription.receive() message = decode_message(talk_response.message.payload) assert isinstance(message, FindContentMessage) assert message.payload.content_key == content_key
def test_found_content_message_encoding_round_trip(data): is_content = data.draw(st.booleans()) if is_content: content = data.draw(st.binary(min_size=32, max_size=32)) enrs = () else: num_enrs = data.draw(st.integers(min_value=0, max_value=3)) enrs = tuple(ENRFactory() for _ in range(num_enrs)) content = b"" encoded_enrs = tuple(rlp.encode(enr) for enr in enrs) payload = FoundContentPayload(encoded_enrs, content) message = FoundContentMessage(payload) encoded = message.to_wire_bytes() result = decode_message(encoded) assert result.payload == message.payload
async def test_alexandria_client_send_locate(bob, bob_network, alice_alexandria_client): async with bob_network.dispatcher.subscribe( TalkRequestMessage) as subscription: await alice_alexandria_client.send_locate( bob.node_id, bob.endpoint, content_key=b"\x01test-key", request_id=b"\x01", ) with trio.fail_after(1): talk_response = await subscription.receive() assert talk_response.request_id == b"\x01" message = decode_message(talk_response.message.payload) assert isinstance(message, LocateMessage) assert message.payload.content_key == b"\x01test-key"
async def test_alexandria_client_send_content(bob, bob_network, alice_alexandria_client): async with bob_network.dispatcher.subscribe( TalkResponseMessage) as subscription: await alice_alexandria_client.send_content( bob.node_id, bob.endpoint, is_proof=True, payload=b"test-payload", request_id=b"\x01", ) with trio.fail_after(1): talk_response = await subscription.receive() message = decode_message(talk_response.message.payload) assert isinstance(message, ContentMessage) assert message.payload.is_proof is True assert message.payload.payload == b"test-payload"
async def test_alexandria_client_send_locations_single_message( alice, bob, bob_network, alice_alexandria_client): advertisements = (AdvertisementFactory(private_key=alice.private_key), ) async with bob_network.dispatcher.subscribe( TalkResponseMessage) as subscription: await alice_alexandria_client.send_locations( bob.node_id, bob.endpoint, advertisements=advertisements, request_id=b"\x01") with trio.fail_after(1): talk_request = await subscription.receive() assert talk_request.request_id == b"\x01" message = decode_message(talk_request.message.payload) assert isinstance(message, LocationsMessage) assert message.payload.total == 1 assert message.payload.locations == advertisements
async def test_alexandria_client_send_ping(bob, bob_network, alice_alexandria_client): async with bob_network.dispatcher.subscribe( TalkRequestMessage) as subscription: await alice_alexandria_client.send_ping( bob.node_id, bob.endpoint, enr_seq=1234, advertisement_radius=4321, request_id=b"\x01\x02", ) with trio.fail_after(1): inbound_msg = await subscription.receive() assert inbound_msg.request_id == b"\x01\x02" message = decode_message(inbound_msg.message.payload) assert isinstance(message, PingMessage) assert message.payload.enr_seq == 1234 assert message.payload.advertisement_radius == 4321
async def test_alexandria_client_send_get_content(bob, bob_network, alice_alexandria_client): content_key = b"test-content-key" start_chunk_index = 5 max_chunks = 16 async with bob_network.dispatcher.subscribe( TalkRequestMessage) as subscription: await alice_alexandria_client.send_get_content( bob.node_id, bob.endpoint, content_key=content_key, start_chunk_index=start_chunk_index, max_chunks=max_chunks, ) with trio.fail_after(1): talk_response = await subscription.receive() message = decode_message(talk_response.message.payload) assert isinstance(message, GetContentMessage) assert message.payload.content_key == content_key assert message.payload.start_chunk_index == start_chunk_index assert message.payload.max_chunks == max_chunks
def test_ack_message_encoding_round_trip(advertisement_radius): payload = AckPayload(advertisement_radius, (True, False)) message = AckMessage(payload) encoded = message.to_wire_bytes() result = decode_message(encoded) assert result == message
def test_advertisement_message_encoding_round_trip(advertisements): message = AdvertiseMessage(advertisements) encoded = message.to_wire_bytes() result = decode_message(encoded) assert result == message
def test_find_nodes_message_encoding_round_trip(distances): payload = FindNodesPayload(distances) message = FindNodesMessage(payload) encoded = message.to_wire_bytes() result = decode_message(encoded) assert result == message
def test_locations_message_encoding_round_trip(advertisements): payload = LocationsPayload(len(advertisements), advertisements) message = LocationsMessage(payload) encoded = message.to_wire_bytes() result = decode_message(encoded) assert result == message
def test_locate_message_encoding_round_trip(content_key): payload = LocatePayload(content_key) message = LocateMessage(payload) encoded = message.to_wire_bytes() result = decode_message(encoded) assert result == message
def test_find_content_message_encoding_round_trip(content_key): payload = FindContentPayload(content_key) message = FindContentMessage(payload) encoded = message.to_wire_bytes() result = decode_message(encoded) assert result == message