async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture): num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size) total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1 found: List[Tuple[str, RemoteExpert]] = [] pending_tasks = deque( asyncio.create_task( node.get_many(uid_prefixes[chunk_i * chunk_size:(chunk_i + 1) * chunk_size], num_workers=num_workers_per_chunk)) for chunk_i in range(min(max_prefetch + 1, total_chunks)) ) # pre-dispatch first task and up to max_prefetch additional tasks for chunk_i in range(total_chunks): # parse task results in chronological order, launch additional tasks on demand response = await pending_tasks.popleft() for uid_prefix in uid_prefixes[chunk_i * chunk_size:(chunk_i + 1) * chunk_size]: maybe_expert_data, maybe_expiration_time = response[uid_prefix] if maybe_expiration_time is not None: # found active peer found.append( (uid_prefix, RemoteExpert(**maybe_expert_data))) # if we found enough active experts, finish immediately if len(found) >= k: break if len(found) >= k: break pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1 if pre_dispatch_chunk_i < total_chunks: pending_tasks.append( asyncio.create_task( node.get_many( uid_prefixes[pre_dispatch_chunk_i * chunk_size:(pre_dispatch_chunk_i + 1) * chunk_size], num_workers=num_workers_per_chunk))) for task in pending_tasks: task.cancel() # return k active prefixes or as many as we could find future.set_result(OrderedDict(found))
def run_node(node_id, peers, status_pipe: mp.Pipe): if asyncio.get_event_loop().is_running(): asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop asyncio.set_event_loop(asyncio.new_event_loop()) loop = asyncio.get_event_loop() node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers)) status_pipe.send(node.port) while True: loop.run_forever()
async def _get_active_successors(self, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, num_workers: Optional[int] = None, future: Optional[MPFuture] = None ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]: grid_size = grid_size or float('inf') num_workers = num_workers or min(len(prefixes), self.max_workers or len(prefixes)) dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers) successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {} for prefix, found in dht_responses.items(): if found and isinstance(found.value, dict): successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items() if isinstance(coord, Coordinate) and 0 <= coord < grid_size and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2} else: successors[prefix] = {} if found is None and self.negative_caching: logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}") asyncio.create_task(node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + self.expiration)) if future: future.set_result(successors) return successors
def test_dht_node(): # create dht with 50 nodes + your 51-st node dht: Dict[Endpoint, DHTID] = {} processes: List[mp.Process] = [] for i in range(50): node_id = DHTID.generate() peers = random.sample(dht.keys(), min(len(dht), 5)) pipe_recv, pipe_send = mp.Pipe(duplex=False) proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True) proc.start() port = pipe_recv.recv() processes.append(proc) dht[f"{LOCALHOST}:{port}"] = node_id loop = asyncio.get_event_loop() me = loop.run_until_complete( DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10, cache_refresh_before_expiry=False)) # test 1: find self nearest = loop.run_until_complete( me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id] assert len(nearest) == 1 and ':'.join( nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}" # test 2: find others for i in range(10): ref_endpoint, query_id = random.choice(list(dht.items())) nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=1))[query_id] assert len(nearest) == 1 found_node_id, found_endpoint = next(iter(nearest.items())) assert found_node_id == query_id and ':'.join( found_endpoint.split(':')[-2:]) == ref_endpoint # test 3: find neighbors to random nodes accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union all_node_ids = list(dht.values()) for i in range(10): query_id = DHTID.generate() k_nearest = random.randint(1, 10) exclude_self = random.random() > 0.5 nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id] nearest_nodes = list(nearest) # keys from ordered dict assert len( nearest_nodes ) == k_nearest, "beam search must return exactly k_nearest results" assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self" assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0 ), "results must be sorted by distance" ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance) if exclude_self and me.node_id in ref_nearest: ref_nearest.remove(me.node_id) if len(ref_nearest) > k_nearest: ref_nearest.pop() accuracy_numerator += nearest_nodes[0] == ref_nearest[0] accuracy_denominator += 1 jaccard_numerator += len( set.intersection(set(nearest_nodes), set(ref_nearest))) jaccard_denominator += k_nearest accuracy = accuracy_numerator / accuracy_denominator print("Top-1 accuracy:", accuracy) # should be 98-100% jaccard_index = jaccard_numerator / jaccard_denominator print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100% assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" # test 4: find all nodes dummy = DHTID.generate() nearest = loop.run_until_complete( me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy] assert len(nearest) == len(dht) + 1 assert len( set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0 # test 5: node without peers detached_node = loop.run_until_complete(DHTNode.create()) nearest = loop.run_until_complete(detached_node.find_nearest_nodes( [dummy]))[dummy] assert len(nearest) == 1 and nearest[ detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}" nearest = loop.run_until_complete( detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy] assert len(nearest) == 0 # test 6 store and get value true_time = get_dht_time() + 1200 assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time)) that_guy = loop.run_until_complete( DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10, cache_refresh_before_expiry=False, cache_locally=False)) for node in [me, that_guy]: val, expiration_time = loop.run_until_complete(node.get("mykey")) assert val == ["Value", 10], "Wrong value" assert expiration_time == true_time, f"Wrong time" assert loop.run_until_complete(detached_node.get("mykey")) is None # test 7: bulk store and bulk get keys = 'foo', 'bar', 'baz', 'zzz' values = 3, 2, 'batman', [1, 2, 3] store_ok = loop.run_until_complete( me.store_many(keys, values, expiration_time=get_dht_time() + 999)) assert all(store_ok.values()), "failed to store one or more keys" response = loop.run_until_complete(me.get_many(keys[::-1])) for key, value in zip(keys, values): assert key in response and response[key][0] == value # test 8: store dictionaries as values (with sub-keys) upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3' now = get_dht_time() assert loop.run_until_complete( me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10)) assert loop.run_until_complete( me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20)) for node in [that_guy, me]: value, time = loop.run_until_complete(node.get(upper_key)) assert isinstance(value, dict) and time == now + 20 assert value[subkey1] == (123, now + 10) assert value[subkey2] == (456, now + 20) assert len(value) == 2 assert not loop.run_until_complete( me.store( upper_key, subkey=subkey2, value=345, expiration_time=now + 10)) assert loop.run_until_complete( me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30)) assert loop.run_until_complete( me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50)) loop.run_until_complete(asyncio.sleep(0.1)) # wait for cache to refresh for node in [that_guy, me]: value, time = loop.run_until_complete(node.get(upper_key)) assert isinstance(value, dict) and time == now + 50, (value, time) assert value[subkey1] == (123, now + 10) assert value[subkey2] == (567, now + 30) assert value[subkey3] == (890, now + 50) assert len(value) == 3 for proc in processes: proc.terminate()
def _tester(): # note: we run everything in a separate process to re-initialize all global states from scratch # this helps us avoid undesirable side-effects when running multiple tests in sequence loop = asyncio.get_event_loop() me = loop.run_until_complete( DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10)) # test 1: find self nearest = loop.run_until_complete( me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id] assert len(nearest) == 1 and ':'.join( nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}" # test 2: find others for i in range(10): ref_endpoint, query_id = random.choice(list(dht.items())) nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=1))[query_id] assert len(nearest) == 1 found_node_id, found_endpoint = next(iter(nearest.items())) assert found_node_id == query_id and ':'.join( found_endpoint.split(':')[-2:]) == ref_endpoint # test 3: find neighbors to random nodes accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union all_node_ids = list(dht.values()) for i in range(100): query_id = DHTID.generate() k_nearest = random.randint(1, 20) exclude_self = random.random() > 0.5 nearest = loop.run_until_complete( me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id] nearest_nodes = list(nearest) # keys from ordered dict assert len( nearest_nodes ) == k_nearest, "beam search must return exactly k_nearest results" assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self" assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0 ), "results must be sorted by distance" ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance) if exclude_self and me.node_id in ref_nearest: ref_nearest.remove(me.node_id) if len(ref_nearest) > k_nearest: ref_nearest.pop() accuracy_numerator += nearest_nodes[0] == ref_nearest[0] accuracy_denominator += 1 jaccard_numerator += len( set.intersection(set(nearest_nodes), set(ref_nearest))) jaccard_denominator += k_nearest accuracy = accuracy_numerator / accuracy_denominator print("Top-1 accuracy:", accuracy) # should be 98-100% jaccard_index = jaccard_numerator / jaccard_denominator print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100% assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" # test 4: find all nodes dummy = DHTID.generate() nearest = loop.run_until_complete( me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy] assert len(nearest) == len(dht) + 1 assert len( set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0 # test 5: node without peers other_node = loop.run_until_complete(DHTNode.create()) nearest = loop.run_until_complete( other_node.find_nearest_nodes([dummy]))[dummy] assert len(nearest) == 1 and nearest[ other_node.node_id] == f"{LOCALHOST}:{other_node.port}" nearest = loop.run_until_complete( other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy] assert len(nearest) == 0 # test 6 store and get value true_time = get_dht_time() + 1200 assert loop.run_until_complete( me.store("mykey", ["Value", 10], true_time)) for node in [me, other_node]: val, expiration_time = loop.run_until_complete(me.get("mykey")) assert expiration_time == true_time, "Wrong time" assert val == ["Value", 10], "Wrong value" # test 7: bulk store and bulk get keys = 'foo', 'bar', 'baz', 'zzz' values = 3, 2, 'batman', [1, 2, 3] store_ok = loop.run_until_complete( me.store_many(keys, values, expiration_time=get_dht_time() + 999)) assert all(store_ok.values()), "failed to store one or more keys" response = loop.run_until_complete(me.get_many(keys[::-1])) for key, value in zip(keys, values): assert key in response and response[key][0] == value test_success.set()