示例#1
0
    async def _step(self, *, future: MPFuture, gather_binary: bytes,
                    weight: float, allow_retries: bool,
                    timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        try:
            while not future.done():
                try:
                    self._pending_group_assembled.clear()
                    data_for_gather = self.serializer.dumps([
                        weight, self._throughput, self.mode.value,
                        gather_binary
                    ])
                    group_info = await self._matchmaking.look_for_group(
                        timeout=timeout, data_for_gather=data_for_gather)
                    if group_info is None:
                        raise AllreduceException(
                            "Averaging step failed: could not find a group.")
                    group_id = group_info.group_id
                    allreduce_runner = await self._make_allreduce_runner(
                        group_info, **self.allreduce_kwargs)
                    self._running_groups[group_id] = allreduce_runner
                    self._pending_group_assembled.set()
                    await asyncio.wait_for(allreduce_runner.run(),
                                           self._allreduce_timeout)
                    if self.mode != AveragingMode.AUX:
                        await loop.run_in_executor(None, self.update_tensors,
                                                   allreduce_runner)

                    # averaging is finished, exit the loop
                    future.set_result(allreduce_runner.gathered)

                except (AllreduceException, MatchmakingException,
                        AssertionError, StopAsyncIteration, InternalError,
                        asyncio.CancelledError, asyncio.InvalidStateError,
                        grpc.RpcError, grpc.aio.AioRpcError) as e:
                    time_elapsed = get_dht_time() - start_time
                    if not allow_retries or (timeout is not None
                                             and timeout < time_elapsed):
                        logger.exception(f"Averager caught {repr(e)}")
                        future.set_exception(e)
                    else:
                        logger.warning(f"Averager caught {repr(e)}, retrying")

                finally:
                    _ = self._running_groups.pop(group_id, None)
                    self._pending_group_assembled.set()

        except BaseException as e:
            if not future.done():
                future.set_exception(e)
            raise
        finally:
            if not future.done():
                future.set_exception(
                    RuntimeError(
                        "Internal sanity check failed: averager.step left future pending."
                        " Please report this to hivemind issues."))
示例#2
0
    async def validate_request(self, request: AuthorizedRequestBase) -> bool:
        await self.refresh_token_if_needed()
        auth = request.auth

        if not self.is_token_valid(auth.client_access_token):
            logger.debug('Client failed to prove that it (still) has access to the network')
            return False

        client_public_key = RSAPublicKey.from_bytes(auth.client_access_token.public_key)
        signature = auth.signature
        auth.signature = b''
        if not client_public_key.verify(request.SerializeToString(), signature):
            logger.debug('Request has invalid signature')
            return False

        if auth.service_public_key and auth.service_public_key != self._local_public_key.to_bytes():
            logger.debug('Request is generated for a peer with another public key')
            return False

        with self._recent_nonces.freeze():
            current_time = get_dht_time()
            if abs(auth.time - current_time) > self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds():
                logger.debug('Clocks are not synchronized or a previous request is replayed again')
                return False
            if auth.nonce in self._recent_nonces:
                logger.debug('Previous request is replayed again')
                return False

        self._recent_nonces.store(auth.nonce, None,
                                  current_time + self._MAX_CLIENT_SERVICER_TIME_DIFF.total_seconds() * 3)
        return True
示例#3
0
    async def _step(self, *, future: MPFuture, gather_binary: bytes,
                    allow_retries: bool, timeout: Optional[float]):
        loop = asyncio.get_event_loop()
        start_time = get_dht_time()
        group_id = None

        while not future.done():
            try:
                self._pending_group_assembled.clear()
                allreduce_group = await self._matchmaking.look_for_group(
                    timeout=timeout, data_for_gather=gather_binary)
                if allreduce_group is None:
                    raise AllreduceException(
                        "Averaging step failed: could not find a group.")

                group_id = allreduce_group.group_id
                self._running_groups[group_id] = allreduce_group
                self._pending_group_assembled.set()
                await asyncio.wait_for(allreduce_group.run(),
                                       self._allreduce_timeout)
                await loop.run_in_executor(None, self.update_tensors,
                                           allreduce_group)

                # averaging is finished, exit the loop
                gathered_items = map(self.serializer.loads,
                                     allreduce_group.gathered)
                gathered_data_by_peer = dict(
                    zip(allreduce_group.ordered_group_endpoints,
                        gathered_items))
                future.set_result(gathered_data_by_peer)

            except (AllreduceException, MatchmakingException,
                    asyncio.exceptions.InvalidStateError, grpc.RpcError,
                    grpc.aio.AioRpcError, InternalError) as e:
                time_elapsed = get_dht_time() - start_time
                if not allow_retries or (timeout is not None
                                         and timeout < time_elapsed):
                    future.set_result(None)
                else:
                    logger.debug(f"caught {e}, retrying")

            except Exception as e:
                future.set_exception(e)
                raise
            finally:
                _ = self._running_groups.pop(group_id, None)
                self._pending_group_assembled.set()
示例#4
0
 def get_tensors(self) -> Sequence[torch.Tensor]:
     """
     A contextmanager that gives user access to averaged tensors.
     It is guaranteed that the averager will not modify tensors while this context is active.
     Please do not modify the yielded tensors in-place after the context is released.
     """
     with self.lock_averaged_tensors:
         yield self._averaged_tensors
     self.last_updated = get_dht_time()
示例#5
0
    def update_tensors(self, allreduce_group: AllReduceRunner):
        """
        a private (extendable) method that applies changes from a finished allreduce to local tensors
        """
        assert allreduce_group.return_deltas and allreduce_group.future.done()
        averaging_deltas = allreduce_group.future.result()

        with torch.no_grad(), self.get_tensors() as local_tensors:
            assert len(local_tensors) == len(self._averaged_tensors)
            for tensor, update in zip(local_tensors, averaging_deltas):
                tensor.add_(update, alpha=self._averaging_alpha)
        self.last_updated = get_dht_time()
示例#6
0
 async def _declare_for_download_periodically(self):
     download_key = f'{self._matchmaking.group_key_manager.prefix}.all_averagers'
     while True:
         asyncio.create_task(
             asyncio.wait_for(
                 self.dht.store(download_key,
                                subkey=self.endpoint,
                                value=self.last_updated,
                                expiration_time=get_dht_time() +
                                self._matchmaking.averaging_expiration,
                                return_future=True),
                 timeout=self._matchmaking.averaging_expiration))
         await asyncio.sleep(self._matchmaking.averaging_expiration)
示例#7
0
    async def sign_request(self, request: AuthorizedRequestBase, service_public_key: Optional[RSAPublicKey]) -> None:
        await self.refresh_token_if_needed()
        auth = request.auth

        auth.client_access_token.CopyFrom(self._local_access_token)

        if service_public_key is not None:
            auth.service_public_key = service_public_key.to_bytes()
        auth.time = get_dht_time()
        auth.nonce = secrets.token_bytes(8)

        assert auth.signature == b''
        auth.signature = self._local_private_key.sign(request.SerializeToString())
示例#8
0
文件: grpc.py 项目: swoopyy/hivemind
    def _evict_stale_channels_in_background(self):
        while self._is_active:
            now = get_dht_time()
            time_to_wait = max(0.0, self._nearest_expiration_time - now)
            interrupted_early = self._update_eviction_evt.wait(
                time_to_wait if time_to_wait != float('inf') else None)
            if interrupted_early:
                self._update_eviction_evt.clear()
                continue

            with self._lock:
                self._remove_outdated()
                _, entry = super().top()
                self._nearest_expiration_time = entry.expiration_time if entry is not None else float(
                    'inf')
示例#9
0
文件: grpc.py 项目: swoopyy/hivemind
    def get_stub(cls,
                 target: Endpoint,
                 stub_type: Type[Stub],
                 *,
                 aio: bool,
                 options: Tuple[Tuple[str, Any]] = (),
                 channel_credentials: Optional[grpc.ChannelCredentials] = None,
                 compression: Optional[grpc.Compression] = None) -> Stub:
        """
        Create a grpc channel with given options or reuse pre-existing one

        :param target: the recipient's address and port
        :param stub_type: a gRPC stub (client) to be instantiated
        :param aio: if True, returns grpc.Channel, otherwise returns grpc.aio.Channel
        :param options: see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
        :param channel_credentials: if specified, create a secure channel usin these credentials (default = insecure)
        :param compression: see https://github.com/grpc/grpc/tree/master/examples/python/compression
        """
        cache = cls.get_singleton()
        with cls._lock:
            key = ChannelInfo(target, aio, tuple(options), channel_credentials,
                              compression)
            entry: ValueWithExpiration = super(cls, cache).get(key)

            if entry is not None:
                channel, stubs = entry.value
            else:
                channel = cls._create_channel(*key)
                stubs = {}

            channel._channel.check_connectivity_state(True)

            if stub_type not in stubs:
                stubs[stub_type] = stub_type(channel)

            # either cache channel or update expiration of an existing channel
            expiration_time = get_dht_time() + cls.EVICTION_PERIOD_SECONDS
            super(cls, cache).store(key, (channel, stubs), expiration_time)

            if expiration_time < cache._nearest_expiration_time:
                cache._nearest_expiration_time = expiration_time
                cls._update_eviction_evt.set()

            return stubs[stub_type]
示例#10
0
    async def _load_state_from_peers(self, future: MPFuture):
        key_manager = self._matchmaking.group_key_manager
        peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers",
                                        latest=True) or ({}, None)
        peer_priority = {
            peer: float(info.value)
            for peer, info in peer_priority.items()
            if isinstance(info, ValueWithExpiration)
            and isinstance(info.value, (float, int))
        }

        if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
            logger.info(
                f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}."
            )
            future.set_result(None)
            return

        metadata = None
        for peer in sorted(peer_priority.keys(),
                           key=peer_priority.get,
                           reverse=True):
            if peer != self.endpoint:
                logger.info(f"Downloading parameters from peer {peer}")
                stream = None
                try:
                    leader_stub = ChannelCache.get_stub(
                        peer,
                        averaging_pb2_grpc.DecentralizedAveragingStub,
                        aio=True)
                    stream = leader_stub.rpc_download_state(
                        averaging_pb2.DownloadRequest())
                    current_tensor_parts, tensors = [], []
                    async for message in stream:
                        if message.metadata:
                            metadata = self.serializer.loads(message.metadata)
                        if message.tensor_part.dtype and current_tensor_parts:
                            # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
                            tensors.append(
                                deserialize_torch_tensor(
                                    combine_from_streaming(
                                        current_tensor_parts)))
                            current_tensor_parts = []
                        current_tensor_parts.append(message.tensor_part)
                    if current_tensor_parts:
                        tensors.append(
                            deserialize_torch_tensor(
                                combine_from_streaming(current_tensor_parts)))
                    future.set_result((metadata, tensors))
                    self.last_updated = get_dht_time()
                    return
                except grpc.aio.AioRpcError as e:
                    logger.info(f"Failed to download state from {peer} - {e}")
                finally:
                    if stream is not None:
                        await stream.code()

        else:
            logger.warning(
                "Averager could not load state from peers: found no active peers."
            )
            future.set_result(None)