def test_composite_validator(validators_for_app): validator = CompositeValidator(validators_for_app['A']) assert ([type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]) validator.extend(validators_for_app['B']) assert ([type(item) for item in validator._validators] == [SchemaValidator, RSASignatureValidator]) assert len(validator._validators[0]._schemas) == 2 local_public_key = validators_for_app['A'][0].local_public_key record = DHTRecord(key=DHTID.generate(source='field_b').to_bytes(), subkey=DHTProtocol.serializer.dumps(local_public_key), value=DHTProtocol.serializer.dumps(777), expiration_time=hivemind.get_dht_time() + 10) signed_record = dataclasses.replace(record, value=validator.sign_value(record)) # Expect only one signature since two RSASignatureValidatos have been merged assert signed_record.value.count(b'[signature:') == 1 # Expect successful validation since the second SchemaValidator has been merged to the first assert validator.validate(signed_record) assert validator.strip_value(signed_record) == record.value record = DHTRecord(key=DHTID.generate(source='unknown_key').to_bytes(), subkey=DHTProtocol.IS_REGULAR_VALUE, value=DHTProtocol.serializer.dumps(777), expiration_time=hivemind.get_dht_time() + 10) signed_record = dataclasses.replace(record, value=validator.sign_value(record)) assert signed_record.value.count(b'[signature:') == 0 # Expect failed validation since `unknown_key` is not a part of any schema assert not validator.validate(signed_record)
async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[Dict[ DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, Endpoint]]]]: """ Request keys from a peer. For each key, look for its (value, expiration time) locally and k additional peers that are most likely to have this key (ranked by XOR distance) :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors] value: value stored by the recipient with that key, or None if peer doesn't have this value expiration time: expiration time of the returned value, None if no value was found neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table If peer didn't respond, returns None """ keys = list(keys) find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info) try: async with self.rpc_semaphore: response = await self._get_dht_stub(peer).rpc_find(find_request, timeout=self.wait_timeout) if response.peer and response.peer.node_id: peer_id = DHTID.from_bytes(response.peer.node_id) asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True)) assert len(keys) == len(response.results), "DHTProtocol: response is not aligned with keys" output = {} # unpack data depending on its type for key, result in zip(keys, response.results): key_bytes = DHTID.to_bytes(key) nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids), result.nearest_endpoints)) if result.type == dht_pb2.NOT_FOUND: output[key] = None, nearest elif result.type == dht_pb2.FOUND_REGULAR: if not self._validate_record( key_bytes, self.IS_REGULAR_VALUE, result.value, result.expiration_time): output[key] = None, nearest continue output[key] = ValueWithExpiration(result.value, result.expiration_time), nearest elif result.type == dht_pb2.FOUND_DICTIONARY: value_dictionary = self.serializer.loads(result.value) if not self._validate_dictionary(key_bytes, value_dictionary): output[key] = None, nearest continue output[key] = ValueWithExpiration(value_dictionary, result.expiration_time), nearest else: logger.error(f"Unknown result type: {result.type}") return output except grpc.aio.AioRpcError as error: logger.debug(f"DHTProtocol failed to find at {peer}: {error.code()}") asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
def __init__(self, schema: pydantic.BaseModel, *, allow_extra_keys: bool = True, prefix: Optional[str] = None): """ :param schema: The Pydantic model (a subclass of pydantic.BaseModel). You must always use strict types for the number fields (e.g. ``StrictInt`` instead of ``int``, ``confloat(strict=True, ge=0.0)`` instead of ``confloat(ge=0.0)``, etc.). See the validate() docstring for details. The model will be patched to adjust it for the schema validation. :param allow_extra_keys: Whether to allow keys that are not defined in the schema. If a SchemaValidator is merged with another SchemaValidator, this option applies to keys that are not defined in each of the schemas. :param prefix: (optional) Add ``prefix + '_'`` to the names of all schema fields. """ self._patch_schema(schema) self._schemas = [schema] self._key_id_to_field_name = {} for field in schema.__fields__.values(): raw_key = f'{prefix}_{field.name}' if prefix is not None else field.name self._key_id_to_field_name[DHTID.generate( source=raw_key).to_bytes()] = field.name self._allow_extra_keys = allow_extra_keys
async def get_many( self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None, **kwargs ) -> Dict[DHTKey, Union[Optional[ValueWithExpiration[DHTValue]], Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]: """ Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found. :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found) :param sufficient_expiration_time: if the search finds a value that expires after this time, default = time of call, find any value that did not expire by the time of call If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration :param kwargs: for full list of parameters, see DHTNode.get_many_by_id :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key :note: in order to check if get returned a value, please check if (expiration_time is None) """ keys = tuple(keys) key_ids = [DHTID.generate(key) for key in keys] id_to_original_key = dict(zip(key_ids, keys)) results_by_id = await self.get_many_by_id(key_ids, sufficient_expiration_time, **kwargs) return { id_to_original_key[key]: result_or_future for key, result_or_future in results_by_id.items() }
async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext): """ Some node wants us to add it to our routing table. """ response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(), dht_time=get_dht_time(), available=False) if request.peer and request.peer.node_id and request.peer.rpc_port: sender_id = DHTID.from_bytes(request.peer.node_id) if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value: sender_endpoint = request.peer.endpoint # if peer has preferred endpoint, use it else: sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port) response.sender_endpoint = sender_endpoint if request.validate: response.available = await self.call_ping( response.sender_endpoint, validate=False) == sender_id asyncio.create_task( self.update_routing_table(sender_id, sender_endpoint, responded=response.available or not request.validate)) return response
async def create(cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float, parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None, channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol: """ A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes. As a side-effect, DHTProtocol also maintains a routing table as described in https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf See DHTNode (node.py) for a more detailed description. :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes, for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine Only the call_* methods are meant to be called publicly, e.g. from DHTNode Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp """ self = cls(_initialized_with_create=True) self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas self.wait_timeout, self.channel_options = wait_timeout, tuple( channel_options) self.storage, self.cache = DHTLocalStorage(), DHTLocalStorage( maxsize=cache_size) self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo) self.rpc_semaphore = asyncio.Semaphore( parallel_rpc if parallel_rpc is not None else float('inf')) if listen: # set up server to process incoming rpc requests grpc.aio.init_grpc_aio() self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS) dht_grpc.add_DHTServicer_to_server(self, self.server) self.port = self.server.add_insecure_port(listen_on) assert self.port != 0, f"Failed to listen to {listen_on}" if endpoint is not None and endpoint.endswith('*'): endpoint = replace_port(endpoint, self.port) self.node_info = dht_pb2.NodeInfo( node_id=node_id.to_bytes(), rpc_port=self.port, endpoint=endpoint or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value) await self.server.start() else: # not listening to incoming requests, client-only mode # note: use empty node_info so peers won't add you to their routing tables self.node_info, self.server, self.port = dht_pb2.NodeInfo( ), None, None if listen_on != '0.0.0.0:*' or len(kwargs) != 0: logger.warning( f"DHTProtocol has no server (due to listen=False), listen_on" f"and kwargs have no effect (unused kwargs: {kwargs})") return self
async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse: """ Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value) """ if request.peer: # if requested, add peer to the routing table asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context)) response = dht_pb2.FindResponse(results=[], peer=self.node_info) for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)): maybe_item = self.storage.get(key_id) cached_item = self.cache.get(key_id) if cached_item is not None and (maybe_item is None or cached_item.expiration_time > maybe_item.expiration_time): maybe_item = cached_item if maybe_item is None: # value not found item = dht_pb2.FindResult(type=dht_pb2.NOT_FOUND) elif isinstance(maybe_item.value, DictionaryDHTValue): item = dht_pb2.FindResult(type=dht_pb2.FOUND_DICTIONARY, value=self.serializer.dumps(maybe_item.value), expiration_time=maybe_item.expiration_time) else: # found regular value item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_item.value, expiration_time=maybe_item.expiration_time) for node_id, endpoint in self.routing_table.get_nearest_neighbors( key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)): item.nearest_node_ids.append(node_id.to_bytes()) item.nearest_endpoints.append(endpoint) response.results.append(item) return response
async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse: """ Some node wants us to store this (key, value) pair """ if request.peer: # if requested, add peer to the routing table asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context)) assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache) response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info) for key, tag, value_bytes, expiration_time, in_cache in zip( request.keys, request.subkeys, request.values, request.expiration_time, request.in_cache): key_id = DHTID.from_bytes(key) storage = self.cache if in_cache else self.storage if tag == self.IS_DICTIONARY: # store an entire dictionary with several subkeys value_dictionary = self.serializer.loads(value_bytes) assert isinstance(value_dictionary, DictionaryDHTValue) if not self._validate_dictionary(key, value_dictionary): response.store_ok.append(False) continue response.store_ok.append(all(storage.store_subkey(key_id, subkey, item.value, item.expiration_time) for subkey, item in value_dictionary.items())) continue if not self._validate_record(key, tag, value_bytes, expiration_time): response.store_ok.append(False) continue if tag == self.IS_REGULAR_VALUE: # store normal value without subkeys response.store_ok.append(storage.store(key_id, value_bytes, expiration_time)) else: # add a new entry into an existing dictionary value or create a new dictionary with one sub-key subkey = self.serializer.loads(tag) response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time)) return response
def test_routing_table_search(): for table_size, lower_active, upper_active in [(10, 10, 10), (10_000, 800, 1100)]: node_id = DHTID.generate() routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5) num_added = 0 total_nodes = 0 for phony_neighbor_port in random.sample(range(1_000_000), table_size): routing_table.add_or_update_node( DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}') new_total = sum( len(bucket.nodes_to_endpoint) for bucket in routing_table.buckets) num_added += new_total > total_nodes total_nodes = new_total
async def call_store( self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue], expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]], in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Sequence[bool]: """ Ask a recipient to store several (key, value : expiration_time) items or update their older value :param peer: request this peer to store the data :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key) :param values: a list of N serialized values (bytes) for each respective key :param expiration_time: a list of N expiration timestamps for each respective key-value pair (see get_dht_time()) :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size :return: list of [True / False] True = stored, False = failed (found newer value or no response) if peer did not respond (e.g. due to timeout or congestion), returns None """ if isinstance(expiration_time, DHTExpiration): expiration_time = [expiration_time] * len(keys) in_cache = in_cache if in_cache is not None else [False] * len( keys) # default value (None) in_cache = [in_cache] * len(keys) if isinstance( in_cache, bool) else in_cache # single bool keys, values, expiration_time, in_cache = map( list, [keys, values, expiration_time, in_cache]) assert len(keys) == len(values) == len(expiration_time) == len( in_cache), "Data is not aligned" store_request = dht_pb2.StoreRequest(keys=list( map(DHTID.to_bytes, keys)), values=values, expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info) try: async with self.rpc_semaphore: response = await self._get(peer).rpc_store( store_request, timeout=self.wait_timeout) if response.peer and response.peer.node_id: peer_id = DHTID.from_bytes(response.peer.node_id) asyncio.create_task( self.update_routing_table(peer_id, peer, responded=True)) return response.store_ok except grpc.experimental.aio.AioRpcError as error: logger.warning( f"DHTProtocol failed to store at {peer}: {error.code()}") asyncio.create_task( self.update_routing_table( self.routing_table.get(endpoint=peer), peer, responded=False)) return [False] * len(keys)
async def rpc_ping(self, peer_info: dht_pb2.NodeInfo, context: grpc.ServicerContext): """ Some node wants us to add it to our routing table. """ if peer_info.node_id and peer_info.rpc_port: sender_id = DHTID.from_bytes(peer_info.node_id) rpc_endpoint = replace_port(context.peer(), new_port=peer_info.rpc_port) asyncio.create_task( self.update_routing_table(sender_id, rpc_endpoint)) return self.node_info
def test_routing_table_parameters(): for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [ (20, 5, 45, 65), (50, 5, 35, 45), (20, 10, 650, 800), (20, 1, 7, 15), ]: node_id = DHTID.generate() routing_table = RoutingTable(node_id, bucket_size=bucket_size, depth_modulo=modulo) for phony_neighbor_port in random.sample(range(1_000_000), 10_000): routing_table.add_or_update_node( DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}') for bucket in routing_table.buckets: assert len(bucket.replacement_nodes) == 0 or len( bucket.nodes_to_endpoint) <= bucket.size assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, ( f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}" )
async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]], expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]], subkeys: Optional[Union[Subkey, Sequence[Optional[Subkey]]]] = None, in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Optional[List[bool]]: """ Ask a recipient to store several (key, value : expiration_time) items or update their older value :param peer: request this peer to store the data :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key) :param values: a list of N serialized values (bytes) for each respective key :param expiration_time: a list of N expiration timestamps for each respective key-value pair(see get_dht_time()) :param subkeys: a list of N optional sub-keys. If None, stores value normally. If not subkey is not None: 1) if local storage doesn't have :key:, create a new dictionary {subkey: (value, expiration_time)} 2) if local storage already has a dictionary under :key:, try add (subkey, value, exp_time) to that dictionary 2) if local storage associates :key: with a normal value with smaller expiration, clear :key: and perform (1) 3) finally, if local storage currently associates :key: with a normal value with larger expiration, do nothing :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size :return: list of [True / False] True = stored, False = failed (found newer value or no response) if peer did not respond (e.g. due to timeout or congestion), returns None """ if isinstance(expiration_time, DHTExpiration): expiration_time = [expiration_time] * len(keys) if subkeys is None: subkeys = [None] * len(keys) in_cache = in_cache if in_cache is not None else [False] * len(keys) # default value (None) in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache # single bool keys, subkeys, values, expiration_time, in_cache = map(list, [keys, subkeys, values, expiration_time, in_cache]) for i in range(len(keys)): if subkeys[i] is None: # add default sub-key if not specified subkeys[i] = self.IS_DICTIONARY if isinstance(values[i], DictionaryDHTValue) else self.IS_REGULAR_VALUE else: subkeys[i] = self.serializer.dumps(subkeys[i]) if isinstance(values[i], DictionaryDHTValue): assert subkeys[i] == self.IS_DICTIONARY, "Please don't specify subkey when storing an entire dictionary" values[i] = self.serializer.dumps(values[i]) assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned" store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), subkeys=subkeys, values=values, expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info) try: async with self.rpc_semaphore: response = await self._get_dht_stub(peer).rpc_store(store_request, timeout=self.wait_timeout) if response.peer and response.peer.node_id: peer_id = DHTID.from_bytes(response.peer.node_id) asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True)) return response.store_ok except grpc.aio.AioRpcError as error: logger.debug(f"DHTProtocol failed to store at {peer}: {error.code()}") asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False)) return None
def test_ids_basic(): # basic functionality tests for i in range(100): id1, id2 = DHTID.generate(), DHTID.generate() assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2, id2) == 0 assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2) assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes( id2.to_bytes()) == id2
def test_ids_depth(): for i in range(100): ids = [random.randint(0, 4096) for i in range(random.randint(1, 256))] ours = DHTID.longest_common_prefix_length(*map(DHTID, ids)) ids_bitstr = [ "".join( bin(bite)[2:].rjust(8, '0') for bite in uid.to_bytes(20, 'big')) for uid in ids ] reference = len(shared_prefix(*ids_bitstr)) assert reference == ours, f"ours {ours} != reference {reference}, ids: {ids}"
async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]: """ Get peer's node id and add him to the routing table. If peer doesn't respond, return None :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888 :param validate: if True, validates that node's endpoint is available :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table :return: node's DHTID, if peer responded and decided to send his node_id """ try: async with self.rpc_semaphore: ping_request = dht_pb2.PingRequest(peer=self.node_info, validate=validate) time_requested = get_dht_time() response = await self._get_dht_stub(peer).rpc_ping( ping_request, timeout=self.wait_timeout) time_responded = get_dht_time() except grpc.aio.AioRpcError as error: logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}") response = None responded = bool(response and response.peer and response.peer.node_id) if responded and validate: try: if self.server is not None and not response.available: raise ValidationError( f"Peer {peer} couldn't access this node at {response.sender_endpoint} . " f"Make sure that this port is open for incoming requests." ) if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value: if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \ response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS: raise ValidationError( f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds " f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})" ) except ValidationError as e: if strict: raise else: logger.warning(repr(e)) peer_id = DHTID.from_bytes( response.peer.node_id) if responded else None asyncio.create_task( self.update_routing_table(peer_id, peer, responded=responded)) return peer_id
def test_routing_table_basic(): node_id = DHTID.generate() routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5) added_nodes = [] for phony_neighbor_port in random.sample(range(10000), 100): phony_id = DHTID.generate() routing_table.add_or_update_node(phony_id, f'{LOCALHOST}:{phony_neighbor_port}') assert phony_id in routing_table assert f'{LOCALHOST}:{phony_neighbor_port}' in routing_table assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}' assert routing_table[f'{LOCALHOST}:{phony_neighbor_port}'] == phony_id added_nodes.append(phony_id) assert routing_table.buckets[ 0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX for bucket in routing_table.buckets: assert len( bucket.replacement_nodes ) == 0, "There should be no replacement nodes in a table with 100 entries" assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets) random_node = random.choice(added_nodes) assert routing_table.get(node_id=random_node) == routing_table[random_node] dummy_node = DHTID.generate() assert (dummy_node not in routing_table) == (routing_table.get(node_id=dummy_node) is None) for node in added_nodes: found_bucket_index = routing_table.get_bucket_index(node) for bucket_index, bucket in enumerate(routing_table.buckets): if bucket.lower <= node < bucket.upper: break else: raise ValueError( "Naive search could not find bucket. Universe has gone crazy.") assert bucket_index == found_bucket_index
async def _refresh_routing_table(self, *, period: Optional[float]) -> None: """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """ while self.is_alive and period is not None: # if None run once, otherwise run forever refresh_time = get_dht_time() staleness_threshold = refresh_time - period stale_buckets = [ bucket for bucket in self.protocol.routing_table.buckets if bucket.last_updated < staleness_threshold ] for bucket in stale_buckets: refresh_id = DHTID( random.randint(bucket.lower, bucket.upper - 1)) await self.find_nearest_nodes(refresh_id) await asyncio.sleep( max(0.0, period - (get_dht_time() - refresh_time)))
async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \ Optional[Dict[DHTID, Tuple[Optional[BinaryDHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]]]: """ Request keys from a peer. For each key, look for its (value, expiration time) locally and k additional peers that are most likely to have this key (ranked by XOR distance) :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors] value: value stored by the recipient with that key, or None if peer doesn't have this value expiration time: expiration time of the returned value, None if no value was found neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table If peer didn't respond, returns None """ keys = list(keys) find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info) try: async with self.rpc_semaphore: response = await self._get(peer).rpc_find( find_request, timeout=self.wait_timeout) if response.peer and response.peer.node_id: peer_id = DHTID.from_bytes(response.peer.node_id) asyncio.create_task( self.update_routing_table(peer_id, peer, responded=True)) assert len(response.values) == len(response.expiration_time) == len(response.nearest) == len(keys), \ "DHTProtocol: response is not aligned with keys and/or expiration times" output = {} # unpack data without special NOT_FOUND_* values for key, value, expiration_time, nearest in zip( keys, response.values, response.expiration_time, response.nearest): value = value if value != _NOT_FOUND_VALUE else None expiration_time = expiration_time if expiration_time != _NOT_FOUND_EXPIRATION else None nearest = dict( zip(map(DHTID.from_bytes, nearest.node_ids), nearest.endpoints)) output[key] = (value, expiration_time, nearest) return output except grpc.experimental.aio.AioRpcError as error: logger.warning( f"DHTProtocol failed to find at {peer}: {error.code()}") asyncio.create_task( self.update_routing_table( self.routing_table.get(endpoint=peer), peer, responded=False))
async def rpc_store( self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse: """ Some node wants us to store this (key, value) pair """ if request.peer: # if requested, add peer to the routing table asyncio.create_task(self.rpc_ping(request.peer, context)) assert len(request.keys) == len(request.values) == len( request.expiration_time) == len(request.in_cache) response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info) for key_bytes, value_bytes, expiration_time, in_cache in zip( request.keys, request.values, request.expiration_time, request.in_cache): local_memory = self.cache if in_cache else self.storage response.store_ok.append( local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time)) return response
async def call_ping(self, peer: Endpoint) -> Optional[DHTID]: """ Get peer's node id and add him to the routing table. If peer doesn't respond, return None :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888 :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table :return: node's DHTID, if peer responded and decided to send his node_id """ try: async with self.rpc_semaphore: peer_info = await self._get(peer).rpc_ping( self.node_info, timeout=self.wait_timeout) except grpc.experimental.aio.AioRpcError as error: logger.warning( f"DHTProtocol failed to ping {peer}: {error.code()}") peer_info = None responded = bool(peer_info and peer_info.node_id) peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None asyncio.create_task( self.update_routing_table(peer_id, peer, responded=responded)) return peer_id
async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse: """ Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value) """ if request.peer: # if requested, add peer to the routing table asyncio.create_task(self.rpc_ping(request.peer, context)) response = dht_pb2.FindResponse(values=[], expiration_time=[], nearest=[], peer=self.node_info) for key_id in map(DHTID.from_bytes, request.keys): maybe_value, maybe_expiration_time = self.storage.get(key_id) cached_value, cached_expiration_time = self.cache.get(key_id) if (cached_expiration_time or -float('inf')) > (maybe_expiration_time or -float('inf')): maybe_value, maybe_expiration_time = cached_value, cached_expiration_time nearest_neighbors = self.routing_table.get_nearest_neighbors( key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)) if nearest_neighbors: peer_ids, endpoints = zip(*nearest_neighbors) else: peer_ids, endpoints = [], [] response.values.append( maybe_value if maybe_value is not None else _NOT_FOUND_VALUE) response.expiration_time.append( maybe_expiration_time if maybe_expiration_time else _NOT_FOUND_EXPIRATION) response.nearest.append( dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)), endpoints=endpoints)) return response
async def simple_traverse_dht( query_id: DHTID, initial_nodes: Collection[DHTID], beam_size: int, get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]], visited_nodes: Collection[DHTID] = () ) -> Tuple[List[DHTID], Set[DHTID]]: """ Traverse the DHT graph using get_neighbors function, find :beam_size: nearest nodes according to DHTID.xor_distance. :note: This is a simplified (but working) algorithm provided for documentation purposes. Actual DHTNode uses `traverse_dht` - a generalization of this this algorithm that allows multiple queries and concurrent workers. :param query_id: search query, find k_nearest neighbors of this DHTID :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, ...maybe_some_peers] :param beam_size: beam search will not give up until it exhausts this many nearest nodes (to query_id) from the heap Recommended value: A beam size of k_nearest * (2-5) will yield near-perfect results. :param get_neighbors: A function that returns neighbors of a given node and controls beam search stopping criteria. async def get_neighbors(node: DHTID) -> neighbors_of_that_node: List[DHTID], should_continue: bool If should_continue is False, beam search will halt and return k_nearest of whatever it found by then. :param visited_nodes: beam search will neither call get_neighbors on these nodes, nor return them as nearest :returns: a list of k nearest nodes (nearest to farthest), and a set of all visited nodes (including visited_nodes) """ visited_nodes = set( visited_nodes ) # note: copy visited_nodes because we will add more nodes to this collection. initial_nodes = [ node_id for node_id in initial_nodes if node_id not in visited_nodes ] if not initial_nodes: return [], visited_nodes unvisited_nodes = [(distance, uid) for uid, distance in zip( initial_nodes, query_id.xor_distance(initial_nodes))] heapq.heapify( unvisited_nodes) # nearest-first heap of candidates, unlimited size nearest_nodes = [ (-distance, node_id) for distance, node_id in heapq.nsmallest(beam_size, unvisited_nodes) ] heapq.heapify( nearest_nodes ) # farthest-first heap of size beam_size, used for early-stopping and to select results while len(nearest_nodes) > beam_size: heapq.heappop(nearest_nodes) visited_nodes |= set(initial_nodes) upper_bound = -nearest_nodes[0][ 0] # distance to farthest element that is still in beam was_interrupted = False # will set to True if host triggered beam search to stop via get_neighbors while (not was_interrupted) and len( unvisited_nodes) != 0 and unvisited_nodes[0][0] <= upper_bound: _, node_id = heapq.heappop( unvisited_nodes ) # note: this --^ is the smallest element in heap (see heapq) neighbors, was_interrupted = await get_neighbors(node_id) neighbors = [ node_id for node_id in neighbors if node_id not in visited_nodes ] visited_nodes.update(neighbors) for neighbor_id, distance in zip(neighbors, query_id.xor_distance(neighbors)): if distance <= upper_bound or len(nearest_nodes) < beam_size: heapq.heappush(unvisited_nodes, (distance, neighbor_id)) heapq_add_or_replace = heapq.heappush if len( nearest_nodes) < beam_size else heapq.heappushpop heapq_add_or_replace(nearest_nodes, (-distance, neighbor_id)) upper_bound = -nearest_nodes[0][ 0] # distance to beam_size-th nearest element found so far return [ node_id for _, node_id in heapq.nlargest(beam_size, nearest_nodes) ], visited_nodes
async def get_many( self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None, num_workers: Optional[int] = None, beam_size: Optional[int] = None ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]: """ :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found) :param sufficient_expiration_time: if the search finds a value that expires after this time, default = time of call, find any value that did not expire by the time of call If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size :param num_workers: override for default num_workers, see traverse_dht num_workers param :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key :note: in order to check if get returned a value, please check (expiration_time is None) """ key_ids = [DHTID.generate(key) for key in keys] id_to_original_key = dict(zip(key_ids, keys)) sufficient_expiration_time = sufficient_expiration_time or get_dht_time( ) beam_size = beam_size if beam_size is not None else self.protocol.bucket_size num_workers = num_workers if num_workers is not None else self.num_workers # search metadata unfinished_key_ids = set( key_ids) # track key ids for which the search is not terminated node_to_endpoint: Dict[ DHTID, Endpoint] = dict() # global routing table for all queries SearchResult = namedtuple( "SearchResult", ["binary_value", "expiration_time", "source_node_id"]) latest_results = { key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids } # stage 1: value can be stored in our local cache for key_id in key_ids: maybe_value, maybe_expiration_time = self.protocol.storage.get( key_id) if maybe_expiration_time is None: maybe_value, maybe_expiration_time = self.protocol.cache.get( key_id) if maybe_expiration_time is not None and maybe_expiration_time > latest_results[ key_id].expiration_time: latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, self.node_id) if maybe_expiration_time >= sufficient_expiration_time: unfinished_key_ids.remove(key_id) # stage 2: traverse the DHT for any unfinished keys for key_id in unfinished_key_ids: node_to_endpoint.update( self.protocol.routing_table.get_nearest_neighbors( key_id, self.protocol.bucket_size, exclude=self.node_id)) async def get_neighbors( peer: DHTID, queries: Collection[DHTID] ) -> Dict[DHTID, Tuple[List[DHTID], bool]]: queries = list(queries) response = await self.protocol.call_find(node_to_endpoint[peer], queries) if not response: return {query: ([], False) for query in queries} output: Dict[DHTID, Tuple[List[DHTID], bool]] = {} for key_id, (maybe_value, maybe_expiration_time, peers) in response.items(): node_to_endpoint.update(peers) if maybe_expiration_time is not None and maybe_expiration_time > latest_results[ key_id].expiration_time: latest_results[key_id] = SearchResult( maybe_value, maybe_expiration_time, peer) should_interrupt = (latest_results[key_id].expiration_time >= sufficient_expiration_time) output[key_id] = list(peers.keys()), should_interrupt return output nearest_nodes_per_query, visited_nodes = await traverse_dht( queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids)**0.5), get_neighbors=get_neighbors, visited_nodes={ key_id: {self.node_id} for key_id in unfinished_key_ids }) # stage 3: cache any new results depending on caching parameters for key_id, nearest_nodes in nearest_nodes_per_query.items(): latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[ key_id] should_cache = latest_expiration_time >= sufficient_expiration_time # if we found a newer value, cache it if should_cache and self.cache_locally: self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration_time) if should_cache and self.cache_nearest: num_cached_nodes = 0 for node_id in nearest_nodes: if node_id == latest_node_id: continue asyncio.create_task( self.protocol.call_store(node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration_time], in_cache=True)) num_cached_nodes += 1 if num_cached_nodes >= self.cache_nearest: break # stage 4: deserialize data and assemble function output find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {} for key_id, (latest_value_bytes, latest_expiration_time, _) in latest_results.items(): if latest_expiration_time != -float('inf'): latest_value = self.serializer.loads(latest_value_bytes) find_result[id_to_original_key[key_id]] = ( latest_value, latest_expiration_time) else: find_result[id_to_original_key[key_id]] = None, None return find_result
async def store_many(self, keys: List[DHTKey], values: List[DHTValue], expiration_time: Union[DHTExpiration, List[DHTExpiration]], subkeys: Optional[Union[ Subkey, List[Optional[Subkey]]]] = None, exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]: """ Traverse DHT to find up :num_replicas: to best nodes to store multiple (key, value, expiration_time) pairs. :param keys: arbitrary serializable keys associated with each value :param values: serializable "payload" for each key :param expiration_time: either one expiration time for all keys or individual expiration times (see class doc) :param subkeys: an optional list of same shape as keys. If specified, this :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers) :param exclude_self: if True, never store value locally even if you are one of the nearest nodes :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally :param await_all_replicas: if False, this function returns after first store_ok and proceeds in background if True, the function will wait for num_replicas successful stores or running out of beam_size nodes :returns: for each key: True if store succeeds, False if it fails (due to no response or newer value) """ if isinstance(expiration_time, DHTExpiration): expiration_time = [expiration_time] * len(keys) if subkeys is None: subkeys = [None] * len(keys) assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \ "Either of keys, values, subkeys or expiration timestamps have different sequence lengths." key_id_to_data: DefaultDict[DHTID, List[Tuple[ DHTKey, Subkey, DHTValue, DHTExpiration]]] = defaultdict(list) for key, subkey, value, expiration in zip(keys, subkeys, values, expiration_time): key_id_to_data[DHTID.generate(source=key)].append( (key, subkey, value, expiration)) unfinished_key_ids = set(key_id_to_data.keys( )) # use this set to ensure that each store request is finished store_ok = {(key, subkey): None for key, subkey in zip(keys, subkeys) } # outputs, updated during search store_finished_events = {(key, subkey): asyncio.Event() for key, subkey in zip(keys, subkeys)} # pre-populate node_to_endpoint node_to_endpoint: Dict[DHTID, Endpoint] = dict() for key_id in unfinished_key_ids: node_to_endpoint.update( self.protocol.routing_table.get_nearest_neighbors( key_id, self.protocol.bucket_size, exclude=self.node_id)) async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set[DHTID]) -> None: """ This will be called once per key when find_nearest_nodes is done for a particular node """ # note: we use callbacks instead of returned values to call store immediately without waiting for stragglers assert key_id in unfinished_key_ids, "Internal error: traverse_dht finished the same query twice" assert self.node_id not in nearest_nodes unfinished_key_ids.remove(key_id) # ensure k nodes stored the value, optionally include self.node_id as a candidate num_successful_stores = 0 pending_store_tasks = set() store_candidates = sorted( nearest_nodes + ([] if exclude_self else [self.node_id]), key=key_id.xor_distance, reverse=True) # ordered so that .pop() returns nearest [original_key, *_], current_subkeys, current_values, current_expirations = zip( *key_id_to_data[key_id]) binary_values: List[bytes] = list( map(self.protocol.serializer.dumps, current_values)) while num_successful_stores < self.num_replicas and ( store_candidates or pending_store_tasks): while store_candidates and num_successful_stores + len( pending_store_tasks) < self.num_replicas: node_id: DHTID = store_candidates.pop( ) # nearest untried candidate if node_id == self.node_id: num_successful_stores += 1 for subkey, value, expiration_time in zip( current_subkeys, binary_values, current_expirations): store_ok[original_key, subkey] = self.protocol.storage.store( key_id, value, expiration_time, subkey=subkey) if not await_all_replicas: store_finished_events[original_key, subkey].set() else: pending_store_tasks.add( asyncio.create_task( self.protocol.call_store( node_to_endpoint[node_id], keys=[key_id] * len(current_values), values=binary_values, expiration_time=current_expirations, subkeys=current_subkeys))) # await nearest task. If it fails, dispatch more on the next iteration if pending_store_tasks: finished_store_tasks, pending_store_tasks = await asyncio.wait( pending_store_tasks, return_when=asyncio.FIRST_COMPLETED) for task in finished_store_tasks: if task.result() is not None: num_successful_stores += 1 for subkey, store_status in zip( current_subkeys, task.result()): store_ok[original_key, subkey] = store_status if not await_all_replicas: store_finished_events[original_key, subkey].set() if self.cache_on_store: self._update_cache_on_store(key_id, current_subkeys, binary_values, current_expirations, store_ok=[ store_ok[original_key, subkey] for subkey in current_subkeys ]) for subkey, value_bytes, expiration in zip(current_subkeys, binary_values, current_expirations): store_finished_events[original_key, subkey].set() store_task = asyncio.create_task( self.find_nearest_nodes(queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint, found_callback=on_found, exclude_self=exclude_self, **kwargs)) try: await asyncio.wait([ evt.wait() for evt in store_finished_events.values() ]) # wait for items to be stored assert len( unfinished_key_ids ) == 0, "Internal error: traverse_dht didn't finish search" return {(key, subkey) if subkey else key: status or False for (key, subkey), status in store_ok.items()} except asyncio.CancelledError as e: store_task.cancel() raise e
async def create(cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (), bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None, wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None, cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5, cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16, listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode: """ :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id :param initial_peers: connects to these peers to populate routing table, defaults to no peers :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to either split in two buckets along the midpoint or reject the new node (but still save it as a replacement) Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k) :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b) :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol Reduce this value if your RPC requests register no response despite the peer sending the response. :param wait_timeout: a kademlia rpc request is deemed lost if we did not receive a reply in this many seconds :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay) :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds :param cache_locally: if True, caches all values (stored or found) in a node-local cache :param cache_on_store: if True, update cache entries for a key after storing a new item for that key :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet :param cache_size: if specified, local cache will store up to this many records (as in LRU cache) :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values if they are accessed this many seconds before expiration time. :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key all concurrent get requests for the same key will reuse the procedure that is currently in progress :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param) :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen" if False, this node will refuse any incoming request, effectively being only a "client" :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654" :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)] see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options :param kwargs: extra parameters used in grpc.aio.server """ self = cls(_initialized_with_create=True) self.node_id = node_id = node_id if node_id is not None else DHTID.generate( ) self.num_replicas, self.num_workers, self.chunk_size = num_replicas, num_workers, chunk_size self.is_alive = True # if set to False, cancels all background jobs such as routing table refresh self.reuse_get_requests = reuse_get_requests self.pending_get_requests = defaultdict( partial(SortedList, key=lambda _res: -_res.sufficient_expiration_time)) # caching policy self.refresh_timeout = refresh_timeout self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store self.cache_refresh_before_expiry = cache_refresh_before_expiry self.cache_refresh_queue = CacheRefreshQueue() self.cache_refresh_evt = asyncio.Event() self.cache_refresh_task = None self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout, parallel_rpc, cache_size, listen, listen_on, **kwargs) self.port = self.protocol.port if initial_peers: # stage 1: ping initial_peers, add each other to the routing table bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout start_time = get_dht_time() ping_tasks = map(self.protocol.call_ping, initial_peers) finished_pings, unfinished_pings = await asyncio.wait( ping_tasks, return_when=asyncio.FIRST_COMPLETED) # stage 2: gather remaining peers (those who respond within bootstrap_timeout) if unfinished_pings: finished_in_time, stragglers = await asyncio.wait( unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time) for straggler in stragglers: straggler.cancel() finished_pings |= finished_in_time if not finished_pings: warn( "DHTNode bootstrap failed: none of the initial_peers responded to a ping." ) # stage 3: traverse dht to find my own nearest neighbors and populate the routing table # ... maybe receive some values that we are meant to store (see protocol.update_routing_table) # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout await asyncio.wait([ asyncio.create_task(self.find_nearest_nodes([self.node_id])), asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time) ], return_when=asyncio.FIRST_COMPLETED) if self.refresh_timeout is not None: asyncio.create_task( self._refresh_routing_table(period=self.refresh_timeout)) return self
num_added += new_total > total_nodes total_nodes = new_total num_replacements = sum( len(bucket.replacement_nodes) for bucket in routing_table.buckets) all_active_neighbors = list( chain(*(bucket.nodes_to_endpoint.keys() for bucket in routing_table.buckets))) assert lower_active <= len(all_active_neighbors) <= upper_active assert len(all_active_neighbors) == num_added assert num_added + num_replacements == table_size # random queries for i in range(1000): k = random.randint(1, 100) query_id = DHTID.generate() exclude = query_id if random.random() < 0.5 else None our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors( query_id, k=k, exclude=exclude)) reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance) assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn)) assert all( our_endpoint == routing_table[our_node] for our_node, our_endpoint in zip(our_knn, our_endpoints)) # queries from table for i in range(1000): k = random.randint(1, 100)