コード例 #1
0
ファイル: __init__.py プロジェクト: MaximKsh/hivemind
    async def _step(self, *, future: MPFuture, gather: DataForGather, allow_retries: bool, timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        while not future.done():
            try:
                self._pending_group_assembled.clear()
                gather_binary = self.serializer.dumps(gather)
                allreduce_group = await self._matchmaking.look_for_group(timeout=timeout, data_for_gather=gather_binary)
                if allreduce_group is None:
                    raise AllreduceException("Averaging step failed: could not find a group.")

                group_id = allreduce_group.group_id
                self._running_groups[group_id] = allreduce_group
                self._pending_group_assembled.set()
                await asyncio.wait_for(allreduce_group.run(), self._allreduce_timeout)
                await loop.run_in_executor(None, self.update_tensors, allreduce_group)

                # averaging is finished, exit the loop
                gathered_items = map(self.serializer.loads, allreduce_group.gathered)
                gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
                future.set_result(gathered_data_by_peer)

            except (AllreduceException, MatchmakingException):
                time_elapsed = get_dht_time() - start_time
                if not allow_retries or (timeout is not None and timeout < time_elapsed):
                    future.set_result(None)

            except Exception as e:
                future.set_exception(e)
                raise
            finally:
                _ = self._running_groups.pop(group_id, None)
                self._pending_group_assembled.set()
コード例 #2
0
 async def _set_group_bits(self, group_bits: str, future: MPFuture):
     try:
         self._matchmaking.group_key_manager.group_bits = group_bits
         return future.set_result(None)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
コード例 #3
0
 async def _get(self, node: DHTNode, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
     try:
         result = await node.get(key, latest=latest, **kwargs)
         if not future.done():
             future.set_result(result)
     except BaseException as e:
         if not future.done():
             future.set_exception(e)
         raise
コード例 #4
0
 async def _store(self, node: DHTNode, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
                  subkey: Optional[Subkey], future: MPFuture, **kwargs):
     try:
         result = await node.store(key, value, expiration_time, subkey=subkey, **kwargs)
         if not future.done():
             future.set_result(result)
     except BaseException as e:
         if not future.done():
             future.set_exception(e)
         raise
コード例 #5
0
 async def _declare_averager(self, node: DHTNode, *, group_key: str,
                             endpoint: Endpoint,
                             expiration_time: DHTExpiration,
                             looking_for_group: bool, future: MPFuture):
     try:
         expiration_time = expiration_time if looking_for_group else float(
             nextafter(expiration_time, float('inf')))
         # ^-- when declaring averager inactive, we increment expiration time to overwrite the pre-existing entry
         store_ok = await node.store(key=group_key,
                                     subkey=endpoint,
                                     value=looking_for_group,
                                     expiration_time=expiration_time)
         future.set_result(store_ok)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
コード例 #6
0
    async def _step(self, *, future: MPFuture, gather_binary: bytes,
                    weight: float, allow_retries: bool,
                    timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        while not future.done():
            try:
                self._pending_group_assembled.clear()
                data_for_gather = self.serializer.dumps(
                    [weight, self._throughput, self.listen, gather_binary])
                group_info = await self._matchmaking.look_for_group(
                    timeout=timeout, data_for_gather=data_for_gather)
                if group_info is None:
                    raise AllreduceException(
                        "Averaging step failed: could not find a group.")
                group_id = group_info.group_id
                allreduce_runner = await self._make_allreduce_runner(
                    group_info, **self.allreduce_kwargs)
                self._running_groups[group_id] = allreduce_runner
                self._pending_group_assembled.set()
                await asyncio.wait_for(allreduce_runner.run(),
                                       self._allreduce_timeout)
                await loop.run_in_executor(None, self.update_tensors,
                                           allreduce_runner)

                # averaging is finished, exit the loop
                future.set_result(allreduce_runner.gathered)

            except (AllreduceException, MatchmakingException, AssertionError,
                    asyncio.InvalidStateError, grpc.RpcError,
                    grpc.aio.AioRpcError, InternalError) as e:
                time_elapsed = get_dht_time() - start_time
                if not allow_retries or (timeout is not None
                                         and timeout < time_elapsed):
                    logger.warning(f"Averager caught {e}")
                    future.set_result(None)
                else:
                    logger.warning(f"Averager caught {e}, retrying")

            except Exception as e:
                future.set_exception(e)
                raise
            finally:
                _ = self._running_groups.pop(group_id, None)
                self._pending_group_assembled.set()
コード例 #7
0
 async def _get_averagers(self, node: DHTNode, *, group_key: str,
                          only_active: bool, future: MPFuture):
     try:
         result = await node.get(group_key, latest=True)
         if result is None:
             logger.debug(
                 f"Allreduce group not found: {group_key}, creating new group."
             )
             future.set_result([])
             return
         assert isinstance(result.value, dict), f"expected {group_key} to be a Dict[Endpoint, is_active], " \
                                                f"but got {result.value} of type {type(result.value)}."
         averagers = [(endpoint, entry.expiration_time)
                      for endpoint, entry in result.value.items()
                      if not only_active or entry.value is True]
         future.set_result(averagers)
     except Exception as e:
         if not future.done():
             future.set_exception(e)
コード例 #8
0
    async def _step(self, *, future: MPFuture, timeout: Optional[float]):
        group_id = None
        try:
            self._pending_group_assembled.clear()
            allreduce_group = await self._matchmaking.look_for_group(
                timeout=timeout)
            group_id = allreduce_group.group_id
            if allreduce_group is not None:
                self._running_groups[group_id] = allreduce_group
                self._pending_group_assembled.set()
                future.set_result(await
                                  asyncio.wait_for(allreduce_group.run(),
                                                   self.allreduce_timeout))
            else:
                raise AllreduceException(
                    f"{self} - group_allreduce failed, unable to find a group")

        except Exception as e:
            future.set_exception(e)
            raise
        finally:
            self._pending_group_assembled.set()
            if group_id is not None:
                _ = self._running_groups.pop(group_id, None)