class ReplayBufferServer(replay_buffer_pb2_grpc.ReplayBufferServicer): """ A server for replay buffers, exposing their functionality through a gRPC API. """ def __init__(self, config: MuZeroConfig) -> None: self.replay_buffer = ReplayBuffer(config=config) def NumGames(self, request: replay_buffer_pb2.Empty, context) -> replay_buffer_pb2.NumGamesResponse: return replay_buffer_pb2.NumGamesResponse( num_games=self.replay_buffer.num_games()) def SaveHistory(self, request: replay_buffer_pb2.GameHistory, context) -> replay_buffer_pb2.NumGamesResponse: self.replay_buffer.save_history(history_from_protobuf(request)) return replay_buffer_pb2.NumGamesResponse(num_games=1) def SaveMultipleHistory(self, request_iterator: Iterable[ replay_buffer_pb2.GameHistory], context) -> replay_buffer_pb2.NumGamesResponse: num_games = 0 for message in request_iterator: self.replay_buffer.save_history(history_from_protobuf(message)) num_games += 1 return replay_buffer_pb2.NumGamesResponse(num_games=num_games) def SampleBatch(self, request: replay_buffer_pb2.MiniBatchRequest, context) -> Iterable[replay_buffer_pb2.MiniBatchResponse]: dataset = self.replay_buffer.as_dataset(batch_size=request.batch_size) for inputs, outputs in dataset: (batch_observations, batch_actions) = inputs (batch_target_rewards, batch_target_values, batch_target_policies) = outputs response = replay_buffer_pb2.MiniBatchResponse() response.batch_observations.CopyFrom( tf.make_tensor_proto(batch_observations)) response.batch_actions.CopyFrom( tf.make_tensor_proto(batch_actions)) response.batch_target_rewards.CopyFrom( tf.make_tensor_proto(batch_target_rewards)) response.batch_target_values.CopyFrom( tf.make_tensor_proto(batch_target_values)) response.batch_target_policies.CopyFrom( tf.make_tensor_proto(batch_target_policies)) yield response def Stats(self, request: replay_buffer_pb2.StatsRequest, context) -> replay_buffer_pb2.StatsResponse: if request.detailed: return replay_buffer_pb2.StatsResponse( metrics=self.replay_buffer.detailed_stats()) else: return replay_buffer_pb2.StatsResponse( metrics=self.replay_buffer.stats()) def BackupBuffer(self, request: replay_buffer_pb2.Empty, context) -> Iterable[replay_buffer_pb2.GameHistory]: for history in self.replay_buffer.buffer: yield history_to_protobuf(history)