def testBatchedEnvironment(self, max_steps, max_episodes, expected_length): expected_trajectories = [ trajectory.Trajectory( step_type=np.array([0, 0]), observation=np.array([0, 0]), action=np.array([2, 1]), policy_info=np.array([4, 2]), next_step_type=np.array([1, 1]), reward=np.array([1., 1.]), discount=np.array([1., 1.])), trajectory.Trajectory( step_type=np.array([1, 1]), observation=np.array([2, 1]), action=np.array([1, 2]), policy_info=np.array([2, 4]), next_step_type=np.array([2, 1]), reward=np.array([1., 1.]), discount=np.array([0., 1.])), trajectory.Trajectory( step_type=np.array([2, 1]), observation=np.array([3, 3]), action=np.array([2, 1]), policy_info=np.array([4, 2]), next_step_type=np.array([0, 2]), reward=np.array([0., 1.]), discount=np.array([1., 0.])) ] env1 = driver_test_utils.PyEnvironmentMock(final_state=3) env2 = driver_test_utils.PyEnvironmentMock(final_state=4) env = batched_py_environment.BatchedPyEnvironment([env1, env2]) tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock( tf_env.time_step_spec(), tf_env.action_spec(), batch_size=2, initial_policy_state=tf.constant([1, 2], dtype=tf.int32)) replay_buffer_observer = MockReplayBufferObserver() driver = tf_driver.TFDriver( tf_env, policy, observers=[replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes, ) initial_time_step = tf_env.reset() initial_policy_state = tf.constant([1, 2], dtype=tf.int32) self.evaluate(driver.run(initial_time_step, initial_policy_state)) trajectories = replay_buffer_observer.gather_all() self.assertEqual( len(trajectories), len(expected_trajectories[:expected_length])) for t1, t2 in zip(trajectories, expected_trajectories[:expected_length]): for t1_field, t2_field in zip(t1, t2): self.assertAllEqual(t1_field, t2_field)
def testOneStepUpdatesObservers(self): if tf.executing_eagerly(): self.skipTest('b/123880556') env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) policy_state_ph = tensor_spec.to_nest_placeholder( policy.policy_state_spec, default=0, name_scope='policy_state_ph', outer_dims=(1, )) num_episodes_observer = driver_test_utils.NumEpisodesObserver() driver = dynamic_step_driver.DynamicStepDriver( env, policy, observers=[num_episodes_observer]) run_driver = driver.run(policy_state=policy_state_ph) with self.session() as session: session.run(tf.compat.v1.global_variables_initializer()) _, policy_state = session.run(run_driver) for _ in range(4): _, policy_state = session.run( run_driver, feed_dict={policy_state_ph: policy_state}) self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 2)
def testMultipleRunMaxEpisodes(self): num_episodes = 2 num_expected_steps = 6 env = driver_test_utils.PyEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(), tf_env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() driver = tf_driver.TFDriver( tf_env, policy, observers=[replay_buffer_observer], max_steps=None, max_episodes=1, ) time_step = tf_env.reset() policy_state = policy.get_initial_state(batch_size=1) for _ in range(num_episodes): time_step, policy_state = self.evaluate( driver.run(time_step, policy_state)) trajectories = replay_buffer_observer.gather_all() self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
def testOneStepUpdatesObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) num_episodes_observer = driver_test_utils.NumEpisodesObserver() num_steps_observer = driver_test_utils.NumStepsObserver() num_steps_transition_observer = ( driver_test_utils.NumStepsTransitionObserver()) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, observers=[num_episodes_observer, num_steps_observer], transition_observers=[num_steps_transition_observer]) self.evaluate(tf.compat.v1.global_variables_initializer()) for _ in range(5): self.evaluate(driver.run()) self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5) # Two steps per episode. self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10) self.assertEqual( self.evaluate(num_steps_transition_observer.num_steps), 10)
def testTwoStepObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) num_episodes_observer0 = driver_test_utils.NumEpisodesObserver( variable_scope='observer0') num_episodes_observer1 = driver_test_utils.NumEpisodesObserver( variable_scope='observer1') num_steps_transition_observer = ( driver_test_utils.NumStepsTransitionObserver()) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, num_episodes=5, observers=[num_episodes_observer0, num_episodes_observer1], transition_observers=[num_steps_transition_observer]) run_driver = driver.run() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) self.assertEqual(self.evaluate(num_episodes_observer0.num_episodes), 5) self.assertEqual(self.evaluate(num_episodes_observer1.num_episodes), 5) self.assertEqual( self.evaluate(num_steps_transition_observer.num_steps), 10)
def test_with_dynamic_step_driver(self): env = driver_test_utils.PyEnvironmentMock() tf_env = tf_py_environment.TFPyEnvironment(env) policy = driver_test_utils.TFPolicyMock(tf_env.time_step_spec(), tf_env.action_spec()) trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(), policy.policy_step_spec, tf_env.time_step_spec()) tfrecord_observer = example_encoding_dataset.TFRecordObserver( self.dataset_path, trajectory_spec) driver = dynamic_step_driver.DynamicStepDriver( tf_env, policy, observers=[common.function(tfrecord_observer)], num_steps=10) self.evaluate(tf.compat.v1.global_variables_initializer()) time_step = self.evaluate(tf_env.reset()) initial_policy_state = policy.get_initial_state(batch_size=1) self.evaluate( common.function(driver.run)(time_step, initial_policy_state)) tfrecord_observer.flush() tfrecord_observer.close() dataset = example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path], buffer_size=2, as_trajectories=True) iterator = eager_utils.dataset_iterator(dataset) sample = self.evaluate(eager_utils.get_next(iterator)) self.assertIsInstance(sample, trajectory.Trajectory)
def test_with_py_driver(self): env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) trajectory_spec = trajectory.from_transition(env.time_step_spec(), policy.policy_step_spec, env.time_step_spec()) trajectory_spec = tensor_spec.from_spec(trajectory_spec) tfrecord_observer = example_encoding_dataset.TFRecordObserver( self.dataset_path, trajectory_spec, py_mode=True) driver = py_driver.PyDriver(env, policy, [tfrecord_observer], max_steps=10) time_step = env.reset() driver.run(time_step) tfrecord_observer.flush() tfrecord_observer.close() dataset = example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path], buffer_size=2, as_trajectories=True) iterator = eager_utils.dataset_iterator(dataset) sample = self.evaluate(eager_utils.get_next(iterator)) self.assertIsInstance(sample, trajectory.Trajectory)
def testToTransitionHandlesTrajectoryFromDriverCorrectly(self): env = tf_py_environment.TFPyEnvironment( drivers_test_utils.PyEnvironmentMock()) policy = drivers_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer = drivers_test_utils.make_replay_buffer(policy) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, num_episodes=3, observers=[replay_buffer.add_batch]) run_driver = driver.run() rb_gather_all = replay_buffer.gather_all() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) trajectories = self.evaluate(rb_gather_all) time_steps, policy_step, next_time_steps = trajectory.to_transition( trajectories) self.assertAllEqual(time_steps.observation, trajectories.observation[:, :-1]) self.assertAllEqual(time_steps.step_type, trajectories.step_type[:, :-1]) self.assertAllEqual(next_time_steps.observation, trajectories.observation[:, 1:]) self.assertAllEqual(next_time_steps.step_type, trajectories.step_type[:, 1:]) self.assertAllEqual(next_time_steps.reward, trajectories.reward[:, :-1]) self.assertAllEqual(next_time_steps.discount, trajectories.discount[:, :-1]) self.assertAllEqual(policy_step.action, trajectories.action[:, :-1]) self.assertAllEqual(policy_step.info, trajectories.policy_info[:, :-1])
def testMultiStepReplayBufferObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer = driver_test_utils.make_replay_buffer(policy) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, num_episodes=3, observers=[replay_buffer.add_batch]) run_driver = driver.run() rb_gather_all = replay_buffer.gather_all() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) trajectories = self.evaluate(rb_gather_all) self.assertAllEqual(trajectories.step_type, [[0, 1, 2, 0, 1, 2, 0, 1, 2]]) self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2, 1]]) self.assertAllEqual(trajectories.observation, [[0, 1, 3, 0, 1, 3, 0, 1, 3]]) self.assertAllEqual(trajectories.policy_info, [[2, 4, 2, 2, 4, 2, 2, 4, 2]]) self.assertAllEqual(trajectories.next_step_type, [[1, 2, 0, 1, 2, 0, 1, 2, 0]]) self.assertAllEqual(trajectories.reward, [[1., 1., 0., 1., 1., 0., 1., 1., 0.]]) self.assertAllEqual(trajectories.discount, [[1., 0., 1., 1., 0., 1., 1., 0., 1.]])
def testRunOnce(self, max_steps, max_episodes, expected_steps): env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() transition_replay_buffer_observer = MockReplayBufferObserver() driver = py_driver.PyDriver( env, policy, observers=[replay_buffer_observer], transition_observers=[transition_replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes, ) initial_time_step = env.reset() initial_policy_state = policy.get_initial_state() driver.run(initial_time_step, initial_policy_state) trajectories = replay_buffer_observer.gather_all() self.assertEqual(trajectories, self._trajectories[:expected_steps]) transitions = transition_replay_buffer_observer.gather_all() self.assertLen(transitions, expected_steps) # TimeStep, Action, NextTimeStep self.assertLen(transitions[0], 3)
def testToTransitionSpec(self): env = tf_py_environment.TFPyEnvironment( drivers_test_utils.PyEnvironmentMock()) policy = drivers_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) trajectory_spec = policy.trajectory_spec ts_spec, ps_spec, nts_spec = trajectory.to_transition_spec( trajectory_spec) self.assertAllEqual(ts_spec, env.time_step_spec()) self.assertAllEqual(ps_spec.action, env.action_spec()) self.assertAllEqual(nts_spec, env.time_step_spec())
def testValueErrorOnInvalidArgs(self, max_steps, max_episodes): env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() with self.assertRaises(ValueError): py_driver.PyDriver( env, policy, observers=[replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes, )
def testMultiStepUpdatesObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) num_episodes_observer = driver_test_utils.NumEpisodesObserver() driver = dynamic_step_driver.DynamicStepDriver( env, policy, num_steps=5, observers=[num_episodes_observer]) run_driver = driver.run() self.evaluate(tf.compat.v1.global_variables_initializer()) self.evaluate(run_driver) self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 2)
def testMultiStepUpdatesObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) num_episodes_observer = NumEpisodesObserver() num_steps_observer = NumStepsObserver() driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, observers=[num_episodes_observer, num_steps_observer]) run_driver = driver.run(num_episodes=5) self.evaluate(tf.global_variables_initializer()) self.evaluate(run_driver) self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5) # Two steps per episode. self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10)
def testPolicyState(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) num_episodes_observer = driver_test_utils.NumEpisodesObserver() num_steps_observer = driver_test_utils.NumStepsObserver() driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, observers=[num_episodes_observer, num_steps_observer]) run_driver = driver.run() self.evaluate(tf.compat.v1.global_variables_initializer()) time_step, policy_state = self.evaluate(run_driver) self.assertEqual(time_step.step_type, 0) self.assertEqual(policy_state, [3])
def testOneStepReplayBufferObservers(self): if tf.executing_eagerly(): self.skipTest('b/123880556') env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) policy_state_ph = tensor_spec.to_nest_placeholder( policy.policy_state_spec, default=0, name_scope='policy_state_ph', outer_dims=(1, )) replay_buffer = driver_test_utils.make_replay_buffer(policy) driver = dynamic_step_driver.DynamicStepDriver( env, policy, num_steps=1, observers=[replay_buffer.add_batch]) run_driver = driver.run(policy_state=policy_state_ph) rb_gather_all = replay_buffer.gather_all() with self.session() as session: session.run(tf.compat.v1.global_variables_initializer()) _, policy_state = session.run(run_driver) for _ in range(5): _, policy_state = session.run( run_driver, feed_dict={policy_state_ph: policy_state}) trajectories = self.evaluate(rb_gather_all) self.assertAllEqual(trajectories.step_type, [[0, 1, 2, 0, 1, 2, 0, 1]]) self.assertAllEqual(trajectories.observation, [[0, 1, 3, 0, 1, 3, 0, 1]]) self.assertAllEqual(trajectories.action, [[1, 2, 1, 1, 2, 1, 1, 2]]) self.assertAllEqual(trajectories.policy_info, [[2, 4, 2, 2, 4, 2, 2, 4]]) self.assertAllEqual(trajectories.next_step_type, [[1, 2, 0, 1, 2, 0, 1, 2]]) self.assertAllEqual(trajectories.reward, [[1., 1., 0., 1., 1., 0., 1., 1.]]) self.assertAllEqual(trajectories.discount, [[1., 0., 1, 1, 0, 1., 1., 0.]])
def testOneStepUpdatesObservers(self): if tf.executing_eagerly(): self.skipTest('b/123880410') env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) num_episodes_observer = driver_test_utils.NumEpisodesObserver() num_steps_observer = driver_test_utils.NumStepsObserver() driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, observers=[num_episodes_observer, num_steps_observer]) run_driver = driver.run() self.evaluate(tf.compat.v1.global_variables_initializer()) for _ in range(5): self.evaluate(run_driver) self.assertEqual(self.evaluate(num_episodes_observer.num_episodes), 5) # Two steps per episode. self.assertEqual(self.evaluate(num_steps_observer.num_steps), 10)
def testPolicyStateReset(self): num_episodes = 2 num_expected_steps = 6 env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() driver = py_driver.PyDriver( env, policy, observers=[replay_buffer_observer], max_steps=None, max_episodes=num_episodes, ) time_step = env.reset() policy_state = policy.get_initial_state() time_step, policy_state = driver.run(time_step, policy_state) trajectories = replay_buffer_observer.gather_all() self.assertEqual(trajectories, self._trajectories[:num_expected_steps]) self.assertEqual(num_episodes, policy.get_initial_state_call_count)
def testTwoObservers(self): env = tf_py_environment.TFPyEnvironment( driver_test_utils.PyEnvironmentMock()) policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec()) policy_state = policy.get_initial_state(1) num_episodes_observer0 = NumEpisodesObserver( variable_scope='observer0') num_episodes_observer1 = NumEpisodesObserver( variable_scope='observer1') driver = dynamic_step_driver.DynamicStepDriver( env, policy, num_steps=5, observers=[num_episodes_observer0, num_episodes_observer1]) run_driver = driver.run(policy_state=policy_state) self.evaluate(tf.global_variables_initializer()) self.evaluate(run_driver) self.assertEqual(self.evaluate(num_episodes_observer0.num_episodes), 2) self.assertEqual(self.evaluate(num_episodes_observer1.num_episodes), 2)
def testMultipleRunMaxSteps(self): num_steps = 3 num_expected_steps = 4 env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() driver = py_driver.PyDriver( env, policy, observers=[replay_buffer_observer], max_steps=1, max_episodes=None, ) time_step = env.reset() policy_state = policy.get_initial_state() for _ in range(num_steps): time_step, policy_state = driver.run(time_step, policy_state) trajectories = replay_buffer_observer.gather_all() self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
def testMultiStepEpisodicReplayBuffer(self): num_episodes = 5 num_driver_episodes = 5 # Create mock environment. py_env = batched_py_environment.BatchedPyEnvironment([ driver_test_utils.PyEnvironmentMock(final_state=i + 1) for i in range(num_episodes) ]) env = tf_py_environment.TFPyEnvironment(py_env) # Creat mock policy. policy = driver_test_utils.TFPolicyMock(env.time_step_spec(), env.action_spec(), batch_size=num_episodes) # Create replay buffer and driver. replay_buffer = self._make_replay_buffer(env) stateful_buffer = episodic_replay_buffer.StatefulEpisodicReplayBuffer( replay_buffer, num_episodes) driver = dynamic_episode_driver.DynamicEpisodeDriver( env, policy, num_episodes=num_driver_episodes, observers=[stateful_buffer.add_batch]) run_driver = driver.run() end_episodes = replay_buffer._maybe_end_batch_episodes( stateful_buffer.episode_ids, end_episode=True) completed_episodes = replay_buffer._completed_episodes() self.evaluate([ tf.compat.v1.local_variables_initializer(), tf.compat.v1.global_variables_initializer() ]) self.evaluate(run_driver) self.evaluate(end_episodes) completed_episodes = self.evaluate(completed_episodes) eps = [replay_buffer._get_episode(ep) for ep in completed_episodes] eps = self.evaluate(eps) episodes_length = [tf.nest.flatten(ep)[0].shape[0] for ep in eps] # Compare with expected output. self.assertAllEqual(completed_episodes, [3, 4, 5, 6, 7]) self.assertAllEqual(episodes_length, [4, 4, 2, 1, 1]) first = ts.StepType.FIRST mid = ts.StepType.MID last = ts.StepType.LAST step_types = [ep.step_type for ep in eps] observations = [ep.observation for ep in eps] rewards = [ep.reward for ep in eps] actions = [ep.action for ep in eps] self.assertAllClose([[first, mid, mid, last], [first, mid, mid, mid], [first, last], [first], [first]], step_types) self.assertAllClose([ [0, 1, 3, 4], [0, 1, 3, 4], [0, 1], [0], [0], ], observations) self.assertAllClose([ [1, 2, 1, 2], [1, 2, 1, 2], [1, 2], [1], [1], ], actions) self.assertAllClose([ [1, 1, 1, 0], [1, 1, 1, 1], [1, 0], [1], [1], ], rewards)