Exemplo n.º 1
0
    def testOneStepUpdatesObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer = driver_test_utils.NumEpisodesObserver()
        num_steps_observer = driver_test_utils.NumStepsObserver()
        num_steps_transition_observer = (
            driver_test_utils.NumStepsTransitionObserver())

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env,
            policy,
            observers=[num_episodes_observer, num_steps_observer],
            transition_observers=[num_steps_transition_observer])

        self.evaluate(tf.compat.v1.global_variables_initializer())
        for _ in range(5):
            self.evaluate(driver.run())

        self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5)
        # Two steps per episode.
        self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10)
        self.assertEqual(
            self.evaluate(num_steps_transition_observer.num_steps), 10)
Exemplo n.º 2
0
    def testTwoStepObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer0 = driver_test_utils.NumEpisodesObserver(
            variable_scope='observer0')
        num_episodes_observer1 = driver_test_utils.NumEpisodesObserver(
            variable_scope='observer1')
        num_steps_transition_observer = (
            driver_test_utils.NumStepsTransitionObserver())

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env,
            policy,
            num_episodes=5,
            observers=[num_episodes_observer0, num_episodes_observer1],
            transition_observers=[num_steps_transition_observer])
        run_driver = driver.run()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        self.assertEqual(self.evaluate(num_episodes_observer0.num_episodes), 5)
        self.assertEqual(self.evaluate(num_episodes_observer1.num_episodes), 5)
        self.assertEqual(
            self.evaluate(num_steps_transition_observer.num_steps), 10)