예제 #1
0
class StateVisitationHistogram(tf_metric.TFHistogramStepMetric):
    """Metric to compute the frequency of states visited."""
    def __init__(self,
                 state_selection_function,
                 state_shape=(),
                 name='StateVisitationHistogram',
                 dtype=tf.float64,
                 buffer_size=100):
        super(StateVisitationHistogram, self).__init__(name=name)
        self._buffer = TFDeque(buffer_size, dtype, shape=state_shape)
        self._dtype = dtype
        self._state_selection_function = state_selection_function

    @common.function
    def call(self, trajectory):
        self._buffer.extend(
            self._state_selection_function(trajectory.observation))
        return trajectory

    @common.function
    def result(self):
        return self._buffer.data

    @common.function
    def reset(self):
        self._buffer.clear()
예제 #2
0
class RewardHistogram(tf_metric.TFHistogramStepMetric):
    """Metric to compute the frequency of rewards."""
    def __init__(self,
                 name='RewardHistogram',
                 dtype=tf.int32,
                 buffer_size=100):
        super(RewardHistogram, self).__init__(name=name)
        self._buffer = TFDeque(buffer_size, dtype)
        self._dtype = dtype

    @common.function
    def call(self, trajectory):
        self._buffer.extend(trajectory.reward)
        return trajectory

    @common.function
    def result(self):
        return self._buffer.data

    @common.function
    def reset(self):
        self._buffer.clear()