def testOneStepUpdatesObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) num_episodes_observer = driver_test_utils.NumEpisodesObserver() num_steps_observer = driver_test_utils.NumStepsObserver() num_steps_transition_observer = ( driver_test_utils.NumStepsTransitionObserver()) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, observers=[num_episodes_observer, num_steps_observer], transition_observers=[num_steps_transition_observer]) self.evaluate(tf.compat.v1.global_variables_initializer()) for _ in range(5): self.evaluate(driver.run()) self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5) # Two steps per episode. self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10) self.assertEqual( self.evaluate(num_steps_transition_observer.num_steps), 10)
def testTwoStepObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) num_episodes_observer0 = driver_test_utils.NumEpisodesObserver( variable_scope='observer0') num_episodes_observer1 = driver_test_utils.NumEpisodesObserver( variable_scope='observer1') num_steps_transition_observer = ( driver_test_utils.NumStepsTransitionObserver()) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, num_episodes=5, observers=[num_episodes_observer0, num_episodes_observer1], transition_observers=[num_steps_transition_observer]) run_driver = driver.run() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) self.assertEqual(self.evaluate(num_episodes_observer0.num_episodes), 5) self.assertEqual(self.evaluate(num_episodes_observer1.num_episodes), 5) self.assertEqual( self.evaluate(num_steps_transition_observer.num_steps), 10)