コード例 #1
0
    def testAverageOneEpisode(self, metric_class, expected_result):
        metric = metric_class()

        metric(trajectory.boundary((), (), (), 0., 1.))
        metric(trajectory.mid((), (), (), 1., 1.))
        metric(trajectory.mid((), (), (), 2., 1.))
        metric(trajectory.last((), (), (), 3., 0.))
        self.assertEqual(expected_result, metric.result())
コード例 #2
0
    def testBatchSizeProvided(self, metric_class, expected_result):
        metric = py_metrics.AverageReturnMetric(batch_size=2)

        metric(
            nest_utils.stack_nested_arrays([
                trajectory.boundary((), (), (), 0., 1.),
                trajectory.boundary((), (), (), 0., 1.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.first((), (), (), 1., 1.),
                trajectory.first((), (), (), 1., 1.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.mid((), (), (), 2., 1.),
                trajectory.last((), (), (), 3., 0.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.last((), (), (), 3., 0.),
                trajectory.boundary((), (), (), 0., 1.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.boundary((), (), (), 0., 1.),
                trajectory.first((), (), (), 1., 1.)
            ]))
        self.assertEqual(metric.result(), 5.0)
コード例 #3
0
    def testBatch(self, metric_class, expected_result):
        metric = metric_class()

        metric(
            nest_utils.stack_nested_arrays([
                trajectory.boundary((), (), (), 0., 1.),
                trajectory.boundary((), (), (), 0., 1.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.first((), (), (), 1., 1.),
                trajectory.first((), (), (), 1., 1.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.mid((), (), (), 2., 1.),
                trajectory.last((), (), (), 3., 0.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.last((), (), (), 3., 0.),
                trajectory.boundary((), (), (), 0., 1.)
            ]))
        metric(
            nest_utils.stack_nested_arrays([
                trajectory.boundary((), (), (), 0., 1.),
                trajectory.first((), (), (), 1., 1.)
            ]))
        self.assertEqual(expected_result, metric.result(), 5.0)
コード例 #4
0
    def testAverageOneEpisodeWithReset(self, metric_class, expected_result):
        metric = metric_class()

        metric(trajectory.first((), (), (), 0., 1.))
        metric(trajectory.mid((), (), (), 1., 1.))
        metric(trajectory.mid((), (), (), 2., 1.))
        # The episode is reset.
        #
        # This could happen when using the dynamic_episode_driver with
        # parallel_py_environment. When the parallel episodes are of different
        # lengths and num_episodes is reached, some episodes would be left in "MID".
        # When the driver runs again, all environments are reset at the beginning
        # of the tf.while_loop and the unfinished episodes would get "FIRST" without
        # seeing "LAST".
        metric(trajectory.first((), (), (), 3., 1.))
        metric(trajectory.last((), (), (), 4., 1.))
        self.assertEqual(expected_result, metric.result())
コード例 #5
0
    def testSaveRestore(self):
        metrics = [
            py_metrics.AverageReturnMetric(),
            py_metrics.AverageEpisodeLengthMetric(),
            py_metrics.EnvironmentSteps(),
            py_metrics.NumberOfEpisodes()
        ]

        for metric in metrics:
            metric(trajectory.boundary((), (), (), 0., 1.))
            metric(trajectory.mid((), (), (), 1., 1.))
            metric(trajectory.mid((), (), (), 2., 1.))
            metric(trajectory.last((), (), (), 3., 0.))

        checkpoint = tf.train.Checkpoint(**{m.name: m for m in metrics})
        prefix = self.get_temp_dir() + '/ckpt'
        save_path = checkpoint.save(prefix)
        for metric in metrics:
            metric.reset()
            self.assertEqual(0, metric.result())
        checkpoint.restore(save_path).assert_consumed()
        for metric in metrics:
            self.assertGreater(metric.result(), 0)
コード例 #6
0
    def testAverageTwoEpisode(self, metric_class, expected_result):
        metric = metric_class()

        metric(trajectory.boundary((), (), (), 0., 1.))
        metric(trajectory.first((), (), (), 1., 1.))
        metric(trajectory.mid((), (), (), 2., 1.))
        metric(trajectory.last((), (), (), 3., 0.))
        metric(trajectory.boundary((), (), (), 0., 1.))

        # TODO(kbanoop): Add optional next_step_type arg to trajectory.first. Or
        # implement trajectory.first_last().
        metric(
            trajectory.Trajectory(ts.StepType.FIRST, (), (), (),
                                  ts.StepType.LAST, -6., 1.))

        self.assertEqual(expected_result, metric.result())
コード例 #7
0
def trajectory_mid(observation):
    return trajectory.mid(observation=observation,
                          action=1,
                          policy_info=(),
                          reward=np.array(1, dtype=np.float32),
                          discount=1.0)