def _update_cache_on_store(self, key_id: DHTID, subkeys: List[Subkey], binary_values: List[bytes], expirations: List[DHTExpiration], store_ok: List[bool]): """ Update local cache after finishing a store for one key (with perhaps several subkeys) """ store_succeeded = any(store_ok) is_dictionary = any(subkey is not None for subkey in subkeys) if store_succeeded and not is_dictionary: # stored a new regular value, cache it! stored_value_bytes, stored_expiration = max(zip( binary_values, expirations), key=lambda p: p[1]) self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration) elif not store_succeeded and not is_dictionary: # store rejected, check if local cache is also obsolete rejected_value, rejected_expiration = max(zip( binary_values, expirations), key=lambda p: p[1]) if (self.protocol.cache.get(key_id)[1] or float("inf") ) <= rejected_expiration: # cache would be rejected self._schedule_for_refresh(key_id, refresh_time=get_dht_time( )) # fetch new key in background (asap) elif is_dictionary and key_id in self.protocol.cache: # there can be other keys and we should update for subkey, stored_value_bytes, expiration_time in zip( subkeys, binary_values, expirations): self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time) self._schedule_for_refresh(key_id, refresh_time=get_dht_time() ) # fetch new key in background (asap)
def test_maxsize_cache(): d = DHTLocalStorage(maxsize=1) d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1) d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200) assert d.get(DHTID.generate( "key2"))[0] == b"val2", "Value with bigger exp. time must be kept" assert d.get(DHTID.generate( "key1")) is None, "Value with less exp time, must be deleted"
def test_change_expiration_time(): d = DHTLocalStorage() d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1) assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value" d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200) time.sleep(1) assert d.get(DHTID.generate( "key"))[0] == b"val2", "Value must be changed, but still kept in table" print("Test change expiration time passed")
async def _refresh_routing_table(self, *, period: Optional[float]) -> None: """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """ while self.is_alive and period is not None: # if None run once, otherwise run forever refresh_time = get_dht_time() staleness_threshold = refresh_time - period stale_buckets = [ bucket for bucket in self.protocol.routing_table.buckets if bucket.last_updated < staleness_threshold ] for bucket in stale_buckets: refresh_id = DHTID( random.randint(bucket.lower, bucket.upper - 1)) await self.find_nearest_nodes(refresh_id) await asyncio.sleep( max(0.0, period - (get_dht_time() - refresh_time)))
def _remove_outdated(self): while not self.frozen and self.expiration_heap and ( self.expiration_heap[ROOT].expiration_time < get_dht_time() or len(self.expiration_heap) > self.maxsize): heap_entry = heapq.heappop(self.expiration_heap) if self.key_to_heap.get(heap_entry.key) == heap_entry: del self.data[heap_entry.key], self.key_to_heap[heap_entry.key]
def test_get_expired(): d = DHTLocalStorage() d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1) time.sleep(0.5) assert d.get( DHTID.generate("key")) is None, "Expired value must be deleted" print("Test get expired passed")
async def _declare_experts( self, node: DHTNode, uids: List[ExpertUID], endpoint: Endpoint, future: Optional[MPFuture]) -> Dict[ExpertUID, bool]: num_workers = len(uids) if self.max_workers is None else min( len(uids), self.max_workers) expiration_time = get_dht_time() + self.expiration data_to_store: Dict[Tuple[ExpertPrefix, Optional[Coordinate]], DHTValue] = {} for uid in uids: data_to_store[uid, None] = endpoint prefix = uid if uid.count( UID_DELIMITER) > 1 else f'{uid}{UID_DELIMITER}{FLAT_EXPERT}' for i in range(prefix.count(UID_DELIMITER) - 1): prefix, last_coord = split_uid(prefix) data_to_store[prefix, last_coord] = [uid, endpoint] keys, maybe_subkeys, values = zip( *((key, subkey, value) for (key, subkey), value in data_to_store.items())) store_ok = await node.store_many(keys, values, expiration_time, subkeys=maybe_subkeys, num_workers=num_workers) if future is not None: future.set_result(store_ok) return store_ok
async def _get_experts( self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture): if expiration_time is None: expiration_time = get_dht_time() num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers) response = await node.get_many(uids, expiration_time, num_workers=num_workers) future.set_result([RemoteExpert(**expert_data) if maybe_expiration_time else None for uid, (expert_data, maybe_expiration_time) in response.items()])
def remove_outdated(self): while self.expiration_heap and ( self.expiration_heap[0][0] < get_dht_time() or len(self.expiration_heap) > self.cache_size): heap_entry = heapq.heappop(self.expiration_heap) key = heap_entry[1] if self.key_to_heap[key] == heap_entry: del self.data[key], self.key_to_heap[key]
def test_localstorage_freeze(): d = DHTLocalStorage(maxsize=2) with d.freeze(): d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 0.01) assert DHTID.generate("key1") in d time.sleep(0.03) assert DHTID.generate("key1") in d assert DHTID.generate("key1") not in d with d.freeze(): d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1) d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2) d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 3) # key3 will push key1 out due to maxsize assert DHTID.generate("key1") in d assert DHTID.generate("key1") not in d
async def _refresh_stale_cache_entries(self): """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """ while self.is_alive: while len(self.cache_refresh_queue) == 0: await self.cache_refresh_evt.wait() self.cache_refresh_evt.clear() key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top() try: # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires time_to_wait = nearest_refresh_time - get_dht_time() await asyncio.wait_for(self.cache_refresh_evt.wait(), timeout=time_to_wait) # note: the line above will cause TimeoutError when we are ready to refresh cache self.cache_refresh_evt.clear( ) # no timeout error => someone added new entry to queue and ... continue # ... and this element is earlier than nearest_expiration. we should refresh this entry first except asyncio.TimeoutError: # caught TimeoutError => it is time to refresh the most recent cached entry # step 2: find all keys that we should already refresh and remove them from queue current_time = get_dht_time() keys_to_refresh = {key_id} max_expiration_time = self.protocol.cache.get( key_id)[1] or current_time del self.cache_refresh_queue[ key_id] # we pledge to refresh this key_id in the nearest batch while self.cache_refresh_queue and len( keys_to_refresh) < self.chunk_size: key_id, ( _, nearest_refresh_time) = self.cache_refresh_queue.top() if nearest_refresh_time > current_time: break del self.cache_refresh_queue[ key_id] # we pledge to refresh this key_id in the nearest batch keys_to_refresh.add(key_id) cached_item = self.protocol.cache.get(key_id) if cached_item is not None and cached_item.expiration_time > max_expiration_time: max_expiration_time = cached_item.expiration_time # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1 await self.get_many_by_id(keys_to_refresh, sufficient_expiration_time, _is_refresh=True)
async def _refresh_stale_cache_entries(self): """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """ while self.is_alive: with self.cache_refresh_queue.freeze(): while len(self.cache_refresh_queue) == 0: await self.cache_refresh_available.wait() self.cache_refresh_available.clear() key_id, _, nearest_expiration = self.cache_refresh_queue.top() try: # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires time_to_wait = nearest_expiration - get_dht_time( ) - self.cache_refresh_before_expiry await asyncio.wait_for(self.cache_refresh_available.wait(), timeout=time_to_wait) # note: the line above will cause TimeoutError when we are ready to refresh cache self.cache_refresh_available.clear( ) # no timeout error => someone added new entry to queue and ... continue # ... and this element is earlier than nearest_expiration. we should refresh this entry first except asyncio.TimeoutError: # caught TimeoutError => it is time to refresh the most recent cached entry # step 2: find all keys that we should already refresh and remove them from queue with self.cache_refresh_queue.freeze(): keys_to_refresh = {key_id} del self.cache_refresh_queue[ key_id] # we pledge to refresh this key_id in the nearest batch while self.cache_refresh_queue: key_id, _, nearest_expiration = self.cache_refresh_queue.top( ) if nearest_expiration > get_dht_time( ) + self.cache_refresh_before_expiry: break del self.cache_refresh_queue[ key_id] # we pledge to refresh this key_id in the nearest batch keys_to_refresh.add(key_id) # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id await self.get_many_by_id( keys_to_refresh, sufficient_expiration_time=nearest_expiration + self.cache_refresh_before_expiry, _refresh_cache=False ) # if we found value locally, we shouldn't trigger another refresh
async def _get_experts(self, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration], future: Optional[MPFuture] = None) -> List[Optional[RemoteExpert]]: if expiration_time is None: expiration_time = get_dht_time() num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers) found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers) experts: List[Optional[RemoteExpert]] = [None] * len(uids) for i, uid in enumerate(uids): if found[uid] is not None and isinstance(found[uid].value, Endpoint): experts[i] = RemoteExpert(uid, found[uid].value) if future: future.set_result(experts) return experts
async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]): num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers) expiration_time = get_dht_time() + self.expiration data_to_store = {} for uid in uids: uid_parts = uid.split(self.UID_DELIMITER) for i in range(len(uid_parts)): uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1]) data_to_store[uid_prefix_i] = {'uid': uid, 'endpoint': endpoint} store_keys, store_values = zip(*data_to_store.items()) store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers) if future is not None: future.set_result([store_ok[key] for key in data_to_store.keys()])
def test_localstorage_serialize(): d1 = DictionaryDHTValue() d2 = DictionaryDHTValue() now = get_dht_time() d1.store('key1', b'ololo', now + 1) d2.store('key2', b'pysh', now + 1) d2.store('key3', b'pyshpysh', now + 2) data = MSGPackSerializer.dumps([d1, d2, 123321]) assert isinstance(data, bytes) new_d1, new_d2, new_value = MSGPackSerializer.loads(data) assert isinstance(new_d1, DictionaryDHTValue) and isinstance( new_d2, DictionaryDHTValue) and new_value == 123321 assert 'key1' in new_d1 and len(new_d1) == 1 assert 'key1' not in new_d2 and len(new_d2) == 2 and new_d2.get( 'key3') == (b'pyshpysh', now + 2)
def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool: """ Store a (key, value) pair locally at least until expiration_time. See class docstring for details. :returns: True if new value was stored, False it was rejected (current value is newer) """ if expiration_time < get_dht_time(): return False self.key_to_heap[key] = (expiration_time, key) heapq.heappush(self.expiration_heap, (expiration_time, key)) if key in self.data: if self.data[key][1] < expiration_time: self.data[key] = (value, expiration_time) return True return False self.data[key] = (value, expiration_time) self.remove_outdated() return True
def store(self, key: KeyType, value: ValueType, expiration_time: DHTExpiration) -> bool: """ Store a (key, value) pair locally at least until expiration_time. See class docstring for details. :returns: True if new value was stored, False it was rejected (current value is newer) """ if expiration_time < get_dht_time() and not self.frozen: return False self.key_to_heap[key] = HeapEntry(expiration_time, key) heapq.heappush(self.expiration_heap, self.key_to_heap[key]) if key in self.data: if self.data[key].expiration_time < expiration_time: self.data[key] = ValueWithExpiration(value, expiration_time) return True return False self.data[key] = ValueWithExpiration(value, expiration_time) self._remove_outdated() return True
def test_localstorage_nested(): time = get_dht_time() d1 = DHTLocalStorage() d2 = DictionaryDHTValue() d2.store('subkey1', b'value1', time + 2) d2.store('subkey2', b'value2', time + 3) d2.store('subkey3', b'value3', time + 1) assert d2.latest_expiration_time == time + 3 for subkey, (subvalue, subexpiration) in d2.items(): assert d1.store_subkey(DHTID.generate('foo'), subkey, subvalue, subexpiration) assert d1.store(DHTID.generate('bar'), b'456', time + 2) assert d1.get(DHTID.generate('foo'))[0].data == d2.data assert d1.get(DHTID.generate('foo'))[1] == d2.latest_expiration_time assert d1.get(DHTID.generate('foo'))[0].get('subkey1') == (b'value1', time + 2) assert len(d1.get(DHTID.generate('foo'))[0]) == 3 assert d1.store_subkey(DHTID.generate('foo'), 'subkey4', b'value4', time + 4) assert len(d1.get(DHTID.generate('foo'))[0]) == 4 assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA', time + 1) is False # prev has better expiration assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA', time + 3) # new value has better expiration assert d1.store_subkey(DHTID.generate('bar'), 'subkeyB', b'valueB', time + 4) # new value has better expiration assert d1.store_subkey(DHTID.generate('bar'), 'subkeyA', b'valueA+', time + 5) # overwrite subkeyA under key bar assert all(subkey in d1.get(DHTID.generate('bar'))[0] for subkey in ('subkeyA', 'subkeyB')) assert len(d1.get(DHTID.generate('bar'))[0]) == 2 and d1.get( DHTID.generate('bar'))[1] == time + 5 assert d1.store(DHTID.generate('foo'), b'nothing', time + 3.5) is False # previous value has better expiration assert d1.get(DHTID.generate('foo'))[0].get('subkey2') == (b'value2', time + 3) assert d1.store(DHTID.generate('foo'), b'nothing', time + 5) is True # new value has better expiraiton assert d1.get(DHTID.generate('foo')) == (b'nothing', time + 5 ) # value should be replaced
async def _get_active_successors(self, node: DHTNode, prefixes: List[ExpertPrefix], grid_size: Optional[int] = None, num_workers: Optional[int] = None, future: Optional[MPFuture] = None ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]: grid_size = grid_size or float('inf') num_workers = num_workers or min(len(prefixes), self.max_workers or len(prefixes)) dht_responses = await node.get_many(keys=prefixes, num_workers=num_workers) successors: Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]] = {} for prefix, found in dht_responses.items(): if found and isinstance(found.value, dict): successors[prefix] = {coord: UidEndpoint(*match.value) for coord, match in found.value.items() if isinstance(coord, Coordinate) and 0 <= coord < grid_size and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2} else: successors[prefix] = {} if found is None and self.negative_caching: logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {prefix}") asyncio.create_task(node.store(prefix, subkey=-1, value=None, expiration_time=get_dht_time() + self.expiration)) if future: future.set_result(successors) return successors
def test_localstorage_top(): d = DHTLocalStorage(maxsize=3) d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1) d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 2) d.store(DHTID.generate("key3"), b"val3", get_dht_time() + 4) assert d.top()[0] == DHTID.generate("key1") and d.top()[1].value == b"val1" d.store(DHTID.generate("key1"), b"val1_new", get_dht_time() + 3) assert d.top()[0] == DHTID.generate("key2") and d.top()[1].value == b"val2" del d[DHTID.generate('key2')] assert d.top()[0] == DHTID.generate( "key1") and d.top()[1].value == b"val1_new" d.store(DHTID.generate("key2"), b"val2_new", get_dht_time() + 5) d.store(DHTID.generate("key4"), b"val4", get_dht_time() + 6) # key4 will push out key1 due to maxsize assert d.top()[0] == DHTID.generate("key3") and d.top()[1].value == b"val3"
async def _get_initial_beam(self, node, prefix: ExpertPrefix, beam_size: int, scores: Tuple[float, ...], num_workers: Optional[int] = None, future: Optional[MPFuture] = None ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]: num_workers = num_workers or self.max_workers or beam_size beam: List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]] = [] unattempted_indices: List[Coordinate] = sorted(range(len(scores)), key=scores.__getitem__) # from worst to best pending_tasks: Deque[Tuple[Coordinate, ExpertPrefix, asyncio.Task]] = deque() while len(beam) < beam_size and (unattempted_indices or pending_tasks): # dispatch additional tasks while unattempted_indices and len(pending_tasks) < num_workers: next_index = unattempted_indices.pop() # note: this is best unattempted index because of sort order next_best_prefix = f"{prefix}{next_index}{UID_DELIMITER}" pending_tasks.append((next_index, next_best_prefix, asyncio.create_task(node.get(next_best_prefix)))) # await the next best prefix to be fetched pending_best_index, pending_best_prefix, pending_task = pending_tasks.popleft() try: maybe_prefix_data = await pending_task if maybe_prefix_data is not None and isinstance(maybe_prefix_data.value, dict): successors = {coord: UidEndpoint(*match.value) for coord, match in maybe_prefix_data.value.items() if isinstance(coord, Coordinate) and isinstance(getattr(match, 'value', None), list) and len(match.value) == 2} if successors: beam.append((scores[pending_best_index], pending_best_prefix, successors)) elif maybe_prefix_data is None and self.negative_caching: logger.debug(f"DHT negative caching: storing a 'no prefix' entry for {pending_best_prefix}") asyncio.create_task(node.store(pending_best_prefix, subkey=-1, value=None, expiration_time=get_dht_time() + self.expiration)) except asyncio.CancelledError: for _, pending_task in pending_tasks: pending_task.cancel() raise if future: future.set_result(beam) return beam
def test_store(): d = DHTLocalStorage() d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5) assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value" print("Test store passed")
async def create(cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (), bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None, wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None, cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5, cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16, listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode: """ :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id :param initial_peers: connects to these peers to populate routing table, defaults to no peers :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to either split in two buckets along the midpoint or reject the new node (but still save it as a replacement) Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k) :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b) :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol Reduce this value if your RPC requests register no response despite the peer sending the response. :param wait_timeout: a kademlia rpc request is deemed lost if we did not receive a reply in this many seconds :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay) :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds :param cache_locally: if True, caches all values (stored or found) in a node-local cache :param cache_on_store: if True, update cache entries for a key after storing a new item for that key :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet :param cache_size: if specified, local cache will store up to this many records (as in LRU cache) :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values if they are accessed this many seconds before expiration time. :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key all concurrent get requests for the same key will reuse the procedure that is currently in progress :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param) :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen" if False, this node will refuse any incoming request, effectively being only a "client" :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654" :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)] see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options :param kwargs: extra parameters used in grpc.aio.server """ self = cls(_initialized_with_create=True) self.node_id = node_id = node_id if node_id is not None else DHTID.generate( ) self.num_replicas, self.num_workers, self.chunk_size = num_replicas, num_workers, chunk_size self.is_alive = True # if set to False, cancels all background jobs such as routing table refresh self.reuse_get_requests = reuse_get_requests self.pending_get_requests = defaultdict( partial(SortedList, key=lambda _res: -_res.sufficient_expiration_time)) # caching policy self.refresh_timeout = refresh_timeout self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store self.cache_refresh_before_expiry = cache_refresh_before_expiry self.cache_refresh_queue = CacheRefreshQueue() self.cache_refresh_evt = asyncio.Event() self.cache_refresh_task = None self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout, parallel_rpc, cache_size, listen, listen_on, **kwargs) self.port = self.protocol.port if initial_peers: # stage 1: ping initial_peers, add each other to the routing table bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout start_time = get_dht_time() ping_tasks = map(self.protocol.call_ping, initial_peers) finished_pings, unfinished_pings = await asyncio.wait( ping_tasks, return_when=asyncio.FIRST_COMPLETED) # stage 2: gather remaining peers (those who respond within bootstrap_timeout) if unfinished_pings: finished_in_time, stragglers = await asyncio.wait( unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time) for straggler in stragglers: straggler.cancel() finished_pings |= finished_in_time if not finished_pings: warn( "DHTNode bootstrap failed: none of the initial_peers responded to a ping." ) # stage 3: traverse dht to find my own nearest neighbors and populate the routing table # ... maybe receive some values that we are meant to store (see protocol.update_routing_table) # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout await asyncio.wait([ asyncio.create_task(self.find_nearest_nodes([self.node_id])), asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time) ], return_when=asyncio.FIRST_COMPLETED) if self.refresh_timeout is not None: asyncio.create_task( self._refresh_routing_table(period=self.refresh_timeout)) return self
async def get_many_by_id( self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None, num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False, _is_refresh=False ) -> Dict[DHTID, Union[Optional[ValueWithExpiration[DHTValue]], Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]: """ Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found. :param key_ids: traverse the DHT and find the value for each of these keys (or (None, None) if not key found) :param sufficient_expiration_time: if the search finds a value that expires after this time, default = time of call, find any value that did not expire by the time of call If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size :param num_workers: override for default num_workers, see traverse_dht num_workers param :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework. The algorithm will populate these futures with (value, expiration) when it finds the corresponding key Note: canceling a future will stop search for the corresponding key :param _is_refresh: internal flag, set to True by an internal cache refresher (if enabled) :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key :note: in order to check if get returned a value, please check (expiration_time is None) """ sufficient_expiration_time = sufficient_expiration_time or get_dht_time( ) beam_size = beam_size if beam_size is not None else self.protocol.bucket_size num_workers = num_workers if num_workers is not None else self.num_workers search_results: Dict[DHTID, _SearchState] = { key_id: _SearchState(key_id, sufficient_expiration_time, serializer=self.protocol.serializer) for key_id in key_ids } if not _is_refresh: # if we're already refreshing cache, there's no need to trigger subsequent refreshes for key_id in key_ids: search_results[key_id].add_done_callback( self._trigger_cache_refresh) # if we have concurrent get request for some of the same keys, subscribe to their results if self.reuse_get_requests: for key_id, search_result in search_results.items(): self.pending_get_requests[key_id].add(search_result) search_result.add_done_callback( self._reuse_finished_search_result) # stage 1: check for value in this node's local storage and cache for key_id in key_ids: search_results[key_id].add_candidate( self.protocol.storage.get(key_id), source_node_id=self.node_id) if not _is_refresh: search_results[key_id].add_candidate( self.protocol.cache.get(key_id), source_node_id=self.node_id) # stage 2: traverse the DHT to get the remaining keys from remote peers unfinished_key_ids = [ key_id for key_id in key_ids if not search_results[key_id].finished ] node_to_endpoint: Dict[ DHTID, Endpoint] = dict() # global routing table for all keys for key_id in unfinished_key_ids: node_to_endpoint.update( self.protocol.routing_table.get_nearest_neighbors( key_id, self.protocol.bucket_size, exclude=self.node_id)) # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer async def get_neighbors( peer: DHTID, queries: Collection[DHTID] ) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]: queries = list(queries) response = await self.protocol.call_find(node_to_endpoint[peer], queries) if not response: return {query: ([], False) for query in queries} output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {} for key_id, (maybe_value_with_expiration, peers) in response.items(): node_to_endpoint.update(peers) search_results[key_id].add_candidate( maybe_value_with_expiration, source_node_id=peer) output[key_id] = tuple( peers.keys()), search_results[key_id].finished # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user) return output # V-- this function will be called exactly once when traverse_dht finishes search for a given key async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]): search_results[key_id].finish_search( ) # finish search whether or we found something self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh) asyncio.create_task( traverse_dht(queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers, queries_per_call=min( int(len(unfinished_key_ids)**0.5), self.chunk_size), get_neighbors=get_neighbors, visited_nodes={ key_id: {self.node_id} for key_id in unfinished_key_ids }, found_callback=found_callback, await_all_tasks=False)) if return_futures: return { key_id: search_result.future for key_id, search_result in search_results.items() } else: try: # note: this should be first time when we await something, there's no need to "try" the entire function return { key_id: await search_result.future for key_id, search_result in search_results.items() } except asyncio.CancelledError as e: # terminate remaining tasks ASAP for key_id, search_result in search_results.items(): search_result.future.cancel() raise e
async def get_many( self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None, num_workers: Optional[int] = None, beam_size: Optional[int] = None ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]: """ :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found) :param sufficient_expiration_time: if the search finds a value that expires after this time, default = time of call, find any value that did not expire by the time of call If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size :param num_workers: override for default num_workers, see traverse_dht num_workers param :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key :note: in order to check if get returned a value, please check (expiration_time is None) """ key_ids = [DHTID.generate(key) for key in keys] id_to_original_key = dict(zip(key_ids, keys)) sufficient_expiration_time = sufficient_expiration_time or get_dht_time( ) beam_size = beam_size if beam_size is not None else self.protocol.bucket_size num_workers = num_workers if num_workers is not None else self.num_workers # search metadata unfinished_key_ids = set( key_ids) # track key ids for which the search is not terminated node_to_endpoint: Dict[ DHTID, Endpoint] = dict() # global routing table for all queries SearchResult = namedtuple( "SearchResult", ["binary_value", "expiration_time", "source_node_id"]) latest_results = { key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids } # stage 1: value can be stored in our local cache for key_id in key_ids: maybe_value, maybe_expiration_time = self.protocol.storage.get( key_id) if maybe_expiration_time is None: maybe_value, maybe_expiration_time = self.protocol.cache.get( key_id) if maybe_expiration_time is not None and maybe_expiration_time > latest_results[ key_id].expiration_time: latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, self.node_id) if maybe_expiration_time >= sufficient_expiration_time: unfinished_key_ids.remove(key_id) # stage 2: traverse the DHT for any unfinished keys for key_id in unfinished_key_ids: node_to_endpoint.update( self.protocol.routing_table.get_nearest_neighbors( key_id, self.protocol.bucket_size, exclude=self.node_id)) async def get_neighbors( peer: DHTID, queries: Collection[DHTID] ) -> Dict[DHTID, Tuple[List[DHTID], bool]]: queries = list(queries) response = await self.protocol.call_find(node_to_endpoint[peer], queries) if not response: return {query: ([], False) for query in queries} output: Dict[DHTID, Tuple[List[DHTID], bool]] = {} for key_id, (maybe_value, maybe_expiration_time, peers) in response.items(): node_to_endpoint.update(peers) if maybe_expiration_time is not None and maybe_expiration_time > latest_results[ key_id].expiration_time: latest_results[key_id] = SearchResult( maybe_value, maybe_expiration_time, peer) should_interrupt = (latest_results[key_id].expiration_time >= sufficient_expiration_time) output[key_id] = list(peers.keys()), should_interrupt return output nearest_nodes_per_query, visited_nodes = await traverse_dht( queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids)**0.5), get_neighbors=get_neighbors, visited_nodes={ key_id: {self.node_id} for key_id in unfinished_key_ids }) # stage 3: cache any new results depending on caching parameters for key_id, nearest_nodes in nearest_nodes_per_query.items(): latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[ key_id] should_cache = latest_expiration_time >= sufficient_expiration_time # if we found a newer value, cache it if should_cache and self.cache_locally: self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration_time) if should_cache and self.cache_nearest: num_cached_nodes = 0 for node_id in nearest_nodes: if node_id == latest_node_id: continue asyncio.create_task( self.protocol.call_store(node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration_time], in_cache=True)) num_cached_nodes += 1 if num_cached_nodes >= self.cache_nearest: break # stage 4: deserialize data and assemble function output find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {} for key_id, (latest_value_bytes, latest_expiration_time, _) in latest_results.items(): if latest_expiration_time != -float('inf'): latest_value = self.serializer.loads(latest_value_bytes) find_result[id_to_original_key[key_id]] = ( latest_value, latest_expiration_time) else: find_result[id_to_original_key[key_id]] = None, None return find_result