Esempio n. 1
0
    def testTwoStepObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer0 = driver_test_utils.NumEpisodesObserver(
            variable_scope='observer0')
        num_episodes_observer1 = driver_test_utils.NumEpisodesObserver(
            variable_scope='observer1')
        num_steps_transition_observer = (
            driver_test_utils.NumStepsTransitionObserver())

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env,
            policy,
            num_episodes=5,
            observers=[num_episodes_observer0, num_episodes_observer1],
            transition_observers=[num_steps_transition_observer])
        run_driver = driver.run()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        self.assertEqual(self.evaluate(num_episodes_observer0.num_episodes), 5)
        self.assertEqual(self.evaluate(num_episodes_observer1.num_episodes), 5)
        self.assertEqual(
            self.evaluate(num_steps_transition_observer.num_steps), 10)
Esempio n. 2
0
    def testOneStepUpdatesObservers(self):
        if tf.executing_eagerly():
            self.skipTest('b/123880556')

        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        policy_state_ph = tensor_spec.to_nest_placeholder(
            policy.policy_state_spec,
            default=0,
            name_scope='policy_state_ph',
            outer_dims=(1, ))
        num_episodes_observer = driver_test_utils.NumEpisodesObserver()

        driver = dynamic_step_driver.DynamicStepDriver(
            env, policy, observers=[num_episodes_observer])
        run_driver = driver.run(policy_state=policy_state_ph)

        with self.session() as session:
            session.run(tf.compat.v1.global_variables_initializer())
            _, policy_state = session.run(run_driver)
            for _ in range(4):
                _, policy_state = session.run(
                    run_driver, feed_dict={policy_state_ph: policy_state})
            self.assertEqual(self.evaluate(num_episodes_observer.num_episodes),
                             2)
Esempio n. 3
0
    def testOneStepUpdatesObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer = driver_test_utils.NumEpisodesObserver()
        num_steps_observer = driver_test_utils.NumStepsObserver()
        num_steps_transition_observer = (
            driver_test_utils.NumStepsTransitionObserver())

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env,
            policy,
            observers=[num_episodes_observer, num_steps_observer],
            transition_observers=[num_steps_transition_observer])

        self.evaluate(tf.compat.v1.global_variables_initializer())
        for _ in range(5):
            self.evaluate(driver.run())

        self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5)
        # Two steps per episode.
        self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10)
        self.assertEqual(
            self.evaluate(num_steps_transition_observer.num_steps), 10)
Esempio n. 4
0
    def testTwoObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        policy_state = policy.get_initial_state(1)
        num_episodes_observer0 = driver_test_utils.NumEpisodesObserver(
            variable_scope='observer0')
        num_episodes_observer1 = driver_test_utils.NumEpisodesObserver(
            variable_scope='observer1')

        driver = dynamic_step_driver.DynamicStepDriver(
            env,
            policy,
            num_steps=5,
            observers=[num_episodes_observer0, num_episodes_observer1])
        run_driver = driver.run(policy_state=policy_state)

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        self.assertEqual(self.evaluate(num_episodes_observer0.num_episodes), 2)
        self.assertEqual(self.evaluate(num_episodes_observer1.num_episodes), 2)
Esempio n. 5
0
    def testMultiStepUpdatesObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer = driver_test_utils.NumEpisodesObserver()

        driver = dynamic_step_driver.DynamicStepDriver(
            env, policy, num_steps=5, observers=[num_episodes_observer])

        run_driver = driver.run()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 2)
    def testPolicyState(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())

        num_episodes_observer = driver_test_utils.NumEpisodesObserver()
        num_steps_observer = driver_test_utils.NumStepsObserver()

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env, policy, observers=[num_episodes_observer, num_steps_observer])
        run_driver = driver.run()

        self.evaluate(tf.compat.v1.global_variables_initializer())

        time_step, policy_state = self.evaluate(run_driver)

        self.assertEqual(time_step.step_type, 0)
        self.assertEqual(policy_state, [3])
Esempio n. 7
0
  def testNumEpisodesObserverEpisodeTotal(self, batch_size, traj_len):
    single_trajectory = np.concatenate([[ts.StepType.FIRST],
                                        np.repeat(ts.StepType.MID,
                                                  traj_len - 2),
                                        [ts.StepType.LAST]])
    step_type = np.tile(single_trajectory, (batch_size, 1))

    traj = trajectory.Trajectory(
        observation=np.random.rand(batch_size, traj_len),
        action=np.random.rand(batch_size, traj_len),
        policy_info=(),
        reward=np.random.rand(batch_size, traj_len),
        discount=np.ones((batch_size, traj_len)),
        step_type=step_type,
        next_step_type=np.zeros((batch_size, traj_len)))

    observer = driver_test_utils.NumEpisodesObserver()
    observer(traj)
    self.assertEqual(observer.num_episodes, batch_size)
    def testOneStepUpdatesObservers(self):
        if tf.executing_eagerly():
            self.skipTest('b/123880410')
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        num_episodes_observer = driver_test_utils.NumEpisodesObserver()
        num_steps_observer = driver_test_utils.NumStepsObserver()

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env, policy, observers=[num_episodes_observer, num_steps_observer])
        run_driver = driver.run()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        for _ in range(5):
            self.evaluate(run_driver)

        self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5)
        # Two steps per episode.
        self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10)