Exemple #1
0
    async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
        """
        :param leader: request this peer to be your leader for allreduce
        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
        :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
        :note: this function does not guarantee that your group leader is the same as :leader: parameter
          The originally specified leader can disband group and redirect us to a different leader
        """
        assert self.is_looking_for_group and self.current_leader is None
        call: Optional[grpc.aio.UnaryStreamCall] = None
        try:
            async with self.lock_request_join_group:
                leader_stub = ChannelCache.get_stub(leader, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
                call = leader_stub.rpc_join_group(averaging_pb2.JoinRequest(
                    endpoint=self.endpoint, schema_hash=self.schema_hash, expiration=expiration_time,
                    client_mode=self.client_mode, gather=self.data_for_gather))
                message = await asyncio.wait_for(call.read(), timeout=self.request_timeout)

                if message.code == averaging_pb2.ACCEPTED:
                    logger.debug(f"{self.endpoint} - joining the group of {leader}; waiting for peers")
                    self.current_leader = leader
                    self.was_accepted_to_group.set()
                    if len(self.current_followers) > 0:
                        await self.leader_disband_group()

            if message.code != averaging_pb2.ACCEPTED:
                code = averaging_pb2.MessageCode.Name(message.code)
                logger.debug(f"{self.endpoint} - requested {leader} to be my leader, but got rejected with {code}")
                return None

            async with self.potential_leaders.pause_search():
                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
                message = await asyncio.wait_for(call.read(), time_to_expiration + self.request_timeout)

                if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                    async with self.lock_request_join_group:
                        return await self.follower_assemble_group(leader, message)

            if message.code in (averaging_pb2.GROUP_DISBANDED, averaging_pb2.CANCELLED):
                if message.suggested_leader and message.suggested_leader != self.endpoint:
                    logger.debug(f"{self} - leader disbanded group and redirected us to {message.suggested_leader}")
                    self.current_leader = None
                    call.cancel()
                    return await self.request_join_group(message.suggested_leader, expiration_time)
                else:
                    logger.debug(f"{self} - leader disbanded group")
                    return None

            logger.debug(f"{self} - unexpected message from leader: {averaging_pb2.MessageCode.Name(message.code)}")
            return None
        except asyncio.TimeoutError:
            logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
            if call is not None:
                call.cancel()
            return None
        finally:
            self.was_accepted_to_group.clear()
            self.current_leader = None
            if call is not None:
                await call.code()
Exemple #2
0
def _get_expert_stub(endpoint: Endpoint, *extra_options: Tuple[str, Any]):
    """ Create a gRPC stub to access remote expert or use previously created stub from a process-wide cache """
    channel_options = (('grpc.max_send_message_length', -1),
                       ('grpc.max_receive_message_length', -1)) + extra_options
    return ChannelCache.get_stub(endpoint,
                                 runtime_grpc.ConnectionHandlerStub,
                                 aio=False,
                                 options=channel_options)
Exemple #3
0
    async def _load_state_from_peers(self, future: MPFuture):
        key_manager = self._matchmaking.group_key_manager
        peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers",
                                        latest=True) or ({}, None)
        peer_priority = {
            peer: float(info.value)
            for peer, info in peer_priority.items()
            if isinstance(info, ValueWithExpiration)
            and isinstance(info.value, (float, int))
        }

        if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
            logger.info(
                f"Averager could not load state from peers: peer dict is absent or corrupted {peer_priority}."
            )
            future.set_result(None)
            return

        metadata = None
        for peer in sorted(peer_priority.keys(),
                           key=peer_priority.get,
                           reverse=True):
            if peer != self.endpoint:
                logger.info(f"Downloading parameters from peer {peer}")
                stream = None
                try:
                    leader_stub = ChannelCache.get_stub(
                        peer,
                        averaging_pb2_grpc.DecentralizedAveragingStub,
                        aio=True)
                    stream = leader_stub.rpc_download_state(
                        averaging_pb2.DownloadRequest())
                    current_tensor_parts, tensors = [], []
                    async for message in stream:
                        if message.metadata:
                            metadata = self.serializer.loads(message.metadata)
                        if message.tensor_part.dtype and current_tensor_parts:
                            # tensor_part.dtype indicates the start of the new tensor, so we should wrap up this one
                            tensors.append(
                                deserialize_torch_tensor(
                                    combine_from_streaming(
                                        current_tensor_parts)))
                            current_tensor_parts = []
                        current_tensor_parts.append(message.tensor_part)
                    if current_tensor_parts:
                        tensors.append(
                            deserialize_torch_tensor(
                                combine_from_streaming(current_tensor_parts)))
                    future.set_result((metadata, tensors))
                    self.last_updated = get_dht_time()
                    return
                except grpc.aio.AioRpcError as e:
                    logger.info(f"Failed to download state from {peer} - {e}")
                finally:
                    if stream is not None:
                        await stream.code()

        else:
            logger.warning(
                "Averager could not load state from peers: found no active peers."
            )
            future.set_result(None)