예제 #1
0
 async def _set_group_bits(self, group_bits: str, future: MPFuture):
     try:
         self._matchmaking.group_key_manager.group_bits = group_bits
         return future.set_result(None)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
예제 #2
0
 async def _get_experts(
         self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
     if expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
     response = await node.get_many(uids, expiration_time, num_workers=num_workers)
     future.set_result([RemoteExpert(**expert_data) if maybe_expiration_time else None
                        for uid, (expert_data, maybe_expiration_time) in response.items()])
예제 #3
0
    def first_k_active(
        self,
        uid_prefixes: List[str],
        k: int,
        max_prefetch: int = 1,
        chunk_size: Optional[int] = None,
        return_future=False
    ) -> Union[TOrderedDict[str, RemoteExpert], Awaitable[TOrderedDict[
            str, RemoteExpert]]]:
        """
        Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search

        :param uid_prefixes: a list of uid prefixes ordered from highest to lowest priority
        :param k: return at most *this many* active prefixes
        :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
        :param chunk_size: dispatch this many requests in one task
        :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
        :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
            The keys in the returned dict are ordered same as in uid_prefixes.
        """
        assert not isinstance(
            uid_prefixes, str
        ), "please provide a list/tuple of prefixes as the first argument"
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_first_k_active', [],
                        dict(uid_prefixes=uid_prefixes,
                             k=k,
                             max_prefetch=max_prefetch,
                             chunk_size=chunk_size or k,
                             future=_future)))
        return future if return_future else future.result()
예제 #4
0
    def declare_averager(self,
                         group_key: GroupKey,
                         endpoint: Endpoint,
                         expiration_time: float,
                         *,
                         looking_for_group: bool = True,
                         return_future: bool = False) -> Union[bool, MPFuture]:
        """
        Add (or remove) the averager to a given allreduce bucket

        :param group_key: allreduce group key, e.g. my_averager.0b011011101
        :param endpoint: averager public endpoint for incoming requests
        :param expiration_time: intent to run allreduce before this timestamp
        :param looking_for_group: by default (True), declare the averager as "looking for group" in a given group;
          If False, this will instead mark that the averager as no longer looking for group, (e.g. it already finished)
        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
        :return: True if declared, False if declaration was rejected by DHT peers
        :note: when leaving (i.e. is_active=False), please specify the same expiration_time as when entering the group
        :note: setting is_active=False does *not* guarantee that others will immediately stop to query you.
        """
        assert is_valid_group(
            group_key
        ), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_declare_averager', [],
                        dict(group_key=group_key,
                             endpoint=endpoint,
                             expiration_time=expiration_time,
                             looking_for_group=looking_for_group,
                             future=_future)))
        return future if return_future else future.result()
예제 #5
0
    def step(
            self,
            allow_retries: bool = True,
            gather: Optional[DataForGather] = None,
            timeout: Optional[float] = None,
            wait=True
    ) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
        """
        Set up the averager to look for a group and run one round of averaging, return True on success, False on failure

        :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
          within the specified timeout
        :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
          (this operation is known as all-gather). The gathered data will be available as the output of this function.
        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
        :returns: on success, update averaged_tensors and return group info; on failure, return None
        """
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_step', [],
                        dict(future=_future,
                             gather=gather,
                             allow_retries=allow_retries,
                             timeout=timeout)))
        return future.result() if wait else future
예제 #6
0
    def store(self,
              key: DHTKey,
              value: DHTValue,
              expiration_time: DHTExpiration,
              subkey: Optional[Subkey] = None,
              return_future: bool = False,
              **kwargs) -> Union[bool, MPFuture]:
        """
        Find num_replicas best nodes to store (key, value) and store it there until expiration time.

        :param key: msgpack-serializable key to be associated with value until expiration.
        :param value: msgpack-serializable value to be stored under a given key until expiration.
        :param expiration_time: absolute time when the entry should expire, based on hivemind.get_dht_time()
        :param subkey: if specified, add a value under that subkey instead of overwriting key (see DHTNode.store_many)
        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
        :returns: True if store succeeds, False if it fails (due to no response or newer value)
        """
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_store', [],
                        dict(key=key,
                             value=value,
                             expiration_time=expiration_time,
                             subkey=subkey,
                             future=_future,
                             **kwargs)))
        return future if return_future else future.result()
예제 #7
0
    def get_averagers(
        self,
        group_key: GroupKey,
        *,
        only_active: bool = True,
        return_future: bool = False
    ) -> Union[List[Tuple[Endpoint, DHTExpiration]], MPFuture]:
        """
        Find and return averagers in a specified all-reduce bucket

        :param group_key: finds averagers that have the this group key, e.g. my_averager.0b011011101
        :param only_active: if True, return only active averagers that are looking for group (i.e. with value = True)
            if False, return all averagers under a given group_key regardless of value
        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
        :return: endpoints and expirations of every matching averager
        """
        assert is_valid_group(
            group_key
        ), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_get_averagers', [],
                        dict(group_key=group_key,
                             only_active=only_active,
                             future=_future)))
        return future if return_future else future.result()
예제 #8
0
    def step(
        self,
        gather: Optional[DataForGather] = None,
        weight: float = 1.0,
        timeout: Optional[float] = None,
        allow_retries: bool = True,
        wait: bool = True
    ) -> Union[Optional[Dict[Endpoint, DataForGather]], MPFuture]:
        """
        Set up the averager to look for a group and run one round of averaging, return True on success, False on failure

        :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
          (this operation is known as all-gather). The gathered data will be available as the output of this function.
        :param weight: averaging weight for this peer, int or float, must be strictly positive
        :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
          within the specified timeout
        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
        :returns: on success, update averaged_tensors and return group info; on failure, return None
        """
        assert isinstance(
            weight, (int, float)
        ) and weight > 0, f"Expected a positive int/float, got {type(weight)}"
        future, _future = MPFuture.make_pair()
        gather_binary = self.serializer.dumps(
            gather
        )  # serialize here to avoid loading modules in the averager process
        self.pipe.send(('_step', [],
                        dict(future=_future,
                             gather_binary=gather_binary,
                             weight=weight,
                             allow_retries=allow_retries,
                             timeout=timeout)))
        return future.result() if wait else future
예제 #9
0
    def batch_find_best_experts(
            self,
            prefix: str,
            batch_grid_scores: Sequence[Sequence[Sequence[float]]],
            beam_size: int,
            *,
            workers_per_sample: Optional[int] = None,
            return_future=False) -> Union[List[List[RemoteExpert]], MPFuture]:
        """
        Find and return :beam_size: active experts with highest scores, use both local cache and DHT

        :param prefix: common prefix for all expert uids in grid
        :param batch_grid_scores: scores predicted for each batch example and each dimension in the grid,
        :type batch_grid_scores: list of arrays of shape (batch_size, grid_size[i])
        :param beam_size: how many best experts should beam search return
         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
         Please note that any queries that fall outside the budget will still be performed in background and cached
         for subsequent iterations as long as DHTNode.cache_locally is True
        :param workers_per_sample: use up to this many concurrent workers for every sample in batch
        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
        :returns: a list that contains *up to* k_best RemoteExpert instances
        """
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_batch_find_best_experts', [],
                        dict(prefix=prefix,
                             batch_grid_scores=batch_grid_scores,
                             beam_size=beam_size,
                             workers_per_sample=workers_per_sample,
                             future=_future)))
        return future if return_future else future.result()
예제 #10
0
 def get_active_successors(
     self,
     prefixes: List[ExpertPrefix],
     grid_size: Optional[int] = None,
     num_workers: Optional[int] = None,
     return_future: bool = False
 ) -> Dict[ExpertPrefix, Dict[Coordinate, UidEndpoint]]:
     """
     :param prefixes: a list of prefix for which to find active successor uids
     :param grid_size: if specified, only return successors if ther are in range [0, grid_size)
     :param num_workers: how many parallel workers to use for DHTNode.get_many
     :param return_future: if False (default), find and return successors. Otherwise return MPFuture and fill later.
     :returns: for every expert, return a dict{active_next_coordinate: (matching_expert_uid, matching_endpoint)}
     :note: if a prefix is not found, get_active_successors will return an empty dictionary for that prefix
     """
     assert not isinstance(
         prefixes, str), "Please send a list / tuple of expert prefixes."
     for prefix in prefixes:
         assert is_valid_prefix(
             prefix
         ), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
     future, _future = MPFuture.make_pair()
     self.pipe.send(('_get_active_successors', [],
                     dict(prefixes=list(prefixes),
                          grid_size=grid_size,
                          num_workers=num_workers,
                          future=_future)))
     return future if return_future else future.result()
예제 #11
0
    def find_best_experts(
            self,
            prefix: ExpertPrefix,
            grid_scores: Sequence[Sequence[float]],
            beam_size: int,
            num_workers: Optional[int] = None,
            return_future: bool = False
    ) -> Union[List[RemoteExpert], MPFuture]:
        """
        Find and return :beam_size: active experts with highest scores, use both local cache and DHT

        :param prefix: common prefix for all expert uids in grid
        :param grid_scores: scores predicted for each dimension in the grid,
        :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
        :param beam_size: how many best experts should beam search return
         After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
         Please note that any queries that fall outside the budget will still be performed in background and cached
         for subsequent iterations as long as DHTNode.cache_locally is True
        :param num_workers: use up to this many concurrent workers to search DHT
        :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
        :returns: a list that contains *up to* k_best RemoteExpert instances
        """
        assert len(grid_scores) > 0 and beam_size > 0
        assert is_valid_prefix(
            prefix
        ), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_find_best_experts', [],
                        dict(prefix=prefix,
                             grid_scores=list(map(tuple, grid_scores)),
                             beam_size=beam_size,
                             num_workers=num_workers,
                             future=_future)))
        return future if return_future else future.result()
예제 #12
0
 def get_initial_beam(
     self,
     prefix: ExpertPrefix,
     scores: Sequence[float],
     beam_size: int,
     num_workers: Optional[int] = None,
     return_future: bool = False
 ) -> List[Tuple[Score, ExpertPrefix, Dict[Coordinate, UidEndpoint]]]:
     """
     :param prefix: search for experts whose uids start with this prefix
     :param scores: prefer suffix coordinates that have highest scores
     :param beam_size: select this many active suffixes with highest scores
     :param num_workers: maintain up to this many concurrent DHT searches
     :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
     :returns: a list of up to beam_size tuples of (prefix score, prefix itself, dict{suffix: example expert})
     """
     assert is_valid_prefix(
         prefix
     ), f"prefix '{prefix}' is invalid, it must follow {PREFIX_PATTERN.pattern}"
     future, _future = MPFuture.make_pair()
     self.pipe.send(('_get_initial_beam', [],
                     dict(prefix=prefix,
                          scores=tuple(scores),
                          beam_size=beam_size,
                          num_workers=num_workers,
                          future=_future)))
     return future if return_future else future.result()
예제 #13
0
    def declare_experts(
            self,
            uids: Sequence[ExpertUID],
            endpoint: Endpoint,
            wait: bool = True,
            timeout: Optional[float] = None) -> Dict[ExpertUID, bool]:
        """
        Make experts visible to all DHT peers; update timestamps if declared previously.

        :param uids: a list of expert ids to update
        :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
        :param wait: if True, awaits for declaration to finish, otherwise runs in background
        :param timeout: waits for the procedure to finish for up to this long, None means wait indefinitely
        :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
        """
        assert not isinstance(
            uids, str), "Please send a list / tuple of expert uids."
        for uid in uids:
            assert is_valid_uid(
                uid
            ), f"{uid} is not a valid expert uid. All uids must follow {UID_PATTERN.pattern}"
        future, _future = MPFuture.make_pair() if wait else (None, None)
        self.pipe.send(('_declare_experts', [],
                        dict(uids=list(uids),
                             endpoint=endpoint,
                             future=_future)))
        if wait:
            return future.result(timeout)
예제 #14
0
 async def _declare_averager(self, node: DHTNode, *, group_key: str,
                             endpoint: Endpoint,
                             expiration_time: DHTExpiration,
                             looking_for_group: bool, future: MPFuture):
     try:
         expiration_time = expiration_time if looking_for_group else float(
             nextafter(expiration_time, float('inf')))
         # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
         store_ok = await node.store(key=group_key,
                                     subkey=endpoint,
                                     value=looking_for_group,
                                     expiration_time=expiration_time)
         future.set_result(store_ok)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
예제 #15
0
 def get_group_bits(self, wait: bool = True):
     """
     :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
     :returns: averager's current group key bits (without prefix)
     """
     future, _future = MPFuture.make_pair()
     self.pipe.send(('_get_group_bits', [], dict(future=_future)))
     return future.result() if wait else future
예제 #16
0
    async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        while not future.done():
            try:
                self._pending_group_assembled.clear()
                gather_binary = self.serializer.dumps(gather)
                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                if allreduce_group is None:
                    raise AllreduceException("Averaging step failed: could not find a group.")

                group_id = allreduce_group.group_id
                self._running_groups[group_id] = allreduce_group
                self._pending_group_assembled.set()
                await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
                await loop.run_in_executor(None, self.update_tensors, allreduce_group)

                # averaging is finished, exit the loop
                gathered_items = map(self.serializer.loads, allreduce_group.gathered)
                gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                future.set_result(gathered_data_by_peer)

            except (AllreduceException, MatchmakingException):
                time_elapsed = get_dht_time() - start_time
                if not allow_retries or (timeout is not None and timeout < time_elapsed):
                    future.set_result(None)

            except Exception as e:
                future.set_exception(e)
                raise
            finally:
                _ = self._running_groups.pop(group_id, None)
                self._pending_group_assembled.set()
예제 #17
0
    async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str],
                              k: int, max_prefetch: int, chunk_size: int,
                              future: MPFuture):
        num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
        total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
        found: List[Tuple[str, RemoteExpert]] = []

        pending_tasks = deque(
            asyncio.create_task(
                node.get_many(uid_prefixes[chunk_i * chunk_size:(chunk_i + 1) *
                                           chunk_size],
                              num_workers=num_workers_per_chunk))
            for chunk_i in range(min(max_prefetch + 1, total_chunks))
        )  # pre-dispatch first task and up to max_prefetch additional tasks

        for chunk_i in range(total_chunks):
            # parse task results in chronological order, launch additional tasks on demand
            response = await pending_tasks.popleft()
            for uid_prefix in uid_prefixes[chunk_i * chunk_size:(chunk_i + 1) *
                                           chunk_size]:
                maybe_expert_data, maybe_expiration_time = response[uid_prefix]
                if maybe_expiration_time is not None:  # found active peer
                    found.append(
                        (uid_prefix, RemoteExpert(**maybe_expert_data)))
                    # if we found enough active experts, finish immediately
                    if len(found) >= k:
                        break
            if len(found) >= k:
                break

            pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
            if pre_dispatch_chunk_i < total_chunks:
                pending_tasks.append(
                    asyncio.create_task(
                        node.get_many(
                            uid_prefixes[pre_dispatch_chunk_i *
                                         chunk_size:(pre_dispatch_chunk_i +
                                                     1) * chunk_size],
                            num_workers=num_workers_per_chunk)))

        for task in pending_tasks:
            task.cancel()

        # return k active prefixes or as many as we could find
        future.set_result(OrderedDict(found))
예제 #18
0
 def set_group_bits(self, group_bits: str, wait: bool = True):
     """
     :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
     :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
     """
     future, _future = MPFuture.make_pair()
     assert all(bit in '01' for bit in group_bits)
     self.pipe.send(
         ('_set_group_bits', [], dict(group_bits=group_bits,
                                      future=_future)))
     return future.result() if wait else future
예제 #19
0
 def get_experts(self, uids: List[str], expiration_time: Optional[DHTExpiration] = None,
                 return_future=False) -> List[Optional[RemoteExpert]]:
     """
     :param uids: find experts with these ids from across the DHT
     :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
     :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
     :returns: a list of [RemoteExpert if found else None]
     """
     assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
     future, _future = MPFuture.make_pair()
     self.pipe.send(('_get_experts', [], dict(uids=uids, expiration_time=expiration_time, future=_future)))
     return future if return_future else future.result()
예제 #20
0
 def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
           subkey: Optional[Subkey] = None, return_future: bool = False, **kwargs) -> Union[bool, MPFuture]:
     """
     Find num_replicas best nodes to store (key, value) and store it there until expiration time.
     :note: store is a simplified interface to store_many, all kwargs are be forwarded there
     :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
     :returns: True if store succeeds, False if it fails (due to no response or newer value)
     """
     future, _future = MPFuture.make_pair()
     self.pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
                                        future=_future, **kwargs)))
     return future if return_future else future.result()
예제 #21
0
    def step(self,
             timeout: Optional[float] = None,
             return_future=False) -> Union[Sequence[torch.Tensor], MPFuture]:
        """
        Set up the averager to look for a group and run one round of averaging, then return the averaged tensors

        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
        :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
        """
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_step', [], dict(future=_future, timeout=timeout)))
        return future if return_future else future.result()
예제 #22
0
 def submit_task(self, *args: torch.Tensor) -> Future:
     """ Add task to this pool's queue, return Future for its output """
     future1, future2 = MPFuture.make_pair()
     task = Task(future1, args)
     if self.get_task_size(task) > self.max_batch_size:
         exc = ValueError(
             f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed"
         )
         future2.set_exception(exc)
     else:
         self.tasks.put(task)
         self.undispatched_task_timestamps.put(time.time())
     return future2
예제 #23
0
 def get(self, key: DHTKey, latest: bool = False, return_future: bool = False, **kwargs
         ) -> Union[Optional[ValueWithExpiration[DHTValue]], MPFuture]:
     """
     Search for a key across DHT and return either first or latest entry (if found).
     :param key: same key as in node.store(...)
     :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
     :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
     :param kwargs: parameters forwarded to DHTNode.get_many_by_id
     :returns: (value, expiration time); if value was not found, returns None
     """
     future, _future = MPFuture.make_pair()
     self.pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
     return future if return_future else future.result()
예제 #24
0
    def load_state_from_peers(
            self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
        """
        Try to download the latest optimizer state one of the existing peer.
        :returns: on success, return a 2-tuple with (metadata, tensors), where

        - metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
        - tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics

        The exact contents of both metadata and tensors are determined by get_current_state method
        """
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
        return future.result() if wait else future
예제 #25
0
    def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
        """
        Get this machine's visible address by requesting other peers or using pre-specified network addresses.
        If no parameters are specified, this function will check for manual endpoint; if unavailable, ask 1 random peer.

        :param num_peers: if specified, ask multiple peers and check that they perceive the same endpoint
        :param peers: if specified, ask these exact peers instead of choosing random known peers
        :note: if this node has no known peers in routing table, one must specify :peers: manually
        """
        assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
        assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
        future, _future = MPFuture.make_pair()
        self.pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
        return future.result()
예제 #26
0
 async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
     try:
         result = await node.get(key, latest=latest, **kwargs)
         if not future.done():
             future.set_result(result)
     except BaseException as e:
         if not future.done():
             future.set_exception(e)
         raise
예제 #27
0
    async def _step(self, *, future: MPFuture, timeout: Optional[float]):
        group_id = None
        try:
            self._pending_group_assembled.clear()
            allreduce_group = await self._matchmaking.look_for_group(
                timeout=timeout)
            group_id = allreduce_group.group_id
            if allreduce_group is not None:
                self._running_groups[group_id] = allreduce_group
                self._pending_group_assembled.set()
                future.set_result(await
                                  asyncio.wait_for(allreduce_group.run(),
                                                   self.allreduce_timeout))
            else:
                raise AllreduceException(
                    f"{self} - group_allreduce failed, unable to find a group")

        except Exception as e:
            future.set_exception(e)
            raise
        finally:
            self._pending_group_assembled.set()
            if group_id is not None:
                _ = self._running_groups.pop(group_id, None)
예제 #28
0
    def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
        """
        Make experts visible to all DHT peers; update timestamps if declared previously.

        :param uids: a list of expert ids to update
        :param endpoint: endpoint that serves these experts, usually your server endpoint (e.g. "201.111.222.333:1337")
        :param wait: if True, awaits for declaration to finish, otherwise runs in background
        :param timeout: waits for the procedure to finish, None means wait indeninitely
        :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
        """
        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
        future, _future = MPFuture.make_pair() if wait else (None, None)
        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), endpoint=endpoint, future=_future)))
        if wait:
            return future.result(timeout)
예제 #29
0
 async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
                  subkey: Optional[Subkey], future: MPFuture, **kwargs):
     try:
         result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs)
         if not future.done():
             future.set_result(result)
     except BaseException as e:
         if not future.done():
             future.set_exception(e)
         raise
예제 #30
0
    async def _step(self, *, future: MPFuture, gather_binary: bytes,
                    weight: float, allow_retries: bool,
                    timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        while not future.done():
            try:
                self._pending_group_assembled.clear()
                data_for_gather = self.serializer.dumps(
                    [weight, self._throughput, self.listen, gather_binary])
                group_info = await self._matchmaking.look_for_group(
                    timeout=timeout, data_for_gather=data_for_gather)
                if group_info is None:
                    raise AllreduceException(
                        "Averaging step failed: could not find a group.")
                group_id = group_info.group_id
                allreduce_runner = await self._make_allreduce_runner(
                    group_info, **self.allreduce_kwargs)
                self._running_groups[group_id] = allreduce_runner
                self._pending_group_assembled.set()
                await asyncio.wait_for(allreduce_runner.run(),
                                       self._allreduce_timeout)
                await loop.run_in_executor(None, self.update_tensors,
                                           allreduce_runner)

                # averaging is finished, exit the loop
                future.set_result(allreduce_runner.gathered)

            except (AllreduceException, MatchmakingException, AssertionError,
                    asyncio.InvalidStateError, grpc.RpcError,
                    grpc.aio.AioRpcError, InternalError) as e:
                time_elapsed = get_dht_time() - start_time
                if not allow_retries or (timeout is not None
                                         and timeout < time_elapsed):
                    logger.warning(f"Averager caught {e}")
                    future.set_result(None)
                else:
                    logger.warning(f"Averager caught {e}, retrying")

            except Exception as e:
                future.set_exception(e)
                raise
            finally:
                _ = self._running_groups.pop(group_id, None)
                self._pending_group_assembled.set()