Esempio n. 1
0
    async def _step(self, *, future: MPFuture, gather_binary: bytes,
                    weight: float, allow_retries: bool,
                    timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        try:
            while not future.done():
                try:
                    self._pending_group_assembled.clear()
                    data_for_gather = self.serializer.dumps([
                        weight, self._throughput, self.mode.value,
                        gather_binary
                    ])
                    group_info = await self._matchmaking.look_for_group(
                        timeout=timeout, data_for_gather=data_for_gather)
                    if group_info is None:
                        raise AllreduceException(
                            "Averaging step failed: could not find a group.")
                    group_id = group_info.group_id
                    allreduce_runner = await self._make_allreduce_runner(
                        group_info, **self.allreduce_kwargs)
                    self._running_groups[group_id] = allreduce_runner
                    self._pending_group_assembled.set()
                    await asyncio.wait_for(allreduce_runner.run(),
                                           self._allreduce_timeout)
                    if self.mode != AveragingMode.AUX:
                        await loop.run_in_executor(None, self.update_tensors,
                                                   allreduce_runner)

                    # averaging is finished, exit the loop
                    future.set_result(allreduce_runner.gathered)

                except (AllreduceException, MatchmakingException,
                        AssertionError, StopAsyncIteration, InternalError,
                        asyncio.CancelledError, asyncio.InvalidStateError,
                        grpc.RpcError, grpc.aio.AioRpcError) as e:
                    time_elapsed = get_dht_time() - start_time
                    if not allow_retries or (timeout is not None
                                             and timeout < time_elapsed):
                        logger.exception(f"Averager caught {repr(e)}")
                        future.set_exception(e)
                    else:
                        logger.warning(f"Averager caught {repr(e)}, retrying")

                finally:
                    _ = self._running_groups.pop(group_id, None)
                    self._pending_group_assembled.set()

        except BaseException as e:
            if not future.done():
                future.set_exception(e)
            raise
        finally:
            if not future.done():
                future.set_exception(
                    RuntimeError(
                        "Internal sanity check failed: averager.step left future pending."
                        " Please report this to hivemind issues."))
Esempio n. 2
0
 async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
     try:
         result = await node.get(key, latest=latest, **kwargs)
         if not future.done():
             future.set_result(result)
     except BaseException as e:
         if not future.done():
             future.set_exception(e)
         raise
Esempio n. 3
0
 async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
                  subkey: Optional[Subkey], future: MPFuture, **kwargs):
     try:
         result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs)
         if not future.done():
             future.set_result(result)
     except BaseException as e:
         if not future.done():
             future.set_exception(e)
         raise
Esempio n. 4
0
 async def _set_group_bits(self, group_bits: str, future: MPFuture):
     try:
         self._matchmaking.group_key_manager.group_bits = group_bits
         return future.set_result(None)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
Esempio n. 5
0
    async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        while not future.done():
            try:
                self._pending_group_assembled.clear()
                gather_binary = self.serializer.dumps(gather)
                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                if allreduce_group is None:
                    raise AllreduceException("Averaging step failed: could not find a group.")

                group_id = allreduce_group.group_id
                self._running_groups[group_id] = allreduce_group
                self._pending_group_assembled.set()
                await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
                await loop.run_in_executor(None, self.update_tensors, allreduce_group)

                # averaging is finished, exit the loop
                gathered_items = map(self.serializer.loads, allreduce_group.gathered)
                gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                future.set_result(gathered_data_by_peer)

            except (AllreduceException, MatchmakingException):
                time_elapsed = get_dht_time() - start_time
                if not allow_retries or (timeout is not None and timeout < time_elapsed):
                    future.set_result(None)

            except Exception as e:
                future.set_exception(e)
                raise
            finally:
                _ = self._running_groups.pop(group_id, None)
                self._pending_group_assembled.set()
Esempio n. 6
0
 async def _declare_averager(self, node: DHTNode, *, group_key: str,
                             endpoint: Endpoint,
                             expiration_time: DHTExpiration,
                             looking_for_group: bool, future: MPFuture):
     try:
         expiration_time = expiration_time if looking_for_group else float(
             nextafter(expiration_time, float('inf')))
         # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
         store_ok = await node.store(key=group_key,
                                     subkey=endpoint,
                                     value=looking_for_group,
                                     expiration_time=expiration_time)
         future.set_result(store_ok)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
Esempio n. 7
0
 async def _get_averagers(self, node: DHTNode, *, group_key: str,
                          only_active: bool, future: MPFuture):
     try:
         result = await node.get(group_key, latest=True)
         if result is None:
             logger.debug(
                 f"Allreduce group not found: {group_key}, creating new group."
             )
             future.set_result([])
             return
         assert isinstance(result.value, dict), f"expected {group_key} to be a Dict[Endpoint, is_active], " \
                                                f"but got {result.value} of type {type(result.value)}."
         averagers = [(endpoint, entry.expiration_time)
                      for endpoint, entry in result.value.items()
                      if not only_active or entry.value is True]
         future.set_result(averagers)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
Esempio n. 8
0
    async def _load_state_from_peers(self, future: MPFuture):
        try:
            key_manager = self._matchmaking.group_key_manager
            peer_priority, _ = self.dht.get(
                f"{key_manager.prefix}.all_averagers",
                latest=True) or ({}, None)
            peer_priority = {
                peer: float(info.value)
                for peer, info in peer_priority.items()
                if isinstance(info, ValueWithExpiration)
                and isinstance(info.value, (float, int))
            }

            if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
                logger.info(
                    f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}."
                )
                future.set_result(None)
                return

            metadata = None
            for peer in sorted(peer_priority.keys(),
                               key=peer_priority.get,
                               reverse=True):
                if peer != self.endpoint:
                    logger.info(f"Downloading parameters from peer {peer}")
                    stream = None
                    try:
                        stub = ChannelCache.get_stub(
                            peer,
                            averaging_pb2_grpc.DecentralizedAveragingStub,
                            aio=True)
                        stream = stub.rpc_download_state(
                            averaging_pb2.DownloadRequest())
                        current_tensor_parts, tensors = [], []
                        async for message in stream:
                            if message.metadata:
                                metadata = self.serializer.loads(
                                    message.metadata)
                            if message.tensor_part.dtype and current_tensor_parts:
                                # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
                                tensors.append(
                                    deserialize_torch_tensor(
                                        combine_from_streaming(
                                            current_tensor_parts)))
                                current_tensor_parts = []
                            current_tensor_parts.append(message.tensor_part)
                        if current_tensor_parts:
                            tensors.append(
                                deserialize_torch_tensor(
                                    combine_from_streaming(
                                        current_tensor_parts)))

                        if not metadata:
                            logger.debug(
                                f"Peer {peer} did not send its state.")
                            continue

                        logger.info(f"Finished downloading state from {peer}")
                        future.set_result((metadata, tensors))
                        self.last_updated = get_dht_time()
                        return
                    except BaseException as e:
                        logger.exception(
                            f"Failed to download state from {peer} - {repr(e)}"
                        )
                    finally:
                        if stream is not None:
                            await stream.code()

        finally:
            if not future.done():
                logger.warning(
                    "Averager could not load state from peers: all requests have failed."
                )
                future.set_result(None)