예제 #1
0
    async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str],
                              k: int, max_prefetch: int, chunk_size: int,
                              future: MPFuture):
        num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
        total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
        found: List[Tuple[str, RemoteExpert]] = []

        pending_tasks = deque(
            asyncio.create_task(
                node.get_many(uid_prefixes[chunk_i * chunk_size:(chunk_i + 1) *
                                           chunk_size],
                              num_workers=num_workers_per_chunk))
            for chunk_i in range(min(max_prefetch + 1, total_chunks))
        )  # pre-dispatch first task and up to max_prefetch additional tasks

        for chunk_i in range(total_chunks):
            # parse task results in chronological order, launch additional tasks on demand
            response = await pending_tasks.popleft()
            for uid_prefix in uid_prefixes[chunk_i * chunk_size:(chunk_i + 1) *
                                           chunk_size]:
                maybe_expert_data, maybe_expiration_time = response[uid_prefix]
                if maybe_expiration_time is not None:  # found active peer
                    found.append(
                        (uid_prefix, RemoteExpert(**maybe_expert_data)))
                    # if we found enough active experts, finish immediately
                    if len(found) >= k:
                        break
            if len(found) >= k:
                break

            pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
            if pre_dispatch_chunk_i < total_chunks:
                pending_tasks.append(
                    asyncio.create_task(
                        node.get_many(
                            uid_prefixes[pre_dispatch_chunk_i *
                                         chunk_size:(pre_dispatch_chunk_i +
                                                     1) * chunk_size],
                            num_workers=num_workers_per_chunk)))

        for task in pending_tasks:
            task.cancel()

        # return k active prefixes or as many as we could find
        future.set_result(OrderedDict(found))
예제 #2
0
def run_node(node_id, peers, status_pipe: mp.Pipe):
    if asyncio.get_event_loop().is_running():
        asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
        asyncio.set_event_loop(asyncio.new_event_loop())
    loop = asyncio.get_event_loop()
    node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
    status_pipe.send(node.port)
    while True:
        loop.run_forever()
예제 #3
0
 async def _get_active_successors(self, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
                                  num_workers: Optional[int] = None, future: Optional[MPFuture] = None
                                  ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
     grid_size = grid_size or float('inf')
     num_workers = num_workers or min(len(prefixes), self.max_workers or len(prefixes))
     dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
     successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
     for prefix, found in dht_responses.items():
         if found and isinstance(found.value, dict):
             successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items()
                                   if isinstance(coord, Coordinate) and 0 <= coord < grid_size
                                   and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
         else:
             successors[prefix] = {}
             if found is None and self.negative_caching:
                 logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
                 asyncio.create_task(node.store(prefix, subkey=-1, value=None,
                                                expiration_time=get_dht_time() + self.expiration))
     if future:
         future.set_result(successors)
     return successors
예제 #4
0
def test_dht_node():
    # create dht with 50 nodes + your 51-st node
    dht: Dict[Endpoint, DHTID] = {}
    processes: List[mp.Process] = []

    for i in range(50):
        node_id = DHTID.generate()
        peers = random.sample(dht.keys(), min(len(dht), 5))
        pipe_recv, pipe_send = mp.Pipe(duplex=False)
        proc = mp.Process(target=run_node,
                          args=(node_id, peers, pipe_send),
                          daemon=True)
        proc.start()
        port = pipe_recv.recv()
        processes.append(proc)
        dht[f"{LOCALHOST}:{port}"] = node_id

    loop = asyncio.get_event_loop()
    me = loop.run_until_complete(
        DHTNode.create(initial_peers=random.sample(dht.keys(), 5),
                       parallel_rpc=10,
                       cache_refresh_before_expiry=False))

    # test 1: find self
    nearest = loop.run_until_complete(
        me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
    assert len(nearest) == 1 and ':'.join(
        nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"

    # test 2: find others
    for i in range(10):
        ref_endpoint, query_id = random.choice(list(dht.items()))
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
        assert len(nearest) == 1
        found_node_id, found_endpoint = next(iter(nearest.items()))
        assert found_node_id == query_id and ':'.join(
            found_endpoint.split(':')[-2:]) == ref_endpoint

    # test 3: find neighbors to random nodes
    accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
    jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
    all_node_ids = list(dht.values())

    for i in range(10):
        query_id = DHTID.generate()
        k_nearest = random.randint(1, 10)
        exclude_self = random.random() > 0.5
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([query_id],
                                  k_nearest=k_nearest,
                                  exclude_self=exclude_self))[query_id]
        nearest_nodes = list(nearest)  # keys from ordered dict

        assert len(
            nearest_nodes
        ) == k_nearest, "beam search must return exactly k_nearest results"
        assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
        assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0
                      ), "results must be sorted by distance"

        ref_nearest = heapq.nsmallest(k_nearest + 1,
                                      all_node_ids,
                                      key=query_id.xor_distance)
        if exclude_self and me.node_id in ref_nearest:
            ref_nearest.remove(me.node_id)
        if len(ref_nearest) > k_nearest:
            ref_nearest.pop()

        accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
        accuracy_denominator += 1

        jaccard_numerator += len(
            set.intersection(set(nearest_nodes), set(ref_nearest)))
        jaccard_denominator += k_nearest

    accuracy = accuracy_numerator / accuracy_denominator
    print("Top-1 accuracy:", accuracy)  # should be 98-100%
    jaccard_index = jaccard_numerator / jaccard_denominator
    print("Jaccard index (intersection over union):",
          jaccard_index)  # should be 95-100%
    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
    assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"

    # test 4: find all nodes
    dummy = DHTID.generate()
    nearest = loop.run_until_complete(
        me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
    assert len(nearest) == len(dht) + 1
    assert len(
        set.difference(set(nearest.keys()),
                       set(all_node_ids) | {me.node_id})) == 0

    # test 5: node without peers
    detached_node = loop.run_until_complete(DHTNode.create())
    nearest = loop.run_until_complete(detached_node.find_nearest_nodes(
        [dummy]))[dummy]
    assert len(nearest) == 1 and nearest[
        detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
    nearest = loop.run_until_complete(
        detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
    assert len(nearest) == 0

    # test 6 store and get value
    true_time = get_dht_time() + 1200
    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
    that_guy = loop.run_until_complete(
        DHTNode.create(initial_peers=random.sample(dht.keys(), 3),
                       parallel_rpc=10,
                       cache_refresh_before_expiry=False,
                       cache_locally=False))

    for node in [me, that_guy]:
        val, expiration_time = loop.run_until_complete(node.get("mykey"))
        assert val == ["Value", 10], "Wrong value"
        assert expiration_time == true_time, f"Wrong time"

    assert loop.run_until_complete(detached_node.get("mykey")) is None

    # test 7: bulk store and bulk get
    keys = 'foo', 'bar', 'baz', 'zzz'
    values = 3, 2, 'batman', [1, 2, 3]
    store_ok = loop.run_until_complete(
        me.store_many(keys, values, expiration_time=get_dht_time() + 999))
    assert all(store_ok.values()), "failed to store one or more keys"
    response = loop.run_until_complete(me.get_many(keys[::-1]))
    for key, value in zip(keys, values):
        assert key in response and response[key][0] == value

    # test 8: store dictionaries as values (with sub-keys)
    upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
    now = get_dht_time()
    assert loop.run_until_complete(
        me.store(upper_key,
                 subkey=subkey1,
                 value=123,
                 expiration_time=now + 10))
    assert loop.run_until_complete(
        me.store(upper_key,
                 subkey=subkey2,
                 value=456,
                 expiration_time=now + 20))
    for node in [that_guy, me]:
        value, time = loop.run_until_complete(node.get(upper_key))
        assert isinstance(value, dict) and time == now + 20
        assert value[subkey1] == (123, now + 10)
        assert value[subkey2] == (456, now + 20)
        assert len(value) == 2

    assert not loop.run_until_complete(
        me.store(
            upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
    assert loop.run_until_complete(
        me.store(upper_key,
                 subkey=subkey2,
                 value=567,
                 expiration_time=now + 30))
    assert loop.run_until_complete(
        me.store(upper_key,
                 subkey=subkey3,
                 value=890,
                 expiration_time=now + 50))
    loop.run_until_complete(asyncio.sleep(0.1))  # wait for cache to refresh

    for node in [that_guy, me]:
        value, time = loop.run_until_complete(node.get(upper_key))
        assert isinstance(value, dict) and time == now + 50, (value, time)
        assert value[subkey1] == (123, now + 10)
        assert value[subkey2] == (567, now + 30)
        assert value[subkey3] == (890, now + 50)
        assert len(value) == 3

    for proc in processes:
        proc.terminate()
예제 #5
0
    def _tester():
        # note: we run everything in a separate process to re-initialize all global states from scratch
        # this helps us avoid undesirable side-effects when running multiple tests in sequence
        loop = asyncio.get_event_loop()
        me = loop.run_until_complete(
            DHTNode.create(initial_peers=random.sample(dht.keys(), 5),
                           parallel_rpc=10))

        # test 1: find self
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
        assert len(nearest) == 1 and ':'.join(
            nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"

        # test 2: find others
        for i in range(10):
            ref_endpoint, query_id = random.choice(list(dht.items()))
            nearest = loop.run_until_complete(
                me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
            assert len(nearest) == 1
            found_node_id, found_endpoint = next(iter(nearest.items()))
            assert found_node_id == query_id and ':'.join(
                found_endpoint.split(':')[-2:]) == ref_endpoint

        # test 3: find neighbors to random nodes
        accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
        jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
        all_node_ids = list(dht.values())

        for i in range(100):
            query_id = DHTID.generate()
            k_nearest = random.randint(1, 20)
            exclude_self = random.random() > 0.5
            nearest = loop.run_until_complete(
                me.find_nearest_nodes([query_id],
                                      k_nearest=k_nearest,
                                      exclude_self=exclude_self))[query_id]
            nearest_nodes = list(nearest)  # keys from ordered dict

            assert len(
                nearest_nodes
            ) == k_nearest, "beam search must return exactly k_nearest results"
            assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
            assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0
                          ), "results must be sorted by distance"

            ref_nearest = heapq.nsmallest(k_nearest + 1,
                                          all_node_ids,
                                          key=query_id.xor_distance)
            if exclude_self and me.node_id in ref_nearest:
                ref_nearest.remove(me.node_id)
            if len(ref_nearest) > k_nearest:
                ref_nearest.pop()

            accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
            accuracy_denominator += 1

            jaccard_numerator += len(
                set.intersection(set(nearest_nodes), set(ref_nearest)))
            jaccard_denominator += k_nearest

        accuracy = accuracy_numerator / accuracy_denominator
        print("Top-1 accuracy:", accuracy)  # should be 98-100%
        jaccard_index = jaccard_numerator / jaccard_denominator
        print("Jaccard index (intersection over union):",
              jaccard_index)  # should be 95-100%
        assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
        assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"

        # test 4: find all nodes
        dummy = DHTID.generate()
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
        assert len(nearest) == len(dht) + 1
        assert len(
            set.difference(set(nearest.keys()),
                           set(all_node_ids) | {me.node_id})) == 0

        # test 5: node without peers
        other_node = loop.run_until_complete(DHTNode.create())
        nearest = loop.run_until_complete(
            other_node.find_nearest_nodes([dummy]))[dummy]
        assert len(nearest) == 1 and nearest[
            other_node.node_id] == f"{LOCALHOST}:{other_node.port}"
        nearest = loop.run_until_complete(
            other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
        assert len(nearest) == 0

        # test 6 store and get value
        true_time = get_dht_time() + 1200
        assert loop.run_until_complete(
            me.store("mykey", ["Value", 10], true_time))
        for node in [me, other_node]:
            val, expiration_time = loop.run_until_complete(me.get("mykey"))
            assert expiration_time == true_time, "Wrong time"
            assert val == ["Value", 10], "Wrong value"

        # test 7: bulk store and bulk get
        keys = 'foo', 'bar', 'baz', 'zzz'
        values = 3, 2, 'batman', [1, 2, 3]
        store_ok = loop.run_until_complete(
            me.store_many(keys, values, expiration_time=get_dht_time() + 999))
        assert all(store_ok.values()), "failed to store one or more keys"
        response = loop.run_until_complete(me.get_many(keys[::-1]))
        for key, value in zip(keys, values):
            assert key in response and response[key][0] == value

        test_success.set()