Esempio n. 1
0
    def testToTransitionHandlesTrajectoryFromDriverCorrectly(self):
        env = tf_py_environment.TFPyEnvironment(
            drivers_test_utils.PyEnvironmentMock())
        policy = drivers_test_utils.TFPolicyMock(env.time_step_spec(),
                                                 env.action_spec())
        replay_buffer = drivers_test_utils.make_replay_buffer(policy)

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env, policy, num_episodes=3, observers=[replay_buffer.add_batch])

        run_driver = driver.run()
        rb_gather_all = replay_buffer.gather_all()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        trajectories = self.evaluate(rb_gather_all)

        time_steps, policy_step, next_time_steps = trajectory.to_transition(
            trajectories)

        self.assertAllEqual(time_steps.observation,
                            trajectories.observation[:, :-1])
        self.assertAllEqual(time_steps.step_type,
                            trajectories.step_type[:, :-1])
        self.assertAllEqual(next_time_steps.observation,
                            trajectories.observation[:, 1:])
        self.assertAllEqual(next_time_steps.step_type,
                            trajectories.step_type[:, 1:])
        self.assertAllEqual(next_time_steps.reward,
                            trajectories.reward[:, :-1])
        self.assertAllEqual(next_time_steps.discount,
                            trajectories.discount[:, :-1])

        self.assertAllEqual(policy_step.action, trajectories.action[:, :-1])
        self.assertAllEqual(policy_step.info, trajectories.policy_info[:, :-1])
    def testMultiStepReplayBufferObservers(self):
        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        replay_buffer = driver_test_utils.make_replay_buffer(policy)

        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env, policy, num_episodes=3, observers=[replay_buffer.add_batch])

        run_driver = driver.run()
        rb_gather_all = replay_buffer.gather_all()

        self.evaluate(tf.compat.v1.global_variables_initializer())
        self.evaluate(run_driver)
        trajectories = self.evaluate(rb_gather_all)

        self.assertAllEqual(trajectories.step_type,
                            [[0, 1, 2, 0, 1, 2, 0, 1, 2]])
        self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2, 1]])
        self.assertAllEqual(trajectories.observation,
                            [[0, 1, 3, 0, 1, 3, 0, 1, 3]])
        self.assertAllEqual(trajectories.policy_info,
                            [[2, 4, 2, 2, 4, 2, 2, 4, 2]])
        self.assertAllEqual(trajectories.next_step_type,
                            [[1, 2, 0, 1, 2, 0, 1, 2, 0]])
        self.assertAllEqual(trajectories.reward,
                            [[1., 1., 0., 1., 1., 0., 1., 1., 0.]])
        self.assertAllEqual(trajectories.discount,
                            [[1., 0., 1., 1., 0., 1., 1., 0., 1.]])
Esempio n. 3
0
    def testOneStepReplayBufferObservers(self):
        if tf.executing_eagerly():
            self.skipTest('b/123880556')

        env = tf_py_environment.TFPyEnvironment(
            driver_test_utils.PyEnvironmentMock())
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        policy_state_ph = tensor_spec.to_nest_placeholder(
            policy.policy_state_spec,
            default=0,
            name_scope='policy_state_ph',
            outer_dims=(1, ))
        replay_buffer = driver_test_utils.make_replay_buffer(policy)

        driver = dynamic_step_driver.DynamicStepDriver(
            env, policy, num_steps=1, observers=[replay_buffer.add_batch])

        run_driver = driver.run(policy_state=policy_state_ph)
        rb_gather_all = replay_buffer.gather_all()

        with self.session() as session:
            session.run(tf.compat.v1.global_variables_initializer())
            _, policy_state = session.run(run_driver)
            for _ in range(5):
                _, policy_state = session.run(
                    run_driver, feed_dict={policy_state_ph: policy_state})

            trajectories = self.evaluate(rb_gather_all)

        self.assertAllEqual(trajectories.step_type, [[0, 1, 2, 0, 1, 2, 0, 1]])
        self.assertAllEqual(trajectories.observation,
                            [[0, 1, 3, 0, 1, 3, 0, 1]])
        self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2]])
        self.assertAllEqual(trajectories.policy_info,
                            [[2, 4, 2, 2, 4, 2, 2, 4]])
        self.assertAllEqual(trajectories.next_step_type,
                            [[1, 2, 0, 1, 2, 0, 1, 2]])
        self.assertAllEqual(trajectories.reward,
                            [[1., 1., 0., 1., 1., 0., 1., 1.]])
        self.assertAllEqual(trajectories.discount,
                            [[1., 0., 1, 1, 0, 1., 1., 0.]])