Ejemplo n.º 1
0
 def _update_cache_on_store(self, key_id: DHTID, subkeys: List[Subkey],
                            binary_values: List[bytes],
                            expirations: List[DHTExpiration],
                            store_ok: List[bool]):
     """ Update local cache after finishing a store for one key (with perhaps several subkeys) """
     store_succeeded = any(store_ok)
     is_dictionary = any(subkey is not None for subkey in subkeys)
     if store_succeeded and not is_dictionary:  # stored a new regular value, cache it!
         stored_value_bytes, stored_expiration = max(zip(
             binary_values, expirations),
                                                     key=lambda p: p[1])
         self.protocol.cache.store(key_id, stored_value_bytes,
                                   stored_expiration)
     elif not store_succeeded and not is_dictionary:  # store rejected, check if local cache is also obsolete
         rejected_value, rejected_expiration = max(zip(
             binary_values, expirations),
                                                   key=lambda p: p[1])
         if (self.protocol.cache.get(key_id)[1] or float("inf")
             ) <= rejected_expiration:  # cache would be rejected
             self._schedule_for_refresh(key_id, refresh_time=get_dht_time(
             ))  # fetch new key in background (asap)
     elif is_dictionary and key_id in self.protocol.cache:  # there can be other keys and we should update
         for subkey, stored_value_bytes, expiration_time in zip(
                 subkeys, binary_values, expirations):
             self.protocol.cache.store_subkey(key_id, subkey,
                                              stored_value_bytes,
                                              expiration_time)
         self._schedule_for_refresh(key_id, refresh_time=get_dht_time()
                                    )  # fetch new key in background (asap)
Ejemplo n.º 2
0
def test_maxsize_cache():
    d = DHTLocalStorage(maxsize=1)
    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
    assert d.get(DHTID.generate(
        "key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
    assert d.get(DHTID.generate(
        "key1")) is None, "Value with less exp time, must be deleted"
Ejemplo n.º 3
0
def test_change_expiration_time():
    d = DHTLocalStorage()
    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
    assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
    time.sleep(1)
    assert d.get(DHTID.generate(
        "key"))[0] == b"val2", "Value must be changed, but still kept in table"
    print("Test change expiration time passed")
Ejemplo n.º 4
0
    async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
        """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
        while self.is_alive and period is not None:  # if None run once, otherwise run forever
            refresh_time = get_dht_time()
            staleness_threshold = refresh_time - period
            stale_buckets = [
                bucket for bucket in self.protocol.routing_table.buckets
                if bucket.last_updated < staleness_threshold
            ]
            for bucket in stale_buckets:
                refresh_id = DHTID(
                    random.randint(bucket.lower, bucket.upper - 1))
                await self.find_nearest_nodes(refresh_id)

            await asyncio.sleep(
                max(0.0, period - (get_dht_time() - refresh_time)))
Ejemplo n.º 5
0
 def _remove_outdated(self):
     while not self.frozen and self.expiration_heap and (
             self.expiration_heap[ROOT].expiration_time < get_dht_time()
             or len(self.expiration_heap) > self.maxsize):
         heap_entry = heapq.heappop(self.expiration_heap)
         if self.key_to_heap.get(heap_entry.key) == heap_entry:
             del self.data[heap_entry.key], self.key_to_heap[heap_entry.key]
Ejemplo n.º 6
0
def test_get_expired():
    d = DHTLocalStorage()
    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
    time.sleep(0.5)
    assert d.get(
        DHTID.generate("key")) is None, "Expired value must be deleted"
    print("Test get expired passed")
Ejemplo n.º 7
0
    async def _declare_experts(
            self, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint,
            future: Optional[MPFuture]) -> Dict[ExpertUID, bool]:
        num_workers = len(uids) if self.max_workers is None else min(
            len(uids), self.max_workers)
        expiration_time = get_dht_time() + self.expiration
        data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]],
                            DHTValue] = {}
        for uid in uids:
            data_to_store[uid, None] = endpoint
            prefix = uid if uid.count(
                UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}'
            for i in range(prefix.count(UID_DELIMITER) - 1):
                prefix, last_coord = split_uid(prefix)
                data_to_store[prefix, last_coord] = [uid, endpoint]

        keys, maybe_subkeys, values = zip(
            *((key, subkey, value)
              for (key, subkey), value in data_to_store.items()))
        store_ok = await node.store_many(keys,
                                         values,
                                         expiration_time,
                                         subkeys=maybe_subkeys,
                                         num_workers=num_workers)
        if future is not None:
            future.set_result(store_ok)
        return store_ok
Ejemplo n.º 8
0
 async def _get_experts(
         self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
     if expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
     response = await node.get_many(uids, expiration_time, num_workers=num_workers)
     future.set_result([RemoteExpert(**expert_data) if maybe_expiration_time else None
                        for uid, (expert_data, maybe_expiration_time) in response.items()])
Ejemplo n.º 9
0
 def remove_outdated(self):
     while self.expiration_heap and (
             self.expiration_heap[0][0] < get_dht_time()
             or len(self.expiration_heap) > self.cache_size):
         heap_entry = heapq.heappop(self.expiration_heap)
         key = heap_entry[1]
         if self.key_to_heap[key] == heap_entry:
             del self.data[key], self.key_to_heap[key]
Ejemplo n.º 10
0
def test_localstorage_freeze():
    d = DHTLocalStorage(maxsize=2)

    with d.freeze():
        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01)
        assert DHTID.generate("key1") in d
        time.sleep(0.03)
        assert DHTID.generate("key1") in d
    assert DHTID.generate("key1") not in d

    with d.freeze():
        d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
        d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
        d.store(DHTID.generate("key3"), b"val3",
                get_dht_time() + 3)  # key3 will push key1 out due to maxsize
        assert DHTID.generate("key1") in d
    assert DHTID.generate("key1") not in d
Ejemplo n.º 11
0
    async def _refresh_stale_cache_entries(self):
        """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
        while self.is_alive:
            while len(self.cache_refresh_queue) == 0:
                await self.cache_refresh_evt.wait()
                self.cache_refresh_evt.clear()
            key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()

            try:
                # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
                time_to_wait = nearest_refresh_time - get_dht_time()
                await asyncio.wait_for(self.cache_refresh_evt.wait(),
                                       timeout=time_to_wait)
                # note: the line above will cause TimeoutError when we are ready to refresh cache
                self.cache_refresh_evt.clear(
                )  # no timeout error => someone added new entry to queue and ...
                continue  # ... and this element is earlier than nearest_expiration. we should refresh this entry first

            except asyncio.TimeoutError:  # caught TimeoutError => it is time to refresh the most recent cached entry
                # step 2: find all keys that we should already refresh and remove them from queue
                current_time = get_dht_time()
                keys_to_refresh = {key_id}
                max_expiration_time = self.protocol.cache.get(
                    key_id)[1] or current_time
                del self.cache_refresh_queue[
                    key_id]  # we pledge to refresh this key_id in the nearest batch
                while self.cache_refresh_queue and len(
                        keys_to_refresh) < self.chunk_size:
                    key_id, (
                        _,
                        nearest_refresh_time) = self.cache_refresh_queue.top()
                    if nearest_refresh_time > current_time:
                        break
                    del self.cache_refresh_queue[
                        key_id]  # we pledge to refresh this key_id in the nearest batch
                    keys_to_refresh.add(key_id)
                    cached_item = self.protocol.cache.get(key_id)
                    if cached_item is not None and cached_item.expiration_time > max_expiration_time:
                        max_expiration_time = cached_item.expiration_time

                # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
                sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
                await self.get_many_by_id(keys_to_refresh,
                                          sufficient_expiration_time,
                                          _is_refresh=True)
Ejemplo n.º 12
0
    async def _refresh_stale_cache_entries(self):
        """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
        while self.is_alive:
            with self.cache_refresh_queue.freeze():
                while len(self.cache_refresh_queue) == 0:
                    await self.cache_refresh_available.wait()
                    self.cache_refresh_available.clear()
                key_id, _, nearest_expiration = self.cache_refresh_queue.top()

            try:
                # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
                time_to_wait = nearest_expiration - get_dht_time(
                ) - self.cache_refresh_before_expiry
                await asyncio.wait_for(self.cache_refresh_available.wait(),
                                       timeout=time_to_wait)
                # note: the line above will cause TimeoutError when we are ready to refresh cache
                self.cache_refresh_available.clear(
                )  # no timeout error => someone added new entry to queue and ...
                continue  # ... and this element is earlier than nearest_expiration. we should refresh this entry first

            except asyncio.TimeoutError:  # caught TimeoutError => it is time to refresh the most recent cached entry
                # step 2: find all keys that we should already refresh and remove them from queue
                with self.cache_refresh_queue.freeze():
                    keys_to_refresh = {key_id}
                    del self.cache_refresh_queue[
                        key_id]  # we pledge to refresh this key_id in the nearest batch
                    while self.cache_refresh_queue:
                        key_id, _, nearest_expiration = self.cache_refresh_queue.top(
                        )
                        if nearest_expiration > get_dht_time(
                        ) + self.cache_refresh_before_expiry:
                            break
                        del self.cache_refresh_queue[
                            key_id]  # we pledge to refresh this key_id in the nearest batch
                        keys_to_refresh.add(key_id)

                # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
                await self.get_many_by_id(
                    keys_to_refresh,
                    sufficient_expiration_time=nearest_expiration +
                    self.cache_refresh_before_expiry,
                    _refresh_cache=False
                )  # if we found value locally, we shouldn't trigger another refresh
Ejemplo n.º 13
0
    async def _get_experts(self, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration],
                           future: Optional[MPFuture] = None) -> List[Optional[RemoteExpert]]:
        if expiration_time is None:
            expiration_time = get_dht_time()
        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
        found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)

        experts: List[Optional[RemoteExpert]] = [None] * len(uids)
        for i, uid in enumerate(uids):
            if found[uid] is not None and isinstance(found[uid].value, Endpoint):
                experts[i] = RemoteExpert(uid, found[uid].value)
        if future:
            future.set_result(experts)
        return experts
Ejemplo n.º 14
0
    async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
        expiration_time = get_dht_time() + self.expiration

        data_to_store = {}
        for uid in uids:
            uid_parts = uid.split(self.UID_DELIMITER)
            for i in range(len(uid_parts)):
                uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
                data_to_store[uid_prefix_i] = {'uid': uid, 'endpoint': endpoint}

        store_keys, store_values = zip(*data_to_store.items())
        store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers)
        if future is not None:
            future.set_result([store_ok[key] for key in data_to_store.keys()])
Ejemplo n.º 15
0
def test_localstorage_serialize():
    d1 = DictionaryDHTValue()
    d2 = DictionaryDHTValue()

    now = get_dht_time()
    d1.store('key1', b'ololo', now + 1)
    d2.store('key2', b'pysh', now + 1)
    d2.store('key3', b'pyshpysh', now + 2)

    data = MSGPackSerializer.dumps([d1, d2, 123321])
    assert isinstance(data, bytes)
    new_d1, new_d2, new_value = MSGPackSerializer.loads(data)
    assert isinstance(new_d1, DictionaryDHTValue) and isinstance(
        new_d2, DictionaryDHTValue) and new_value == 123321
    assert 'key1' in new_d1 and len(new_d1) == 1
    assert 'key1' not in new_d2 and len(new_d2) == 2 and new_d2.get(
        'key3') == (b'pyshpysh', now + 2)
Ejemplo n.º 16
0
 def store(self, key: DHTID, value: BinaryDHTValue,
           expiration_time: DHTExpiration) -> bool:
     """
     Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
     :returns: True if new value was stored, False it was rejected (current value is newer)
     """
     if expiration_time < get_dht_time():
         return False
     self.key_to_heap[key] = (expiration_time, key)
     heapq.heappush(self.expiration_heap, (expiration_time, key))
     if key in self.data:
         if self.data[key][1] < expiration_time:
             self.data[key] = (value, expiration_time)
             return True
         return False
     self.data[key] = (value, expiration_time)
     self.remove_outdated()
     return True
Ejemplo n.º 17
0
 def store(self, key: KeyType, value: ValueType,
           expiration_time: DHTExpiration) -> bool:
     """
     Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
     :returns: True if new value was stored, False it was rejected (current value is newer)
     """
     if expiration_time < get_dht_time() and not self.frozen:
         return False
     self.key_to_heap[key] = HeapEntry(expiration_time, key)
     heapq.heappush(self.expiration_heap, self.key_to_heap[key])
     if key in self.data:
         if self.data[key].expiration_time < expiration_time:
             self.data[key] = ValueWithExpiration(value, expiration_time)
             return True
         return False
     self.data[key] = ValueWithExpiration(value, expiration_time)
     self._remove_outdated()
     return True
Ejemplo n.º 18
0
def test_localstorage_nested():
    time = get_dht_time()
    d1 = DHTLocalStorage()
    d2 = DictionaryDHTValue()
    d2.store('subkey1', b'value1', time + 2)
    d2.store('subkey2', b'value2', time + 3)
    d2.store('subkey3', b'value3', time + 1)

    assert d2.latest_expiration_time == time + 3
    for subkey, (subvalue, subexpiration) in d2.items():
        assert d1.store_subkey(DHTID.generate('foo'), subkey, subvalue,
                               subexpiration)
    assert d1.store(DHTID.generate('bar'), b'456', time + 2)
    assert d1.get(DHTID.generate('foo'))[0].data == d2.data
    assert d1.get(DHTID.generate('foo'))[1] == d2.latest_expiration_time
    assert d1.get(DHTID.generate('foo'))[0].get('subkey1') == (b'value1',
                                                               time + 2)
    assert len(d1.get(DHTID.generate('foo'))[0]) == 3
    assert d1.store_subkey(DHTID.generate('foo'), 'subkey4', b'value4',
                           time + 4)
    assert len(d1.get(DHTID.generate('foo'))[0]) == 4

    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA',
                           time + 1) is False  # prev has better expiration
    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA',
                           time + 3)  # new value has better expiration
    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyB', b'valueB',
                           time + 4)  # new value has better expiration
    assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA+',
                           time + 5)  # overwrite subkeyA under key bar
    assert all(subkey in d1.get(DHTID.generate('bar'))[0]
               for subkey in ('subkeyA', 'subkeyB'))
    assert len(d1.get(DHTID.generate('bar'))[0]) == 2 and d1.get(
        DHTID.generate('bar'))[1] == time + 5

    assert d1.store(DHTID.generate('foo'), b'nothing', time +
                    3.5) is False  # previous value has better expiration
    assert d1.get(DHTID.generate('foo'))[0].get('subkey2') == (b'value2',
                                                               time + 3)
    assert d1.store(DHTID.generate('foo'), b'nothing',
                    time + 5) is True  # new value has better expiraiton
    assert d1.get(DHTID.generate('foo')) == (b'nothing', time + 5
                                             )  # value should be replaced
Ejemplo n.º 19
0
 async def _get_active_successors(self, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None,
                                  num_workers: Optional[int] = None, future: Optional[MPFuture] = None
                                  ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
     grid_size = grid_size or float('inf')
     num_workers = num_workers or min(len(prefixes), self.max_workers or len(prefixes))
     dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers)
     successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {}
     for prefix, found in dht_responses.items():
         if found and isinstance(found.value, dict):
             successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items()
                                   if isinstance(coord, Coordinate) and 0 <= coord < grid_size
                                   and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2}
         else:
             successors[prefix] = {}
             if found is None and self.negative_caching:
                 logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}")
                 asyncio.create_task(node.store(prefix, subkey=-1, value=None,
                                                expiration_time=get_dht_time() + self.expiration))
     if future:
         future.set_result(successors)
     return successors
Ejemplo n.º 20
0
def test_localstorage_top():
    d = DHTLocalStorage(maxsize=3)
    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2)
    d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4)
    assert d.top()[0] == DHTID.generate("key1") and d.top()[1].value == b"val1"

    d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3)
    assert d.top()[0] == DHTID.generate("key2") and d.top()[1].value == b"val2"

    del d[DHTID.generate('key2')]
    assert d.top()[0] == DHTID.generate(
        "key1") and d.top()[1].value == b"val1_new"
    d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5)
    d.store(DHTID.generate("key4"), b"val4",
            get_dht_time() + 6)  # key4 will push out key1 due to maxsize

    assert d.top()[0] == DHTID.generate("key3") and d.top()[1].value == b"val3"
Ejemplo n.º 21
0
    async def _get_initial_beam(self, node, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...],
                                num_workers: Optional[int] = None, future: Optional[MPFuture] = None
                                ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
        num_workers = num_workers or self.max_workers or beam_size
        beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = []
        unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__)  # from worst to best
        pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque()

        while len(beam) < beam_size and (unattempted_indices or pending_tasks):
            # dispatch additional tasks
            while unattempted_indices and len(pending_tasks) < num_workers:
                next_index = unattempted_indices.pop()  # note: this is best unattempted index because of sort order
                next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}"
                pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix))))

            # await the next best prefix to be fetched
            pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft()
            try:
                maybe_prefix_data = await pending_task
                if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict):
                    successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items()
                                  if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list)
                                  and len(match.value) == 2}
                    if successors:
                        beam.append((scores[pending_best_index], pending_best_prefix, successors))
                elif maybe_prefix_data is None and self.negative_caching:
                    logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}")
                    asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None,
                                                   expiration_time=get_dht_time() + self.expiration))

            except asyncio.CancelledError:
                for _, pending_task in pending_tasks:
                    pending_task.cancel()
                raise
        if future:
            future.set_result(beam)
        return beam
Ejemplo n.º 22
0
def test_store():
    d = DHTLocalStorage()
    d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
    print("Test store passed")
Ejemplo n.º 23
0
    async def create(cls,
                     node_id: Optional[DHTID] = None,
                     initial_peers: List[Endpoint] = (),
                     bucket_size: int = 20,
                     num_replicas: int = 5,
                     depth_modulo: int = 5,
                     parallel_rpc: int = None,
                     wait_timeout: float = 5,
                     refresh_timeout: Optional[float] = None,
                     bootstrap_timeout: Optional[float] = None,
                     cache_locally: bool = True,
                     cache_nearest: int = 1,
                     cache_size=None,
                     cache_refresh_before_expiry: float = 5,
                     cache_on_store: bool = True,
                     reuse_get_requests: bool = True,
                     num_workers: int = 1,
                     chunk_size: int = 16,
                     listen: bool = True,
                     listen_on: Endpoint = "0.0.0.0:*",
                     **kwargs) -> DHTNode:
        """
        :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
        :param initial_peers: connects to these peers to populate routing table, defaults to no peers
        :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
          either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
          Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
        :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
        :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
        :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
          Reduce this value if your RPC requests register no response despite the peer sending the response.
        :param wait_timeout: a kademlia rpc request is deemed lost if we did not receive a reply in this many seconds
        :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
          if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
        :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
        :param cache_locally: if True, caches all values (stored or found) in a node-local cache
        :param cache_on_store: if True, update cache entries for a key after storing a new item for that key
        :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
          nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
        :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
        :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
          if they are accessed this many seconds before expiration time.
        :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
          all concurrent get requests for the same key will reuse the procedure that is currently in progress
        :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
        :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
        :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
          if False, this node will refuse any incoming request, effectively being only a "client"
        :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
        :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
        :param kwargs: extra parameters used in grpc.aio.server
        """
        self = cls(_initialized_with_create=True)
        self.node_id = node_id = node_id if node_id is not None else DHTID.generate(
        )
        self.num_replicas, self.num_workers, self.chunk_size = num_replicas, num_workers, chunk_size
        self.is_alive = True  # if set to False, cancels all background jobs such as routing table refresh

        self.reuse_get_requests = reuse_get_requests
        self.pending_get_requests = defaultdict(
            partial(SortedList,
                    key=lambda _res: -_res.sufficient_expiration_time))

        # caching policy
        self.refresh_timeout = refresh_timeout
        self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
        self.cache_refresh_before_expiry = cache_refresh_before_expiry
        self.cache_refresh_queue = CacheRefreshQueue()
        self.cache_refresh_evt = asyncio.Event()
        self.cache_refresh_task = None

        self.protocol = await DHTProtocol.create(self.node_id, bucket_size,
                                                 depth_modulo, num_replicas,
                                                 wait_timeout, parallel_rpc,
                                                 cache_size, listen, listen_on,
                                                 **kwargs)
        self.port = self.protocol.port

        if initial_peers:
            # stage 1: ping initial_peers, add each other to the routing table
            bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
            start_time = get_dht_time()
            ping_tasks = map(self.protocol.call_ping, initial_peers)
            finished_pings, unfinished_pings = await asyncio.wait(
                ping_tasks, return_when=asyncio.FIRST_COMPLETED)

            # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
            if unfinished_pings:
                finished_in_time, stragglers = await asyncio.wait(
                    unfinished_pings,
                    timeout=bootstrap_timeout - get_dht_time() + start_time)
                for straggler in stragglers:
                    straggler.cancel()
                finished_pings |= finished_in_time

            if not finished_pings:
                warn(
                    "DHTNode bootstrap failed: none of the initial_peers responded to a ping."
                )

            # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
            # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
            # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
            await asyncio.wait([
                asyncio.create_task(self.find_nearest_nodes([self.node_id])),
                asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)
            ],
                               return_when=asyncio.FIRST_COMPLETED)

        if self.refresh_timeout is not None:
            asyncio.create_task(
                self._refresh_routing_table(period=self.refresh_timeout))
        return self
Ejemplo n.º 24
0
    async def get_many_by_id(
        self,
        key_ids: Collection[DHTID],
        sufficient_expiration_time: Optional[DHTExpiration] = None,
        num_workers: Optional[int] = None,
        beam_size: Optional[int] = None,
        return_futures: bool = False,
        _is_refresh=False
    ) -> Dict[DHTID,
              Union[Optional[ValueWithExpiration[DHTValue]],
                    Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
        """
        Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.

        :param key_ids: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
        :param sufficient_expiration_time: if the search finds a value that expires after this time,
            default = time of call, find any value that did not expire by the time of call
            If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
        :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
        :param num_workers: override for default num_workers, see traverse_dht num_workers param
        :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
         The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
         Note: canceling a future will stop search for the corresponding key
        :param _is_refresh: internal flag, set to True by an internal cache refresher (if enabled)
        :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
        :note: in order to check if get returned a value, please check (expiration_time is None)
        """
        sufficient_expiration_time = sufficient_expiration_time or get_dht_time(
        )
        beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
        num_workers = num_workers if num_workers is not None else self.num_workers
        search_results: Dict[DHTID, _SearchState] = {
            key_id: _SearchState(key_id,
                                 sufficient_expiration_time,
                                 serializer=self.protocol.serializer)
            for key_id in key_ids
        }

        if not _is_refresh:  # if we're already refreshing cache, there's no need to trigger subsequent refreshes
            for key_id in key_ids:
                search_results[key_id].add_done_callback(
                    self._trigger_cache_refresh)

        # if we have concurrent get request for some of the same keys, subscribe to their results
        if self.reuse_get_requests:
            for key_id, search_result in search_results.items():
                self.pending_get_requests[key_id].add(search_result)
                search_result.add_done_callback(
                    self._reuse_finished_search_result)

        # stage 1: check for value in this node's local storage and cache
        for key_id in key_ids:
            search_results[key_id].add_candidate(
                self.protocol.storage.get(key_id), source_node_id=self.node_id)
            if not _is_refresh:
                search_results[key_id].add_candidate(
                    self.protocol.cache.get(key_id),
                    source_node_id=self.node_id)

        # stage 2: traverse the DHT to get the remaining keys from remote peers
        unfinished_key_ids = [
            key_id for key_id in key_ids if not search_results[key_id].finished
        ]
        node_to_endpoint: Dict[
            DHTID, Endpoint] = dict()  # global routing table for all keys
        for key_id in unfinished_key_ids:
            node_to_endpoint.update(
                self.protocol.routing_table.get_nearest_neighbors(
                    key_id, self.protocol.bucket_size, exclude=self.node_id))

        # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
        async def get_neighbors(
            peer: DHTID, queries: Collection[DHTID]
        ) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
            queries = list(queries)
            response = await self.protocol.call_find(node_to_endpoint[peer],
                                                     queries)
            if not response:
                return {query: ([], False) for query in queries}

            output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
            for key_id, (maybe_value_with_expiration,
                         peers) in response.items():
                node_to_endpoint.update(peers)
                search_results[key_id].add_candidate(
                    maybe_value_with_expiration, source_node_id=peer)
                output[key_id] = tuple(
                    peers.keys()), search_results[key_id].finished
                # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
            return output

        # V-- this function will be called exactly once when traverse_dht finishes search for a given key
        async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID],
                                 _visited: Set[DHTID]):
            search_results[key_id].finish_search(
            )  # finish search whether or we found something
            self._cache_new_result(search_results[key_id],
                                   nearest_nodes,
                                   node_to_endpoint,
                                   _is_refresh=_is_refresh)

        asyncio.create_task(
            traverse_dht(queries=list(unfinished_key_ids),
                         initial_nodes=list(node_to_endpoint),
                         beam_size=beam_size,
                         num_workers=num_workers,
                         queries_per_call=min(
                             int(len(unfinished_key_ids)**0.5),
                             self.chunk_size),
                         get_neighbors=get_neighbors,
                         visited_nodes={
                             key_id: {self.node_id}
                             for key_id in unfinished_key_ids
                         },
                         found_callback=found_callback,
                         await_all_tasks=False))

        if return_futures:
            return {
                key_id: search_result.future
                for key_id, search_result in search_results.items()
            }
        else:
            try:
                # note: this should be first time when we await something, there's no need to "try" the entire function
                return {
                    key_id: await search_result.future
                    for key_id, search_result in search_results.items()
                }
            except asyncio.CancelledError as e:  # terminate remaining tasks ASAP
                for key_id, search_result in search_results.items():
                    search_result.future.cancel()
                raise e
Ejemplo n.º 25
0
    async def get_many(
        self,
        keys: Collection[DHTKey],
        sufficient_expiration_time: Optional[DHTExpiration] = None,
        num_workers: Optional[int] = None,
        beam_size: Optional[int] = None
    ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]:
        """
        :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
        :param sufficient_expiration_time: if the search finds a value that expires after this time,
            default = time of call, find any value that did not expire by the time of call
            If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
        :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
        :param num_workers: override for default num_workers, see traverse_dht num_workers param
        :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key
        :note: in order to check if get returned a value, please check (expiration_time is None)
        """
        key_ids = [DHTID.generate(key) for key in keys]
        id_to_original_key = dict(zip(key_ids, keys))
        sufficient_expiration_time = sufficient_expiration_time or get_dht_time(
        )
        beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
        num_workers = num_workers if num_workers is not None else self.num_workers

        # search metadata
        unfinished_key_ids = set(
            key_ids)  # track key ids for which the search is not terminated
        node_to_endpoint: Dict[
            DHTID, Endpoint] = dict()  # global routing table for all queries

        SearchResult = namedtuple(
            "SearchResult",
            ["binary_value", "expiration_time", "source_node_id"])
        latest_results = {
            key_id: SearchResult(b'', -float('inf'), None)
            for key_id in key_ids
        }

        # stage 1: value can be stored in our local cache
        for key_id in key_ids:
            maybe_value, maybe_expiration_time = self.protocol.storage.get(
                key_id)
            if maybe_expiration_time is None:
                maybe_value, maybe_expiration_time = self.protocol.cache.get(
                    key_id)
            if maybe_expiration_time is not None and maybe_expiration_time > latest_results[
                    key_id].expiration_time:
                latest_results[key_id] = SearchResult(maybe_value,
                                                      maybe_expiration_time,
                                                      self.node_id)
                if maybe_expiration_time >= sufficient_expiration_time:
                    unfinished_key_ids.remove(key_id)

        # stage 2: traverse the DHT for any unfinished keys
        for key_id in unfinished_key_ids:
            node_to_endpoint.update(
                self.protocol.routing_table.get_nearest_neighbors(
                    key_id, self.protocol.bucket_size, exclude=self.node_id))

        async def get_neighbors(
            peer: DHTID, queries: Collection[DHTID]
        ) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
            queries = list(queries)
            response = await self.protocol.call_find(node_to_endpoint[peer],
                                                     queries)
            if not response:
                return {query: ([], False) for query in queries}

            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
            for key_id, (maybe_value, maybe_expiration_time,
                         peers) in response.items():
                node_to_endpoint.update(peers)
                if maybe_expiration_time is not None and maybe_expiration_time > latest_results[
                        key_id].expiration_time:
                    latest_results[key_id] = SearchResult(
                        maybe_value, maybe_expiration_time, peer)
                should_interrupt = (latest_results[key_id].expiration_time >=
                                    sufficient_expiration_time)
                output[key_id] = list(peers.keys()), should_interrupt
            return output

        nearest_nodes_per_query, visited_nodes = await traverse_dht(
            queries=list(unfinished_key_ids),
            initial_nodes=list(node_to_endpoint),
            beam_size=beam_size,
            num_workers=num_workers,
            queries_per_call=int(len(unfinished_key_ids)**0.5),
            get_neighbors=get_neighbors,
            visited_nodes={
                key_id: {self.node_id}
                for key_id in unfinished_key_ids
            })

        # stage 3: cache any new results depending on caching parameters
        for key_id, nearest_nodes in nearest_nodes_per_query.items():
            latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[
                key_id]
            should_cache = latest_expiration_time >= sufficient_expiration_time  # if we found a newer value, cache it
            if should_cache and self.cache_locally:
                self.protocol.cache.store(key_id, latest_value_bytes,
                                          latest_expiration_time)

            if should_cache and self.cache_nearest:
                num_cached_nodes = 0
                for node_id in nearest_nodes:
                    if node_id == latest_node_id:
                        continue
                    asyncio.create_task(
                        self.protocol.call_store(node_to_endpoint[node_id],
                                                 [key_id],
                                                 [latest_value_bytes],
                                                 [latest_expiration_time],
                                                 in_cache=True))
                    num_cached_nodes += 1
                    if num_cached_nodes >= self.cache_nearest:
                        break

        # stage 4: deserialize data and assemble function output
        find_result: Dict[DHTKey, Tuple[Optional[DHTValue],
                                        Optional[DHTExpiration]]] = {}
        for key_id, (latest_value_bytes, latest_expiration_time,
                     _) in latest_results.items():
            if latest_expiration_time != -float('inf'):
                latest_value = self.serializer.loads(latest_value_bytes)
                find_result[id_to_original_key[key_id]] = (
                    latest_value, latest_expiration_time)
            else:
                find_result[id_to_original_key[key_id]] = None, None
        return find_result