def test_get_expired(): d = LocalStorage() d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1) time.sleep(0.5) assert d.get( DHTID.generate("key")) == (None, None), "Expired value must be deleted" print("Test get expired passed")
def test_maxsize_cache(): d = LocalStorage(maxsize=1) d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1) d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200) assert d.get(DHTID.generate( "key2"))[0] == b"val2", "Value with bigger exp. time must be kept" assert d.get(DHTID.generate( "key1"))[0] is None, "Value with less exp time, must be deleted"
def test_change_expiration_time(): d = LocalStorage() d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1) assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value" d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200) time.sleep(1) assert d.get(DHTID.generate( "key"))[0] == b"val2", "Value must be changed, but still kept in table" print("Test change expiration time passed")
def test_ids_basic(): # basic functionality tests for i in range(100): id1, id2 = DHTID.generate(), DHTID.generate() assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2, id2) == 0 assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2) assert len(PickleSerializer.dumps(id1)) - len( PickleSerializer.dumps(int(id1))) < 40 assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes( id2.to_bytes()) == id2
def test_routing_table_search(): for table_size, lower_active, upper_active in [(10, 10, 10), (10_000, 800, 1100)]: node_id = DHTID.generate() routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5) num_added = 0 total_nodes = 0 for phony_neighbor_port in random.sample(range(1_000_000), table_size): routing_table.add_or_update_node( DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}') new_total = sum( len(bucket.nodes_to_endpoint) for bucket in routing_table.buckets) num_added += new_total > total_nodes total_nodes = new_total
def test_empty_table(): """ Test RPC methods with empty routing table """ peer_port, peer_id, peer_started = find_open_port(), DHTID.generate( ), mp.Event() peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True) peer_proc.start(), peer_started.wait() test_success = mp.Event() def _tester(): # note: we run everything in a separate process to re-initialize all global states from scratch # this helps us avoid undesirable side-effects when running multiple tests in sequence loop = asyncio.get_event_loop() protocol = loop.run_until_complete( DHTProtocol.create(DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False)) key, value, expiration = DHTID.generate(), [ random.random(), { 'ololo': 'pyshpysh' } ], get_dht_time() + 1e3 recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key] assert recv_value_bytes is None and recv_expiration is None and len( nodes_found) == 0 assert all( loop.run_until_complete( protocol.call_store(f'{LOCALHOST}:{peer_port}', [key], [MSGPackSerializer.dumps(value)], expiration))), "peer rejected store" recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key] recv_value = MSGPackSerializer.loads(recv_value_bytes) assert len(nodes_found) == 0 assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \ f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})" assert loop.run_until_complete( protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id assert loop.run_until_complete( protocol.call_ping(f'{LOCALHOST}:{find_open_port()}')) is None test_success.set() tester = mp.Process(target=_tester, daemon=True) tester.start() tester.join() assert test_success.is_set() peer_proc.terminate()
def test_routing_table_parameters(): for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [ (20, 5, 45, 65), (50, 5, 35, 45), (20, 10, 650, 800), (20, 1, 7, 15), ]: node_id = DHTID.generate() routing_table = RoutingTable(node_id, bucket_size=bucket_size, depth_modulo=modulo) for phony_neighbor_port in random.sample(range(1_000_000), 10_000): routing_table.add_or_update_node( DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}') for bucket in routing_table.buckets: assert len(bucket.replacement_nodes) == 0 or len( bucket.nodes_to_endpoint) <= bucket.size assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, ( f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}" )
def test_routing_table_basic(): node_id = DHTID.generate() routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5) added_nodes = [] for phony_neighbor_port in random.sample(range(10000), 100): phony_id = DHTID.generate() routing_table.add_or_update_node(phony_id, f'{LOCALHOST}:{phony_neighbor_port}') assert phony_id in routing_table assert f'{LOCALHOST}:{phony_neighbor_port}' in routing_table assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}' assert routing_table[f'{LOCALHOST}:{phony_neighbor_port}'] == phony_id added_nodes.append(phony_id) assert routing_table.buckets[ 0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX for bucket in routing_table.buckets: assert len( bucket.replacement_nodes ) == 0, "There should be no replacement nodes in a table with 100 entries" assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets) random_node = random.choice(added_nodes) assert routing_table.get(node_id=random_node) == routing_table[random_node] dummy_node = DHTID.generate() assert (dummy_node not in routing_table) == (routing_table.get(node_id=dummy_node) is None) for node in added_nodes: found_bucket_index = routing_table.get_bucket_index(node) for bucket_index, bucket in enumerate(routing_table.buckets): if bucket.lower <= node < bucket.upper: break else: raise ValueError( "Naive search could not find bucket. Universe has gone crazy.") assert bucket_index == found_bucket_index
async def create(cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (), bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None, wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None, num_workers: int = 1, cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode: """ :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id :param initial_peers: connects to these peers to populate routing table, defaults to no peers :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to either split in two buckets along the midpoint or reject the new node (but still save it as a replacement) Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k) :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b) :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol Reduce this value if your RPC requests register no response despite the peer sending the response. :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay) :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param) :param cache_locally: if True, caches all values (stored or found) in a node-local cache :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet :param cache_size: if specified, local cache will store up to this many records (as in LRU cache) :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen" if False, this node will refuse any incoming request, effectively being only a "client" :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654" :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)] see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options :param kwargs: extra parameters used in grpc.aio.server """ self = cls(_initialized_with_create=True) self.node_id = node_id = node_id if node_id is not None else DHTID.generate( ) self.num_replicas, self.num_workers = num_replicas, num_workers self.cache_locally, self.cache_nearest = cache_locally, cache_nearest self.refresh_timeout = refresh_timeout self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout, parallel_rpc, cache_size, listen, listen_on, **kwargs) self.port = self.protocol.port if initial_peers: # stage 1: ping initial_peers, add each other to the routing table bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout start_time = get_dht_time() ping_tasks = map(self.protocol.call_ping, initial_peers) finished_pings, unfinished_pings = await asyncio.wait( ping_tasks, return_when=asyncio.FIRST_COMPLETED) # stage 2: gather remaining peers (those who respond within bootstrap_timeout) if unfinished_pings: finished_in_time, stragglers = await asyncio.wait( unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time) for straggler in stragglers: straggler.cancel() finished_pings |= finished_in_time if not finished_pings: warn( "DHTNode bootstrap failed: none of the initial_peers responded to a ping." ) # stage 3: traverse dht to find my own nearest neighbors and populate the routing table # ... maybe receive some values that we are meant to store (see protocol.update_routing_table) # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout await asyncio.wait([ asyncio.create_task(self.find_nearest_nodes([self.node_id])), asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time) ], return_when=asyncio.FIRST_COMPLETED) if self.refresh_timeout is not None: asyncio.create_task( self._refresh_routing_table(period=self.refresh_timeout)) return self
async def get_many( self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None, num_workers: Optional[int] = None, beam_size: Optional[int] = None ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]: """ :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found) :param sufficient_expiration_time: if the search finds a value that expires after this time, default = time of call, find any value that did not expire by the time of call If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size :param num_workers: override for default num_workers, see traverse_dht num_workers param :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key :note: in order to check if get returned a value, please check (expiration_time is None) """ key_ids = [DHTID.generate(key) for key in keys] id_to_original_key = dict(zip(key_ids, keys)) sufficient_expiration_time = sufficient_expiration_time or get_dht_time( ) beam_size = beam_size if beam_size is not None else self.protocol.bucket_size num_workers = num_workers if num_workers is not None else self.num_workers # search metadata unfinished_key_ids = set( key_ids) # track key ids for which the search is not terminated node_to_endpoint: Dict[ DHTID, Endpoint] = dict() # global routing table for all queries SearchResult = namedtuple( "SearchResult", ["binary_value", "expiration_time", "source_node_id"]) latest_results = { key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids } # stage 1: value can be stored in our local cache for key_id in key_ids: maybe_value, maybe_expiration_time = self.protocol.storage.get( key_id) if maybe_expiration_time is None: maybe_value, maybe_expiration_time = self.protocol.cache.get( key_id) if maybe_expiration_time is not None and maybe_expiration_time > latest_results[ key_id].expiration_time: latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, self.node_id) if maybe_expiration_time >= sufficient_expiration_time: unfinished_key_ids.remove(key_id) # stage 2: traverse the DHT for any unfinished keys for key_id in unfinished_key_ids: node_to_endpoint.update( self.protocol.routing_table.get_nearest_neighbors( key_id, self.protocol.bucket_size, exclude=self.node_id)) async def get_neighbors( peer: DHTID, queries: Collection[DHTID] ) -> Dict[DHTID, Tuple[List[DHTID], bool]]: queries = list(queries) response = await self.protocol.call_find(node_to_endpoint[peer], queries) if not response: return {query: ([], False) for query in queries} output: Dict[DHTID, Tuple[List[DHTID], bool]] = {} for key_id, (maybe_value, maybe_expiration_time, peers) in response.items(): node_to_endpoint.update(peers) if maybe_expiration_time is not None and maybe_expiration_time > latest_results[ key_id].expiration_time: latest_results[key_id] = SearchResult( maybe_value, maybe_expiration_time, peer) should_interrupt = (latest_results[key_id].expiration_time >= sufficient_expiration_time) output[key_id] = list(peers.keys()), should_interrupt return output nearest_nodes_per_query, visited_nodes = await traverse_dht( queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids)**0.5), get_neighbors=get_neighbors, visited_nodes={ key_id: {self.node_id} for key_id in unfinished_key_ids }) # stage 3: cache any new results depending on caching parameters for key_id, nearest_nodes in nearest_nodes_per_query.items(): latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[ key_id] should_cache = latest_expiration_time >= sufficient_expiration_time # if we found a newer value, cache it if should_cache and self.cache_locally: self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration_time) if should_cache and self.cache_nearest: num_cached_nodes = 0 for node_id in nearest_nodes: if node_id == latest_node_id: continue asyncio.create_task( self.protocol.call_store(node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration_time], in_cache=True)) num_cached_nodes += 1 if num_cached_nodes >= self.cache_nearest: break # stage 4: deserialize data and assemble function output find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {} for key_id, (latest_value_bytes, latest_expiration_time, _) in latest_results.items(): if latest_expiration_time != -float('inf'): latest_value = self.serializer.loads(latest_value_bytes) find_result[id_to_original_key[key_id]] = ( latest_value, latest_expiration_time) else: find_result[id_to_original_key[key_id]] = None, None return find_result
def _tester(): # note: we run everything in a separate process to re-initialize all global states from scratch # this helps us avoid undesirable side-effects when running multiple tests in sequence loop = asyncio.get_event_loop() for listen in [ False, True ]: # note: order matters, this test assumes that first run uses listen=False protocol = loop.run_until_complete( DHTProtocol.create(DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen)) print(f"Self id={protocol.node_id}", flush=True) assert loop.run_until_complete( protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id key, value, expiration = DHTID.generate(), [ random.random(), { 'ololo': 'pyshpysh' } ], get_dht_time() + 1e3 store_ok = loop.run_until_complete( protocol.call_store(f'{LOCALHOST}:{peer1_port}', [key], [MSGPackSerializer.dumps(value)], expiration)) assert all(store_ok), "DHT rejected a trivial store" # peer 1 must know about peer 2 recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key] recv_value = MSGPackSerializer.loads(recv_value_bytes) (recv_id, recv_endpoint) = next(iter(nodes_found.items())) assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \ f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}" assert recv_value == value and recv_expiration == expiration, \ f"call_find_value expected {value} (expires by {expiration}) " \ f"but got {recv_value} (expires by {recv_expiration})" # peer 2 must know about peer 1, but not have a *random* nonexistent value dummy_key = DHTID.generate() recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete( protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key] assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values" (recv_id, recv_endpoint) = next(iter(nodes_found_2.items())) assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \ f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}" # cause a non-response by querying a nonexistent peer dummy_port = find_open_port() assert loop.run_until_complete( protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None if listen: loop.run_until_complete(protocol.shutdown()) print("DHTProtocol test finished successfully!") test_success.set()
def test_get_empty(): d = LocalStorage() assert d.get(DHTID.generate( source="key")) == (None, None), "LocalStorage returned non-existent value" print("Test get expired passed")
def test_store(): d = LocalStorage() d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5) assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value" print("Test store passed")
def _tester(): # note: we run everything in a separate process to re-initialize all global states from scratch # this helps us avoid undesirable side-effects when running multiple tests in sequence loop = asyncio.get_event_loop() me = loop.run_until_complete( DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10)) # test 1: find self nearest = loop.run_until_complete( me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id] assert len(nearest) == 1 and ':'.join( nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}" # test 2: find others for i in range(10): ref_endpoint, query_id = random.choice(list(dht.items())) nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=1))[query_id] assert len(nearest) == 1 found_node_id, found_endpoint = next(iter(nearest.items())) assert found_node_id == query_id and ':'.join( found_endpoint.split(':')[-2:]) == ref_endpoint # test 3: find neighbors to random nodes accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union all_node_ids = list(dht.values()) for i in range(100): query_id = DHTID.generate() k_nearest = random.randint(1, 20) exclude_self = random.random() > 0.5 nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id] nearest_nodes = list(nearest) # keys from ordered dict assert len( nearest_nodes ) == k_nearest, "beam search must return exactly k_nearest results" assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self" assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0 ), "results must be sorted by distance" ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance) if exclude_self and me.node_id in ref_nearest: ref_nearest.remove(me.node_id) if len(ref_nearest) > k_nearest: ref_nearest.pop() accuracy_numerator += nearest_nodes[0] == ref_nearest[0] accuracy_denominator += 1 jaccard_numerator += len( set.intersection(set(nearest_nodes), set(ref_nearest))) jaccard_denominator += k_nearest accuracy = accuracy_numerator / accuracy_denominator print("Top-1 accuracy:", accuracy) # should be 98-100% jaccard_index = jaccard_numerator / jaccard_denominator print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100% assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" # test 4: find all nodes dummy = DHTID.generate() nearest = loop.run_until_complete( me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy] assert len(nearest) == len(dht) + 1 assert len( set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0 # test 5: node without peers other_node = loop.run_until_complete(DHTNode.create()) nearest = loop.run_until_complete( other_node.find_nearest_nodes([dummy]))[dummy] assert len(nearest) == 1 and nearest[ other_node.node_id] == f"{LOCALHOST}:{other_node.port}" nearest = loop.run_until_complete( other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy] assert len(nearest) == 0 # test 6 store and get value true_time = get_dht_time() + 1200 assert loop.run_until_complete( me.store("mykey", ["Value", 10], true_time)) for node in [me, other_node]: val, expiration_time = loop.run_until_complete(me.get("mykey")) assert expiration_time == true_time, "Wrong time" assert val == ["Value", 10], "Wrong value" # test 7: bulk store and bulk get keys = 'foo', 'bar', 'baz', 'zzz' values = 3, 2, 'batman', [1, 2, 3] store_ok = loop.run_until_complete( me.store_many(keys, values, expiration_time=get_dht_time() + 999)) assert all(store_ok.values()), "failed to store one or more keys" response = loop.run_until_complete(me.get_many(keys[::-1])) for key, value in zip(keys, values): assert key in response and response[key][0] == value test_success.set()
num_added += new_total > total_nodes total_nodes = new_total num_replacements = sum( len(bucket.replacement_nodes) for bucket in routing_table.buckets) all_active_neighbors = list( chain(*(bucket.nodes_to_endpoint.keys() for bucket in routing_table.buckets))) assert lower_active <= len(all_active_neighbors) <= upper_active assert len(all_active_neighbors) == num_added assert num_added + num_replacements == table_size # random queries for i in range(1000): k = random.randint(1, 100) query_id = DHTID.generate() exclude = query_id if random.random() < 0.5 else None our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors( query_id, k=k, exclude=exclude)) reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance) assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn)) assert all( our_endpoint == routing_table[our_node] for our_node, our_endpoint in zip(our_knn, our_endpoints)) # queries from table for i in range(1000): k = random.randint(1, 100)