コード例 #1
0
    async def _step(self, *, future: MPFuture, gather_binary: bytes,
                    weight: float, allow_retries: bool,
                    timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        try:
            while not future.done():
                try:
                    self._pending_group_assembled.clear()
                    data_for_gather = self.serializer.dumps([
                        weight, self._throughput, self.mode.value,
                        gather_binary
                    ])
                    group_info = await self._matchmaking.look_for_group(
                        timeout=timeout, data_for_gather=data_for_gather)
                    if group_info is None:
                        raise AllreduceException(
                            "Averaging step failed: could not find a group.")
                    group_id = group_info.group_id
                    allreduce_runner = await self._make_allreduce_runner(
                        group_info, **self.allreduce_kwargs)
                    self._running_groups[group_id] = allreduce_runner
                    self._pending_group_assembled.set()
                    await asyncio.wait_for(allreduce_runner.run(),
                                           self._allreduce_timeout)
                    if self.mode != AveragingMode.AUX:
                        await loop.run_in_executor(None, self.update_tensors,
                                                   allreduce_runner)

                    # averaging is finished, exit the loop
                    future.set_result(allreduce_runner.gathered)

                except (AllreduceException, MatchmakingException,
                        AssertionError, StopAsyncIteration, InternalError,
                        asyncio.CancelledError, asyncio.InvalidStateError,
                        grpc.RpcError, grpc.aio.AioRpcError) as e:
                    time_elapsed = get_dht_time() - start_time
                    if not allow_retries or (timeout is not None
                                             and timeout < time_elapsed):
                        logger.exception(f"Averager caught {repr(e)}")
                        future.set_exception(e)
                    else:
                        logger.warning(f"Averager caught {repr(e)}, retrying")

                finally:
                    _ = self._running_groups.pop(group_id, None)
                    self._pending_group_assembled.set()

        except BaseException as e:
            if not future.done():
                future.set_exception(e)
            raise
        finally:
            if not future.done():
                future.set_exception(
                    RuntimeError(
                        "Internal sanity check failed: averager.step left future pending."
                        " Please report this to hivemind issues."))
コード例 #2
0
ファイル: __init__.py プロジェクト: MaximKsh/hivemind
    async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        while not future.done():
            try:
                self._pending_group_assembled.clear()
                gather_binary = self.serializer.dumps(gather)
                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                if allreduce_group is None:
                    raise AllreduceException("Averaging step failed: could not find a group.")

                group_id = allreduce_group.group_id
                self._running_groups[group_id] = allreduce_group
                self._pending_group_assembled.set()
                await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
                await loop.run_in_executor(None, self.update_tensors, allreduce_group)

                # averaging is finished, exit the loop
                gathered_items = map(self.serializer.loads, allreduce_group.gathered)
                gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                future.set_result(gathered_data_by_peer)

            except (AllreduceException, MatchmakingException):
                time_elapsed = get_dht_time() - start_time
                if not allow_retries or (timeout is not None and timeout < time_elapsed):
                    future.set_result(None)

            except Exception as e:
                future.set_exception(e)
                raise
            finally:
                _ = self._running_groups.pop(group_id, None)
                self._pending_group_assembled.set()
コード例 #3
0
ファイル: __init__.py プロジェクト: yuejiesong1900/hivemind
 async def _get_experts(
         self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
     if expiration_time is None:
         expiration_time = get_dht_time()
     num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
     response = await node.get_many(uids, expiration_time, num_workers=num_workers)
     future.set_result([RemoteExpert(**expert_data) if maybe_expiration_time else None
                        for uid, (expert_data, maybe_expiration_time) in response.items()])
コード例 #4
0
 async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
     try:
         result = await node.get(key, latest=latest, **kwargs)
         if not future.done():
             future.set_result(result)
     except BaseException as e:
         if not future.done():
             future.set_exception(e)
         raise
コード例 #5
0
 async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
                  subkey: Optional[Subkey], future: MPFuture, **kwargs):
     try:
         result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs)
         if not future.done():
             future.set_result(result)
     except BaseException as e:
         if not future.done():
             future.set_exception(e)
         raise
コード例 #6
0
 async def _set_group_bits(self, group_bits: str, future: MPFuture):
     try:
         self._matchmaking.group_key_manager.group_bits = group_bits
         return future.set_result(None)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
コード例 #7
0
 async def _declare_averager(self, node: DHTNode, *, group_key: str,
                             endpoint: Endpoint,
                             expiration_time: DHTExpiration,
                             looking_for_group: bool, future: MPFuture):
     try:
         expiration_time = expiration_time if looking_for_group else float(
             nextafter(expiration_time, float('inf')))
         # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
         store_ok = await node.store(key=group_key,
                                     subkey=endpoint,
                                     value=looking_for_group,
                                     expiration_time=expiration_time)
         future.set_result(store_ok)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
コード例 #8
0
    async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str],
                              k: int, max_prefetch: int, chunk_size: int,
                              future: MPFuture):
        num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
        total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
        found: List[Tuple[str, RemoteExpert]] = []

        pending_tasks = deque(
            asyncio.create_task(
                node.get_many(uid_prefixes[chunk_i * chunk_size:(chunk_i + 1) *
                                           chunk_size],
                              num_workers=num_workers_per_chunk))
            for chunk_i in range(min(max_prefetch + 1, total_chunks))
        )  # pre-dispatch first task and up to max_prefetch additional tasks

        for chunk_i in range(total_chunks):
            # parse task results in chronological order, launch additional tasks on demand
            response = await pending_tasks.popleft()
            for uid_prefix in uid_prefixes[chunk_i * chunk_size:(chunk_i + 1) *
                                           chunk_size]:
                maybe_expert_data, maybe_expiration_time = response[uid_prefix]
                if maybe_expiration_time is not None:  # found active peer
                    found.append(
                        (uid_prefix, RemoteExpert(**maybe_expert_data)))
                    # if we found enough active experts, finish immediately
                    if len(found) >= k:
                        break
            if len(found) >= k:
                break

            pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
            if pre_dispatch_chunk_i < total_chunks:
                pending_tasks.append(
                    asyncio.create_task(
                        node.get_many(
                            uid_prefixes[pre_dispatch_chunk_i *
                                         chunk_size:(pre_dispatch_chunk_i +
                                                     1) * chunk_size],
                            num_workers=num_workers_per_chunk)))

        for task in pending_tasks:
            task.cancel()

        # return k active prefixes or as many as we could find
        future.set_result(OrderedDict(found))
コード例 #9
0
 async def _get_averagers(self, node: DHTNode, *, group_key: str,
                          only_active: bool, future: MPFuture):
     try:
         result = await node.get(group_key, latest=True)
         if result is None:
             logger.debug(
                 f"Allreduce group not found: {group_key}, creating new group."
             )
             future.set_result([])
             return
         assert isinstance(result.value, dict), f"expected {group_key} to be a Dict[Endpoint, is_active], " \
                                                f"but got {result.value} of type {type(result.value)}."
         averagers = [(endpoint, entry.expiration_time)
                      for endpoint, entry in result.value.items()
                      if not only_active or entry.value is True]
         future.set_result(averagers)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
コード例 #10
0
    async def _step(self, *, future: MPFuture, timeout: Optional[float]):
        group_id = None
        try:
            self._pending_group_assembled.clear()
            allreduce_group = await self._matchmaking.look_for_group(
                timeout=timeout)
            group_id = allreduce_group.group_id
            if allreduce_group is not None:
                self._running_groups[group_id] = allreduce_group
                self._pending_group_assembled.set()
                future.set_result(await
                                  asyncio.wait_for(allreduce_group.run(),
                                                   self.allreduce_timeout))
            else:
                raise AllreduceException(
                    f"{self} - group_allreduce failed, unable to find a group")

        except Exception as e:
            future.set_exception(e)
            raise
        finally:
            self._pending_group_assembled.set()
            if group_id is not None:
                _ = self._running_groups.pop(group_id, None)
コード例 #11
0
 async def _get_group_bits(self, future: MPFuture):
     future.set_result(self._matchmaking.group_key_manager.group_bits)
コード例 #12
0
    async def _load_state_from_peers(self, future: MPFuture):
        key_manager = self._matchmaking.group_key_manager
        peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers",
                                        latest=True) or ({}, None)
        peer_priority = {
            peer: float(info.value)
            for peer, info in peer_priority.items()
            if isinstance(info, ValueWithExpiration)
            and isinstance(info.value, (float, int))
        }

        if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
            logger.info(
                f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}."
            )
            future.set_result(None)
            return

        metadata = None
        for peer in sorted(peer_priority.keys(),
                           key=peer_priority.get,
                           reverse=True):
            if peer != self.endpoint:
                logger.info(f"Downloading parameters from peer {peer}")
                stream = None
                try:
                    leader_stub = ChannelCache.get_stub(
                        peer,
                        averaging_pb2_grpc.DecentralizedAveragingStub,
                        aio=True)
                    stream = leader_stub.rpc_download_state(
                        averaging_pb2.DownloadRequest())
                    current_tensor_parts, tensors = [], []
                    async for message in stream:
                        if message.metadata:
                            metadata = self.serializer.loads(message.metadata)
                        if message.tensor_part.dtype and current_tensor_parts:
                            # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
                            tensors.append(
                                deserialize_torch_tensor(
                                    combine_from_streaming(
                                        current_tensor_parts)))
                            current_tensor_parts = []
                        current_tensor_parts.append(message.tensor_part)
                    if current_tensor_parts:
                        tensors.append(
                            deserialize_torch_tensor(
                                combine_from_streaming(current_tensor_parts)))
                    future.set_result((metadata, tensors))
                    self.last_updated = get_dht_time()
                    return
                except grpc.aio.AioRpcError as e:
                    logger.info(f"Failed to download state from {peer} - {e}")
                finally:
                    if stream is not None:
                        await stream.code()

        else:
            logger.warning(
                "Averager could not load state from peers: found no active peers."
            )
            future.set_result(None)