def testToTransitionHandlesTrajectoryFromDriverCorrectly(self): env = tf_py_environment.TFPyEnvironment( drivers_test_utils.PyEnvironmentMock()) policy = drivers_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer = drivers_test_utils.make_replay_buffer(policy) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, num_episodes=3, observers=[replay_buffer.add_batch]) run_driver = driver.run() rb_gather_all = replay_buffer.gather_all() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) trajectories = self.evaluate(rb_gather_all) time_steps, policy_step, next_time_steps = trajectory.to_transition( trajectories) self.assertAllEqual(time_steps.observation, trajectories.observation[:, :-1]) self.assertAllEqual(time_steps.step_type, trajectories.step_type[:, :-1]) self.assertAllEqual(next_time_steps.observation, trajectories.observation[:, 1:]) self.assertAllEqual(next_time_steps.step_type, trajectories.step_type[:, 1:]) self.assertAllEqual(next_time_steps.reward, trajectories.reward[:, :-1]) self.assertAllEqual(next_time_steps.discount, trajectories.discount[:, :-1]) self.assertAllEqual(policy_step.action, trajectories.action[:, :-1]) self.assertAllEqual(policy_step.info, trajectories.policy_info[:, :-1])
def testMultiStepReplayBufferObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer = driver_test_utils.make_replay_buffer(policy) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, num_episodes=3, observers=[replay_buffer.add_batch]) run_driver = driver.run() rb_gather_all = replay_buffer.gather_all() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) trajectories = self.evaluate(rb_gather_all) self.assertAllEqual(trajectories.step_type, [[0, 1, 2, 0, 1, 2, 0, 1, 2]]) self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2, 1]]) self.assertAllEqual(trajectories.observation, [[0, 1, 3, 0, 1, 3, 0, 1, 3]]) self.assertAllEqual(trajectories.policy_info, [[2, 4, 2, 2, 4, 2, 2, 4, 2]]) self.assertAllEqual(trajectories.next_step_type, [[1, 2, 0, 1, 2, 0, 1, 2, 0]]) self.assertAllEqual(trajectories.reward, [[1., 1., 0., 1., 1., 0., 1., 1., 0.]]) self.assertAllEqual(trajectories.discount, [[1., 0., 1., 1., 0., 1., 1., 0., 1.]])
def testOneStepReplayBufferObservers(self): if tf.executing_eagerly(): self.skipTest('b/123880556') env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) policy_state_ph = tensor_spec.to_nest_placeholder( policy.policy_state_spec, default=0, name_scope='policy_state_ph', outer_dims=(1, )) replay_buffer = driver_test_utils.make_replay_buffer(policy) driver = dynamic_step_driver.DynamicStepDriver( env, policy, num_steps=1, observers=[replay_buffer.add_batch]) run_driver = driver.run(policy_state=policy_state_ph) rb_gather_all = replay_buffer.gather_all() with self.session() as session: session.run(tf.compat.v1.global_variables_initializer()) _, policy_state = session.run(run_driver) for _ in range(5): _, policy_state = session.run( run_driver, feed_dict={policy_state_ph: policy_state}) trajectories = self.evaluate(rb_gather_all) self.assertAllEqual(trajectories.step_type, [[0, 1, 2, 0, 1, 2, 0, 1]]) self.assertAllEqual(trajectories.observation, [[0, 1, 3, 0, 1, 3, 0, 1]]) self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2]]) self.assertAllEqual(trajectories.policy_info, [[2, 4, 2, 2, 4, 2, 2, 4]]) self.assertAllEqual(trajectories.next_step_type, [[1, 2, 0, 1, 2, 0, 1, 2]]) self.assertAllEqual(trajectories.reward, [[1., 1., 0., 1., 1., 0., 1., 1.]]) self.assertAllEqual(trajectories.discount, [[1., 0., 1, 1, 0, 1., 1., 0.]])