class ReplayBufferServer(replay_buffer_pb2_grpc.ReplayBufferServicer):
    """
    A server for replay buffers, exposing their functionality through a gRPC API.
    """
    def __init__(self, config: MuZeroConfig) -> None:
        self.replay_buffer = ReplayBuffer(config=config)

    def NumGames(self, request: replay_buffer_pb2.Empty,
                 context) -> replay_buffer_pb2.NumGamesResponse:
        return replay_buffer_pb2.NumGamesResponse(
            num_games=self.replay_buffer.num_games())

    def SaveHistory(self, request: replay_buffer_pb2.GameHistory,
                    context) -> replay_buffer_pb2.NumGamesResponse:
        self.replay_buffer.save_history(history_from_protobuf(request))

        return replay_buffer_pb2.NumGamesResponse(num_games=1)

    def SaveMultipleHistory(self, request_iterator: Iterable[
        replay_buffer_pb2.GameHistory],
                            context) -> replay_buffer_pb2.NumGamesResponse:

        num_games = 0
        for message in request_iterator:
            self.replay_buffer.save_history(history_from_protobuf(message))
            num_games += 1
        return replay_buffer_pb2.NumGamesResponse(num_games=num_games)

    def SampleBatch(self, request: replay_buffer_pb2.MiniBatchRequest,
                    context) -> Iterable[replay_buffer_pb2.MiniBatchResponse]:

        dataset = self.replay_buffer.as_dataset(batch_size=request.batch_size)

        for inputs, outputs in dataset:
            (batch_observations, batch_actions) = inputs
            (batch_target_rewards, batch_target_values,
             batch_target_policies) = outputs
            response = replay_buffer_pb2.MiniBatchResponse()
            response.batch_observations.CopyFrom(
                tf.make_tensor_proto(batch_observations))
            response.batch_actions.CopyFrom(
                tf.make_tensor_proto(batch_actions))
            response.batch_target_rewards.CopyFrom(
                tf.make_tensor_proto(batch_target_rewards))
            response.batch_target_values.CopyFrom(
                tf.make_tensor_proto(batch_target_values))
            response.batch_target_policies.CopyFrom(
                tf.make_tensor_proto(batch_target_policies))
            yield response

    def Stats(self, request: replay_buffer_pb2.StatsRequest,
              context) -> replay_buffer_pb2.StatsResponse:
        if request.detailed:
            return replay_buffer_pb2.StatsResponse(
                metrics=self.replay_buffer.detailed_stats())
        else:
            return replay_buffer_pb2.StatsResponse(
                metrics=self.replay_buffer.stats())

    def BackupBuffer(self, request: replay_buffer_pb2.Empty,
                     context) -> Iterable[replay_buffer_pb2.GameHistory]:
        for history in self.replay_buffer.buffer:
            yield history_to_protobuf(history)