async def test_auth_rpc_wrapper(): class Servicer: async def rpc_increment( self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse: assert request.peer.endpoint == '127.0.0.1:1111' assert request.auth.client_access_token.username == 'alice' response = dht_pb2.PingResponse() response.sender_endpoint = '127.0.0.1:2222' return response class Client: def __init__(self, servicer: Servicer): self._servicer = servicer async def rpc_increment( self, request: dht_pb2.PingRequest) -> dht_pb2.PingResponse: return await self._servicer.rpc_increment(request) servicer = AuthRPCWrapper(Servicer(), AuthRole.SERVICER, MockAuthorizer(RSAPrivateKey(), 'bob')) client = AuthRPCWrapper(Client(servicer), AuthRole.CLIENT, MockAuthorizer(RSAPrivateKey(), 'alice')) request = dht_pb2.PingRequest() request.peer.endpoint = '127.0.0.1:1111' response = await client.rpc_increment(request) assert response.sender_endpoint == '127.0.0.1:2222' assert response.auth.service_access_token.username == 'bob'
async def test_call_unary_handler_error(handle_name="handle"): async def error_handler(request, context): raise ValueError('boom') server = await P2P.create() server_pid = server._child.pid await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse) assert is_process_running(server_pid) nodes = bootstrap_from([server]) client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) client_pid = client._child.pid assert is_process_running(client_pid) await client.wait_for_at_least_n_peers(1) ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo( node_id=client.id.encode(), rpc_port=client._host_port), validate=True) libp2p_server_id = PeerID.from_base58(server.id) stream_info, reader, writer = await client._client.stream_open( libp2p_server_id, (handle_name, )) await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer) result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader) assert result is None assert err.message == 'boom' await server.stop_listening() await server.shutdown() await client.shutdown()
async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"): handler_cancelled = False async def ping_handler(request, context): try: await asyncio.sleep(2) except asyncio.CancelledError: nonlocal handler_cancelled handler_cancelled = True return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo( node_id=context.id.encode(), rpc_port=context.port), sender_endpoint=context.handle_name, available=True) server_primary = await P2P.create() server = await replicate_if_needed(server_primary, replicate) server_pid = server_primary._child.pid await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest, dht_pb2.PingResponse) assert is_process_running(server_pid) nodes = bootstrap_from([server]) client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes) client = await replicate_if_needed(client_primary, replicate) client_pid = client_primary._child.pid assert is_process_running(client_pid) ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo( node_id=client.id.encode(), rpc_port=client._host_port), validate=True) expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo( node_id=server.id.encode(), rpc_port=server._host_port), sender_endpoint=handle_name, available=True) await client.wait_for_at_least_n_peers(1) libp2p_server_id = PeerID.from_base58(server.id) stream_info, reader, writer = await client._client.stream_open( libp2p_server_id, (handle_name, )) await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer) if should_cancel: writer.close() await asyncio.sleep(1) assert handler_cancelled else: result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader) assert err is None assert result == expected_response assert not handler_cancelled await server.stop_listening() await server_primary.shutdown() assert not is_process_running(server_pid) await client_primary.shutdown() assert not is_process_running(client_pid)
async def rpc_store( self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse: """ Some node wants us to store this (key, value) pair """ if request.peer: # if requested, add peer to the routing table asyncio.create_task( self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context)) assert len(request.keys) == len(request.values) == len( request.expiration_time) == len(request.in_cache) response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info) keys = map(DHTID.from_bytes, request.keys) for key_id, tag, value_bytes, expiration_time, in_cache in zip( keys, request.subkeys, request.values, request.expiration_time, request.in_cache): storage = self.cache if in_cache else self.storage if tag == self.IS_REGULAR_VALUE: # store normal value without subkeys response.store_ok.append( storage.store(key_id, value_bytes, expiration_time)) elif tag == self.IS_DICTIONARY: # store an entire dictionary with several subkeys value_dictionary = self.serializer.loads(value_bytes) assert isinstance(value_dictionary, DictionaryDHTValue) response.store_ok.append( all( storage.store_subkey(key_id, subkey, item.value, item.expiration_time) for subkey, item in value_dictionary.items())) else: # add a new entry into an existing dictionary value or create a new dictionary with one sub-key subkey = self.serializer.loads(tag) response.store_ok.append( storage.store_subkey(key_id, subkey, value_bytes, expiration_time)) return response
async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse: """ Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value) """ if request.peer: # if requested, add peer to the routing table asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context)) response = dht_pb2.FindResponse(results=[], peer=self.node_info) for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)): maybe_item = self.storage.get(key_id) cached_item = self.cache.get(key_id) if cached_item is not None and (maybe_item is None or cached_item.expiration_time > maybe_item.expiration_time): maybe_item = cached_item if maybe_item is None: # value not found item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND) elif isinstance(maybe_item.value, DictionaryDHTValue): item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_item.value), expiration_time=maybe_item.expiration_time) else: # found regular value item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_item.value, expiration_time=maybe_item.expiration_time) for node_id, endpoint in self.routing_table.get_nearest_neighbors( key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)): item.nearest_node_ids.append(node_id.to_bytes()) item.nearest_endpoints.append(endpoint) response.results.append(item) return response
async def get_outgoing_request_endpoint(self, peer: Endpoint) -> Optional[Endpoint]: """ ask this peer how it perceives this node's outgoing request address """ try: async with self.rpc_semaphore: ping_request = dht_pb2.PingRequest(peer=None, validate=False) response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout) if response.sender_endpoint != dht_pb2.PingResponse.sender_endpoint.DESCRIPTOR.default_value: return response.sender_endpoint except grpc.aio.AioRpcError as error: logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]: """ Get peer's node id and add him to the routing table. If peer doesn't respond, return None :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888 :param validate: if True, validates that node's endpoint is available :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table :return: node's DHTID, if peer responded and decided to send his node_id """ try: async with self.rpc_semaphore: ping_request = dht_pb2.PingRequest(peer=self.node_info, validate=validate) time_requested = get_dht_time() response = await self._get_dht_stub(peer).rpc_ping( ping_request, timeout=self.wait_timeout) time_responded = get_dht_time() except grpc.aio.AioRpcError as error: logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}") response = None responded = bool(response and response.peer and response.peer.node_id) if responded and validate: try: if self.server is not None and not response.available: raise ValidationError( f"Peer {peer} couldn't access this node at {response.sender_endpoint} . " f"Make sure that this port is open for incoming requests." ) if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value: if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \ response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS: raise ValidationError( f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds " f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})" ) except ValidationError as e: if strict: raise else: logger.warning(repr(e)) peer_id = DHTID.from_bytes( response.peer.node_id) if responded else None asyncio.create_task( self.update_routing_table(peer_id, peer, responded=responded)) return peer_id
async def test_valid_request_and_response(): client_authorizer = MockAuthorizer(RSAPrivateKey()) service_authorizer = MockAuthorizer(RSAPrivateKey()) request = dht_pb2.PingRequest() request.peer.endpoint = '127.0.0.1:7777' await client_authorizer.sign_request(request, service_authorizer.local_public_key) assert await service_authorizer.validate_request(request) response = dht_pb2.PingResponse() response.sender_endpoint = '127.0.0.1:31337' await service_authorizer.sign_response(response, request) assert await client_authorizer.validate_response(response, request)
async def test_invalid_access_token(): client_authorizer = MockAuthorizer(RSAPrivateKey()) service_authorizer = MockAuthorizer(RSAPrivateKey()) request = dht_pb2.PingRequest() request.peer.endpoint = '127.0.0.1:7777' await client_authorizer.sign_request(request, service_authorizer.local_public_key) # Break the access token signature request.auth.client_access_token.signature = b'broken' assert not await service_authorizer.validate_request(request) response = dht_pb2.PingResponse() response.sender_endpoint = '127.0.0.1:31337' await service_authorizer.sign_response(response, request) # Break the access token signature response.auth.service_access_token.signature = b'broken' assert not await client_authorizer.validate_response(response, request)
async def test_invalid_signatures(): client_authorizer = MockAuthorizer(RSAPrivateKey()) service_authorizer = MockAuthorizer(RSAPrivateKey()) request = dht_pb2.PingRequest() request.peer.endpoint = '127.0.0.1:7777' await client_authorizer.sign_request(request, service_authorizer.local_public_key) # A man-in-the-middle attacker changes the request content request.peer.endpoint = '127.0.0.2:7777' assert not await service_authorizer.validate_request(request) response = dht_pb2.PingResponse() response.sender_endpoint = '127.0.0.1:31337' await service_authorizer.sign_response(response, request) # A man-in-the-middle attacker changes the response content response.sender_endpoint = '127.0.0.2:31337' assert not await client_authorizer.validate_response(response, request)