Esempio n. 1
0
def test_get_expired():
    d = LocalStorage()
    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
    time.sleep(0.5)
    assert d.get(
        DHTID.generate("key")) == (None, None), "Expired value must be deleted"
    print("Test get expired passed")
Esempio n. 2
0
def test_maxsize_cache():
    d = LocalStorage(maxsize=1)
    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
    assert d.get(DHTID.generate(
        "key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
    assert d.get(DHTID.generate(
        "key1"))[0] is None, "Value with less exp time, must be deleted"
Esempio n. 3
0
def test_change_expiration_time():
    d = LocalStorage()
    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
    time.sleep(1)
    assert d.get(DHTID.generate(
        "key"))[0] == b"val2", "Value must be changed, but still kept in table"
    print("Test change expiration time passed")
Esempio n. 4
0
async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], beam_size: int,
                              get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
                              visited_nodes: Collection[DHTID] = ()) -> Tuple[List[DHTID], Set[DHTID]]:
    """
    Traverse the DHT graph using get_neighbors function, find :beam_size: nearest nodes according to DHTID.xor_distance.

    :note: This is a simplified (but working) algorithm provided for documentation purposes. Actual DHTNode uses
       `traverse_dht` - a generalization of this this algorithm that allows multiple queries and concurrent workers.

    :param query_id: search query, find k_nearest neighbors of this DHTID
    :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, ...maybe_some_peers]
    :param beam_size: beam search will not give up until it exhausts this many nearest nodes (to query_id) from the heap
        Recommended value: A beam size of k_nearest * (2-5) will yield near-perfect results.
    :param get_neighbors: A function that returns neighbors of a given node and controls beam search stopping criteria.
        async def get_neighbors(node: DHTID) -> neighbors_of_that_node: List[DHTID], should_continue: bool
        If should_continue is False, beam search will halt and return k_nearest of whatever it found by then.
    :param visited_nodes: beam search will neither call get_neighbors on these nodes, nor return them as nearest
    :returns: a list of k nearest nodes (nearest to farthest), and a set of all visited nodes (including visited_nodes)
    """
    visited_nodes = set(visited_nodes)  # note: copy visited_nodes because we will add more nodes to this collection.
    initial_nodes = [node_id for node_id in initial_nodes if node_id not in visited_nodes]
    if not initial_nodes:
        return [], visited_nodes

    unvisited_nodes = [(distance, uid) for uid, distance in zip(initial_nodes, query_id.xor_distance(initial_nodes))]
    heapq.heapify(unvisited_nodes)  # nearest-first heap of candidates, unlimited size

    nearest_nodes = [(-distance, node_id) for distance, node_id in heapq.nsmallest(beam_size, unvisited_nodes)]
    heapq.heapify(nearest_nodes)  # farthest-first heap of size beam_size, used for early-stopping and to select results
    while len(nearest_nodes) > beam_size:
        heapq.heappop(nearest_nodes)

    visited_nodes |= set(initial_nodes)
    upper_bound = -nearest_nodes[0][0]  # distance to farthest element that is still in beam
    was_interrupted = False  # will set to True if host triggered beam search to stop via get_neighbors

    while (not was_interrupted) and len(unvisited_nodes) != 0 and unvisited_nodes[0][0] <= upper_bound:
        _, node_id = heapq.heappop(unvisited_nodes)  # note: this  --^ is the smallest element in heap (see heapq)
        neighbors, was_interrupted = await get_neighbors(node_id)
        neighbors = [node_id for node_id in neighbors if node_id not in visited_nodes]
        visited_nodes.update(neighbors)

        for neighbor_id, distance in zip(neighbors, query_id.xor_distance(neighbors)):
            if distance <= upper_bound or len(nearest_nodes) < beam_size:
                heapq.heappush(unvisited_nodes, (distance, neighbor_id))

                heapq_add_or_replace = heapq.heappush if len(nearest_nodes) < beam_size else heapq.heappushpop
                heapq_add_or_replace(nearest_nodes, (-distance, neighbor_id))
                upper_bound = -nearest_nodes[0][0]  # distance to beam_size-th nearest element found so far

    return [node_id for _, node_id in heapq.nlargest(beam_size, nearest_nodes)], visited_nodes
Esempio n. 5
0
def test_routing_table_search():
    for table_size, lower_active, upper_active in [(10, 10, 10),
                                                   (10_000, 800, 1100)]:
        node_id = DHTID.generate()
        routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
        num_added = 0
        total_nodes = 0

        for phony_neighbor_port in random.sample(range(1_000_000), table_size):
            routing_table.add_or_update_node(
                DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
            new_total = sum(
                len(bucket.nodes_to_endpoint)
                for bucket in routing_table.buckets)
            num_added += new_total > total_nodes
            total_nodes = new_total
Esempio n. 6
0
def test_empty_table():
    """ Test RPC methods with empty routing table """
    peer_port, peer_id, peer_started = find_open_port(), DHTID.generate(
    ), mp.Event()
    peer_proc = mp.Process(target=run_protocol_listener,
                           args=(peer_port, peer_id, peer_started),
                           daemon=True)
    peer_proc.start(), peer_started.wait()
    test_success = mp.Event()

    def _tester():
        # note: we run everything in a separate process to re-initialize all global states from scratch
        # this helps us avoid undesirable side-effects when running multiple tests in sequence

        loop = asyncio.get_event_loop()
        protocol = loop.run_until_complete(
            DHTProtocol.create(DHTID.generate(),
                               bucket_size=20,
                               depth_modulo=5,
                               wait_timeout=5,
                               num_replicas=3,
                               listen=False))

        key, value, expiration = DHTID.generate(), [
            random.random(), {
                'ololo': 'pyshpysh'
            }
        ], get_dht_time() + 1e3

        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
        assert recv_value_bytes is None and recv_expiration is None and len(
            nodes_found) == 0
        assert all(
            loop.run_until_complete(
                protocol.call_store(f'{LOCALHOST}:{peer_port}', [key],
                                    [MSGPackSerializer.dumps(value)],
                                    expiration))), "peer rejected store"

        recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
            protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
        recv_value = MSGPackSerializer.loads(recv_value_bytes)
        assert len(nodes_found) == 0
        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
            f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"

        assert loop.run_until_complete(
            protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
        assert loop.run_until_complete(
            protocol.call_ping(f'{LOCALHOST}:{find_open_port()}')) is None
        test_success.set()

    tester = mp.Process(target=_tester, daemon=True)
    tester.start()
    tester.join()
    assert test_success.is_set()
    peer_proc.terminate()
Esempio n. 7
0
    async def call_store(
        self,
        peer: Endpoint,
        keys: Sequence[DHTID],
        values: Sequence[BinaryDHTValue],
        expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
        in_cache: Optional[Union[bool,
                                 Sequence[bool]]] = None) -> Sequence[bool]:
        """
        Ask a recipient to store several (key, value : expiration_time) items or update their older value

        :param peer: request this peer to store the data
        :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
        :param values: a list of N serialized values (bytes) for each respective key
        :param expiration_time: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
        :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
        :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
         until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size

        :return: list of [True / False] True = stored, False = failed (found newer value or no response)
         if peer did not respond (e.g. due to timeout or congestion), returns None
        """
        if isinstance(expiration_time, DHTExpiration):
            expiration_time = [expiration_time] * len(keys)
        in_cache = in_cache if in_cache is not None else [False] * len(
            keys)  # default value (None)
        in_cache = [in_cache] * len(keys) if isinstance(
            in_cache, bool) else in_cache  # single bool
        keys, values, expiration_time, in_cache = map(
            list, [keys, values, expiration_time, in_cache])
        assert len(keys) == len(values) == len(expiration_time) == len(
            in_cache), "Data is not aligned"
        store_request = dht_pb2.StoreRequest(keys=list(
            map(DHTID.to_bytes, keys)),
                                             values=values,
                                             expiration_time=expiration_time,
                                             in_cache=in_cache,
                                             peer=self.node_info)
        try:
            async with self.rpc_semaphore:
                response = await self._get(peer).rpc_store(
                    store_request, timeout=self.wait_timeout)
            if response.peer and response.peer.node_id:
                peer_id = DHTID.from_bytes(response.peer.node_id)
                asyncio.create_task(
                    self.update_routing_table(peer_id, peer, responded=True))
            return response.store_ok
        except grpc.experimental.aio.AioRpcError as error:
            logger.warning(
                f"DHTProtocol failed to store at {peer}: {error.code()}")
            asyncio.create_task(
                self.update_routing_table(
                    self.routing_table.get(endpoint=peer),
                    peer,
                    responded=False))
            return [False] * len(keys)
Esempio n. 8
0
 async def rpc_ping(self, peer_info: dht_pb2.NodeInfo,
                    context: grpc.ServicerContext):
     """ Some node wants us to add it to our routing table. """
     if peer_info.node_id and peer_info.rpc_port:
         sender_id = DHTID.from_bytes(peer_info.node_id)
         rpc_endpoint = replace_port(context.peer(),
                                     new_port=peer_info.rpc_port)
         asyncio.create_task(
             self.update_routing_table(sender_id, rpc_endpoint))
     return self.node_info
Esempio n. 9
0
def test_routing_table_parameters():
    for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
        (20, 5, 45, 65),
        (50, 5, 35, 45),
        (20, 10, 650, 800),
        (20, 1, 7, 15),
    ]:
        node_id = DHTID.generate()
        routing_table = RoutingTable(node_id,
                                     bucket_size=bucket_size,
                                     depth_modulo=modulo)
        for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
            routing_table.add_or_update_node(
                DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
        for bucket in routing_table.buckets:
            assert len(bucket.replacement_nodes) == 0 or len(
                bucket.nodes_to_endpoint) <= bucket.size
        assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
            f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}"
        )
Esempio n. 10
0
    async def create(cls,
                     node_id: DHTID,
                     bucket_size: int,
                     depth_modulo: int,
                     num_replicas: int,
                     wait_timeout: float,
                     parallel_rpc: Optional[int] = None,
                     cache_size: Optional[int] = None,
                     listen=True,
                     listen_on='0.0.0.0:*',
                     channel_options: Optional[Sequence[Tuple[str,
                                                              Any]]] = None,
                     **kwargs) -> DHTProtocol:
        """
        A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
        As a side-effect, DHTProtocol also maintains a routing table as described in
        https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf

        See DHTNode (node.py) for a more detailed description.

        :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
         for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
         Only the call_* methods are meant to be called publicly, e.g. from DHTNode
         Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
        """
        self = cls(_initialized_with_create=True)
        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
        self.wait_timeout, self.channel_options = wait_timeout, channel_options
        self.storage, self.cache = LocalStorage(), LocalStorage(
            maxsize=cache_size)
        self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
        self.rpc_semaphore = asyncio.Semaphore(
            parallel_rpc if parallel_rpc is not None else float('inf'))

        if listen:  # set up server to process incoming rpc requests
            grpc.experimental.aio.init_grpc_aio()
            self.server = grpc.experimental.aio.server(**kwargs)
            dht_grpc.add_DHTServicer_to_server(self, self.server)

            found_port = self.server.add_insecure_port(listen_on)
            assert found_port != 0, f"Failed to listen to {listen_on}"
            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(),
                                              rpc_port=found_port)
            self.port = found_port
            await self.server.start()
        else:  # not listening to incoming requests, client-only mode
            # note: use empty node_info so peers wont add you to their routing tables
            self.node_info, self.server, self.port = dht_pb2.NodeInfo(
            ), None, None
            if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
                warn(
                    f"DHTProtocol has no server (due to listen=False), listen_on"
                    f"and kwargs have no effect (unused kwargs: {kwargs})")
        return self
Esempio n. 11
0
def test_ids_depth():
    for i in range(100):
        ids = [random.randint(0, 4096) for i in range(random.randint(1, 256))]
        ours = DHTID.longest_common_prefix_length(*map(DHTID, ids))

        ids_bitstr = [
            "".join(
                bin(bite)[2:].rjust(8, '0')
                for bite in uid.to_bytes(20, 'big')) for uid in ids
        ]
        reference = len(shared_prefix(*ids_bitstr))
        assert reference == ours, f"ours {ours} != reference {reference}, ids: {ids}"
Esempio n. 12
0
def test_routing_table_basic():
    node_id = DHTID.generate()
    routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
    added_nodes = []

    for phony_neighbor_port in random.sample(range(10000), 100):
        phony_id = DHTID.generate()
        routing_table.add_or_update_node(phony_id,
                                         f'{LOCALHOST}:{phony_neighbor_port}')
        assert phony_id in routing_table
        assert f'{LOCALHOST}:{phony_neighbor_port}' in routing_table
        assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
        assert routing_table[f'{LOCALHOST}:{phony_neighbor_port}'] == phony_id
        added_nodes.append(phony_id)

    assert routing_table.buckets[
        0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
    for bucket in routing_table.buckets:
        assert len(
            bucket.replacement_nodes
        ) == 0, "There should be no replacement nodes in a table with 100 entries"
    assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets)

    random_node = random.choice(added_nodes)
    assert routing_table.get(node_id=random_node) == routing_table[random_node]
    dummy_node = DHTID.generate()
    assert (dummy_node
            not in routing_table) == (routing_table.get(node_id=dummy_node) is
                                      None)

    for node in added_nodes:
        found_bucket_index = routing_table.get_bucket_index(node)
        for bucket_index, bucket in enumerate(routing_table.buckets):
            if bucket.lower <= node < bucket.upper:
                break
        else:
            raise ValueError(
                "Naive search could not find bucket. Universe has gone crazy.")
        assert bucket_index == found_bucket_index
Esempio n. 13
0
def test_ids_basic():
    # basic functionality tests
    for i in range(100):
        id1, id2 = DHTID.generate(), DHTID.generate()
        assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX
        assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2,
                                                                  id2) == 0
        assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2)
        assert len(PickleSerializer.dumps(id1)) - len(
            PickleSerializer.dumps(int(id1))) < 40
        assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes(
            id2.to_bytes()) == id2
Esempio n. 14
0
    async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
        """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
        while period is not None:  # if None run once, otherwise run forever
            refresh_time = get_dht_time()
            staleness_threshold = refresh_time - period
            stale_buckets = [
                bucket for bucket in self.protocol.routing_table.buckets
                if bucket.last_updated < staleness_threshold
            ]
            for bucket in stale_buckets:
                refresh_id = DHTID(
                    random.randint(bucket.lower, bucket.upper - 1))
                await self.find_nearest_nodes(refresh_id)

            await asyncio.sleep(
                max(0.0, period - (get_dht_time() - refresh_time)))
Esempio n. 15
0
 async def rpc_store(
         self, request: dht_pb2.StoreRequest,
         context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
     """ Some node wants us to store this (key, value) pair """
     if request.peer:  # if requested, add peer to the routing table
         asyncio.create_task(self.rpc_ping(request.peer, context))
     assert len(request.keys) == len(request.values) == len(
         request.expiration_time) == len(request.in_cache)
     response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
     for key_bytes, value_bytes, expiration_time, in_cache in zip(
             request.keys, request.values, request.expiration_time,
             request.in_cache):
         local_memory = self.cache if in_cache else self.storage
         response.store_ok.append(
             local_memory.store(DHTID.from_bytes(key_bytes), value_bytes,
                                expiration_time))
     return response
Esempio n. 16
0
    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \
            Optional[Dict[DHTID, Tuple[Optional[BinaryDHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]]]:
        """
        Request keys from a peer. For each key, look for its (value, expiration time) locally and
         k additional peers that are most likely to have this key (ranked by XOR distance)

        :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
         value: value stored by the recipient with that key, or None if peer doesn't have this value
         expiration time: expiration time of the returned value, None if no value was found
         neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table
         If peer didn't respond, returns None
        """
        keys = list(keys)
        find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes,
                                                         keys)),
                                           peer=self.node_info)
        try:
            async with self.rpc_semaphore:
                response = await self._get(peer).rpc_find(
                    find_request, timeout=self.wait_timeout)
            if response.peer and response.peer.node_id:
                peer_id = DHTID.from_bytes(response.peer.node_id)
                asyncio.create_task(
                    self.update_routing_table(peer_id, peer, responded=True))
            assert len(response.values) == len(response.expiration_time) == len(response.nearest) == len(keys), \
                "DHTProtocol: response is not aligned with keys and/or expiration times"

            output = {}  # unpack data without special NOT_FOUND_* values
            for key, value, expiration_time, nearest in zip(
                    keys, response.values, response.expiration_time,
                    response.nearest):
                value = value if value != _NOT_FOUND_VALUE else None
                expiration_time = expiration_time if expiration_time != _NOT_FOUND_EXPIRATION else None
                nearest = dict(
                    zip(map(DHTID.from_bytes, nearest.node_ids),
                        nearest.endpoints))
                output[key] = (value, expiration_time, nearest)
            return output
        except grpc.experimental.aio.AioRpcError as error:
            logger.warning(
                f"DHTProtocol failed to find at {peer}: {error.code()}")
            asyncio.create_task(
                self.update_routing_table(
                    self.routing_table.get(endpoint=peer),
                    peer,
                    responded=False))
Esempio n. 17
0
    async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
        """
        Get peer's node id and add him to the routing table. If peer doesn't respond, return None
        :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
        :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table

        :return: node's DHTID, if peer responded and decided to send his node_id
        """
        try:
            async with self.rpc_semaphore:
                peer_info = await self._get(peer).rpc_ping(
                    self.node_info, timeout=self.wait_timeout)
        except grpc.experimental.aio.AioRpcError as error:
            logger.warning(
                f"DHTProtocol failed to ping {peer}: {error.code()}")
            peer_info = None
        responded = bool(peer_info and peer_info.node_id)
        peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
        asyncio.create_task(
            self.update_routing_table(peer_id, peer, responded=responded))
        return peer_id
Esempio n. 18
0
    async def rpc_find(self, request: dht_pb2.FindRequest,
                       context: grpc.ServicerContext) -> dht_pb2.FindResponse:
        """
        Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
        Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
        """
        if request.peer:  # if requested, add peer to the routing table
            asyncio.create_task(self.rpc_ping(request.peer, context))

        response = dht_pb2.FindResponse(values=[],
                                        expiration_time=[],
                                        nearest=[],
                                        peer=self.node_info)
        for key_id in map(DHTID.from_bytes, request.keys):
            maybe_value, maybe_expiration_time = self.storage.get(key_id)
            cached_value, cached_expiration_time = self.cache.get(key_id)
            if (cached_expiration_time or
                    -float('inf')) > (maybe_expiration_time or -float('inf')):
                maybe_value, maybe_expiration_time = cached_value, cached_expiration_time

            nearest_neighbors = self.routing_table.get_nearest_neighbors(
                key_id,
                k=self.bucket_size,
                exclude=DHTID.from_bytes(request.peer.node_id))
            if nearest_neighbors:
                peer_ids, endpoints = zip(*nearest_neighbors)
            else:
                peer_ids, endpoints = [], []

            response.values.append(
                maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
            response.expiration_time.append(
                maybe_expiration_time
                if maybe_expiration_time else _NOT_FOUND_EXPIRATION)
            response.nearest.append(
                dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)),
                              endpoints=endpoints))
        return response
Esempio n. 19
0
    def _tester():
        # note: we run everything in a separate process to re-initialize all global states from scratch
        # this helps us avoid undesirable side-effects when running multiple tests in sequence

        loop = asyncio.get_event_loop()
        for listen in [
                False, True
        ]:  # note: order matters, this test assumes that first run uses listen=False
            protocol = loop.run_until_complete(
                DHTProtocol.create(DHTID.generate(),
                                   bucket_size=20,
                                   depth_modulo=5,
                                   wait_timeout=5,
                                   num_replicas=3,
                                   listen=listen))
            print(f"Self id={protocol.node_id}", flush=True)

            assert loop.run_until_complete(
                protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id

            key, value, expiration = DHTID.generate(), [
                random.random(), {
                    'ololo': 'pyshpysh'
                }
            ], get_dht_time() + 1e3
            store_ok = loop.run_until_complete(
                protocol.call_store(f'{LOCALHOST}:{peer1_port}', [key],
                                    [MSGPackSerializer.dumps(value)],
                                    expiration))
            assert all(store_ok), "DHT rejected a trivial store"

            # peer 1 must know about peer 2
            recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
            recv_value = MSGPackSerializer.loads(recv_value_bytes)
            (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
            assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
                f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"

            assert recv_value == value and recv_expiration == expiration, \
                f"call_find_value expected {value} (expires by {expiration}) " \
                f"but got {recv_value} (expires by {recv_expiration})"

            # peer 2 must know about peer 1, but not have a *random* nonexistent value
            dummy_key = DHTID.generate()
            recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete(
                protocol.call_find(f'{LOCALHOST}:{peer2_port}',
                                   [dummy_key]))[dummy_key]
            assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values"
            (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
            assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
                f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"

            # cause a non-response by querying a nonexistent peer
            dummy_port = find_open_port()
            assert loop.run_until_complete(
                protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None

            if listen:
                loop.run_until_complete(protocol.shutdown())
            print("DHTProtocol test finished successfully!")
            test_success.set()
Esempio n. 20
0
    def _tester():
        # note: we run everything in a separate process to re-initialize all global states from scratch
        # this helps us avoid undesirable side-effects when running multiple tests in sequence
        loop = asyncio.get_event_loop()
        me = loop.run_until_complete(
            DHTNode.create(initial_peers=random.sample(dht.keys(), 5),
                           parallel_rpc=10))

        # test 1: find self
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
        assert len(nearest) == 1 and ':'.join(
            nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"

        # test 2: find others
        for i in range(10):
            ref_endpoint, query_id = random.choice(list(dht.items()))
            nearest = loop.run_until_complete(
                me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
            assert len(nearest) == 1
            found_node_id, found_endpoint = next(iter(nearest.items()))
            assert found_node_id == query_id and ':'.join(
                found_endpoint.split(':')[-2:]) == ref_endpoint

        # test 3: find neighbors to random nodes
        accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
        jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
        all_node_ids = list(dht.values())

        for i in range(100):
            query_id = DHTID.generate()
            k_nearest = random.randint(1, 20)
            exclude_self = random.random() > 0.5
            nearest = loop.run_until_complete(
                me.find_nearest_nodes([query_id],
                                      k_nearest=k_nearest,
                                      exclude_self=exclude_self))[query_id]
            nearest_nodes = list(nearest)  # keys from ordered dict

            assert len(
                nearest_nodes
            ) == k_nearest, "beam search must return exactly k_nearest results"
            assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
            assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0
                          ), "results must be sorted by distance"

            ref_nearest = heapq.nsmallest(k_nearest + 1,
                                          all_node_ids,
                                          key=query_id.xor_distance)
            if exclude_self and me.node_id in ref_nearest:
                ref_nearest.remove(me.node_id)
            if len(ref_nearest) > k_nearest:
                ref_nearest.pop()

            accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
            accuracy_denominator += 1

            jaccard_numerator += len(
                set.intersection(set(nearest_nodes), set(ref_nearest)))
            jaccard_denominator += k_nearest

        accuracy = accuracy_numerator / accuracy_denominator
        print("Top-1 accuracy:", accuracy)  # should be 98-100%
        jaccard_index = jaccard_numerator / jaccard_denominator
        print("Jaccard index (intersection over union):",
              jaccard_index)  # should be 95-100%
        assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
        assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"

        # test 4: find all nodes
        dummy = DHTID.generate()
        nearest = loop.run_until_complete(
            me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
        assert len(nearest) == len(dht) + 1
        assert len(
            set.difference(set(nearest.keys()),
                           set(all_node_ids) | {me.node_id})) == 0

        # test 5: node without peers
        other_node = loop.run_until_complete(DHTNode.create())
        nearest = loop.run_until_complete(
            other_node.find_nearest_nodes([dummy]))[dummy]
        assert len(nearest) == 1 and nearest[
            other_node.node_id] == f"{LOCALHOST}:{other_node.port}"
        nearest = loop.run_until_complete(
            other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
        assert len(nearest) == 0

        # test 6 store and get value
        true_time = get_dht_time() + 1200
        assert loop.run_until_complete(
            me.store("mykey", ["Value", 10], true_time))
        for node in [me, other_node]:
            val, expiration_time = loop.run_until_complete(me.get("mykey"))
            assert expiration_time == true_time, "Wrong time"
            assert val == ["Value", 10], "Wrong value"

        # test 7: bulk store and bulk get
        keys = 'foo', 'bar', 'baz', 'zzz'
        values = 3, 2, 'batman', [1, 2, 3]
        store_ok = loop.run_until_complete(
            me.store_many(keys, values, expiration_time=get_dht_time() + 999))
        assert all(store_ok.values()), "failed to store one or more keys"
        response = loop.run_until_complete(me.get_many(keys[::-1]))
        for key, value in zip(keys, values):
            assert key in response and response[key][0] == value

        test_success.set()
Esempio n. 21
0
def test_get_empty():
    d = LocalStorage()
    assert d.get(DHTID.generate(
        source="key")) == (None,
                           None), "LocalStorage returned non-existent value"
    print("Test get expired passed")
Esempio n. 22
0
def test_store():
    d = LocalStorage()
    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
    print("Test store passed")
Esempio n. 23
0
    async def create(cls,
                     node_id: Optional[DHTID] = None,
                     initial_peers: List[Endpoint] = (),
                     bucket_size: int = 20,
                     num_replicas: int = 5,
                     depth_modulo: int = 5,
                     parallel_rpc: int = None,
                     wait_timeout: float = 5,
                     refresh_timeout: Optional[float] = None,
                     bootstrap_timeout: Optional[float] = None,
                     num_workers: int = 1,
                     cache_locally: bool = True,
                     cache_nearest: int = 1,
                     cache_size=None,
                     listen: bool = True,
                     listen_on: Endpoint = "0.0.0.0:*",
                     **kwargs) -> DHTNode:
        """
        :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
        :param initial_peers: connects to these peers to populate routing table, defaults to no peers
        :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
          either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
          Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
        :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
        :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
        :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
          Reduce this value if your RPC requests register no response despite the peer sending the response.
        :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
        :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
          if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
        :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
        :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
        :param cache_locally: if True, caches all values (stored or found) in a node-local cache
        :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
          nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
        :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
        :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
          if False, this node will refuse any incoming request, effectively being only a "client"
        :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
        :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
        :param kwargs: extra parameters used in grpc.aio.server
        """
        self = cls(_initialized_with_create=True)
        self.node_id = node_id = node_id if node_id is not None else DHTID.generate(
        )
        self.num_replicas, self.num_workers = num_replicas, num_workers
        self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
        self.refresh_timeout = refresh_timeout

        self.protocol = await DHTProtocol.create(self.node_id, bucket_size,
                                                 depth_modulo, num_replicas,
                                                 wait_timeout, parallel_rpc,
                                                 cache_size, listen, listen_on,
                                                 **kwargs)
        self.port = self.protocol.port

        if initial_peers:
            # stage 1: ping initial_peers, add each other to the routing table
            bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
            start_time = get_dht_time()
            ping_tasks = map(self.protocol.call_ping, initial_peers)
            finished_pings, unfinished_pings = await asyncio.wait(
                ping_tasks, return_when=asyncio.FIRST_COMPLETED)

            # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
            if unfinished_pings:
                finished_in_time, stragglers = await asyncio.wait(
                    unfinished_pings,
                    timeout=bootstrap_timeout - get_dht_time() + start_time)
                for straggler in stragglers:
                    straggler.cancel()
                finished_pings |= finished_in_time

            if not finished_pings:
                warn(
                    "DHTNode bootstrap failed: none of the initial_peers responded to a ping."
                )

            # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
            # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
            # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
            await asyncio.wait([
                asyncio.create_task(self.find_nearest_nodes([self.node_id])),
                asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)
            ],
                               return_when=asyncio.FIRST_COMPLETED)

        if self.refresh_timeout is not None:
            asyncio.create_task(
                self._refresh_routing_table(period=self.refresh_timeout))
        return self
Esempio n. 24
0
            num_added += new_total > total_nodes
            total_nodes = new_total
        num_replacements = sum(
            len(bucket.replacement_nodes) for bucket in routing_table.buckets)

        all_active_neighbors = list(
            chain(*(bucket.nodes_to_endpoint.keys()
                    for bucket in routing_table.buckets)))
        assert lower_active <= len(all_active_neighbors) <= upper_active
        assert len(all_active_neighbors) == num_added
        assert num_added + num_replacements == table_size

        # random queries
        for i in range(1000):
            k = random.randint(1, 100)
            query_id = DHTID.generate()
            exclude = query_id if random.random() < 0.5 else None
            our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors(
                query_id, k=k, exclude=exclude))
            reference_knn = heapq.nsmallest(k,
                                            all_active_neighbors,
                                            key=query_id.xor_distance)
            assert all(our == ref
                       for our, ref in zip_longest(our_knn, reference_knn))
            assert all(
                our_endpoint == routing_table[our_node]
                for our_node, our_endpoint in zip(our_knn, our_endpoints))

        # queries from table
        for i in range(1000):
            k = random.randint(1, 100)
Esempio n. 25
0
    async def get_many(
        self,
        keys: Collection[DHTKey],
        sufficient_expiration_time: Optional[DHTExpiration] = None,
        num_workers: Optional[int] = None,
        beam_size: Optional[int] = None
    ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]:
        """
        :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
        :param sufficient_expiration_time: if the search finds a value that expires after this time,
            default = time of call, find any value that did not expire by the time of call
            If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
        :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
        :param num_workers: override for default num_workers, see traverse_dht num_workers param
        :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key
        :note: in order to check if get returned a value, please check (expiration_time is None)
        """
        key_ids = [DHTID.generate(key) for key in keys]
        id_to_original_key = dict(zip(key_ids, keys))
        sufficient_expiration_time = sufficient_expiration_time or get_dht_time(
        )
        beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
        num_workers = num_workers if num_workers is not None else self.num_workers

        # search metadata
        unfinished_key_ids = set(
            key_ids)  # track key ids for which the search is not terminated
        node_to_endpoint: Dict[
            DHTID, Endpoint] = dict()  # global routing table for all queries

        SearchResult = namedtuple(
            "SearchResult",
            ["binary_value", "expiration_time", "source_node_id"])
        latest_results = {
            key_id: SearchResult(b'', -float('inf'), None)
            for key_id in key_ids
        }

        # stage 1: value can be stored in our local cache
        for key_id in key_ids:
            maybe_value, maybe_expiration_time = self.protocol.storage.get(
                key_id)
            if maybe_expiration_time is None:
                maybe_value, maybe_expiration_time = self.protocol.cache.get(
                    key_id)
            if maybe_expiration_time is not None and maybe_expiration_time > latest_results[
                    key_id].expiration_time:
                latest_results[key_id] = SearchResult(maybe_value,
                                                      maybe_expiration_time,
                                                      self.node_id)
                if maybe_expiration_time >= sufficient_expiration_time:
                    unfinished_key_ids.remove(key_id)

        # stage 2: traverse the DHT for any unfinished keys
        for key_id in unfinished_key_ids:
            node_to_endpoint.update(
                self.protocol.routing_table.get_nearest_neighbors(
                    key_id, self.protocol.bucket_size, exclude=self.node_id))

        async def get_neighbors(
            peer: DHTID, queries: Collection[DHTID]
        ) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
            queries = list(queries)
            response = await self.protocol.call_find(node_to_endpoint[peer],
                                                     queries)
            if not response:
                return {query: ([], False) for query in queries}

            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
            for key_id, (maybe_value, maybe_expiration_time,
                         peers) in response.items():
                node_to_endpoint.update(peers)
                if maybe_expiration_time is not None and maybe_expiration_time > latest_results[
                        key_id].expiration_time:
                    latest_results[key_id] = SearchResult(
                        maybe_value, maybe_expiration_time, peer)
                should_interrupt = (latest_results[key_id].expiration_time >=
                                    sufficient_expiration_time)
                output[key_id] = list(peers.keys()), should_interrupt
            return output

        nearest_nodes_per_query, visited_nodes = await traverse_dht(
            queries=list(unfinished_key_ids),
            initial_nodes=list(node_to_endpoint),
            beam_size=beam_size,
            num_workers=num_workers,
            queries_per_call=int(len(unfinished_key_ids)**0.5),
            get_neighbors=get_neighbors,
            visited_nodes={
                key_id: {self.node_id}
                for key_id in unfinished_key_ids
            })

        # stage 3: cache any new results depending on caching parameters
        for key_id, nearest_nodes in nearest_nodes_per_query.items():
            latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[
                key_id]
            should_cache = latest_expiration_time >= sufficient_expiration_time  # if we found a newer value, cache it
            if should_cache and self.cache_locally:
                self.protocol.cache.store(key_id, latest_value_bytes,
                                          latest_expiration_time)

            if should_cache and self.cache_nearest:
                num_cached_nodes = 0
                for node_id in nearest_nodes:
                    if node_id == latest_node_id:
                        continue
                    asyncio.create_task(
                        self.protocol.call_store(node_to_endpoint[node_id],
                                                 [key_id],
                                                 [latest_value_bytes],
                                                 [latest_expiration_time],
                                                 in_cache=True))
                    num_cached_nodes += 1
                    if num_cached_nodes >= self.cache_nearest:
                        break

        # stage 4: deserialize data and assemble function output
        find_result: Dict[DHTKey, Tuple[Optional[DHTValue],
                                        Optional[DHTExpiration]]] = {}
        for key_id, (latest_value_bytes, latest_expiration_time,
                     _) in latest_results.items():
            if latest_expiration_time != -float('inf'):
                latest_value = self.serializer.loads(latest_value_bytes)
                find_result[id_to_original_key[key_id]] = (
                    latest_value, latest_expiration_time)
            else:
                find_result[id_to_original_key[key_id]] = None, None
        return find_result