def test_with_py_driver(self): env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) trajectory_spec = trajectory.from_transition(env.time_step_spec(), policy.policy_step_spec, env.time_step_spec()) trajectory_spec = tensor_spec.from_spec(trajectory_spec) tfrecord_observer = example_encoding_dataset.TFRecordObserver( self.dataset_path, trajectory_spec, py_mode=True) driver = py_driver.PyDriver(env, policy, [tfrecord_observer], max_steps=10) time_step = env.reset() driver.run(time_step) tfrecord_observer.flush() tfrecord_observer.close() dataset = example_encoding_dataset.load_tfrecord_dataset( [self.dataset_path], buffer_size=2, as_trajectories=True) iterator = eager_utils.dataset_iterator(dataset) sample = self.evaluate(eager_utils.get_next(iterator)) self.assertIsInstance(sample, trajectory.Trajectory)
def testRunOnce(self, max_steps, max_episodes, expected_steps): env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() transition_replay_buffer_observer = MockReplayBufferObserver() driver = py_driver.PyDriver( env, policy, observers=[replay_buffer_observer], transition_observers=[transition_replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes, ) initial_time_step = env.reset() initial_policy_state = policy.get_initial_state() driver.run(initial_time_step, initial_policy_state) trajectories = replay_buffer_observer.gather_all() self.assertEqual(trajectories, self._trajectories[:expected_steps]) transitions = transition_replay_buffer_observer.gather_all() self.assertLen(transitions, expected_steps) # TimeStep, Action, NextTimeStep self.assertLen(transitions[0], 3)
def test_get_collect_data_spec(self): env = suite_gym.load('CartPole-v0') policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) collect_spec = spec_utils.get_collect_data_spec_from_policy_and_env( env, policy) self.assertEqual(collect_spec.observation.name, 'observation') self.assertEqual(collect_spec.reward.name, 'reward')
def testBatchedEnvironment(self, max_steps, max_episodes, expected_length): expected_trajectories = [ trajectory.Trajectory( step_type=np.array([0, 0]), observation=np.array([0, 0]), action=np.array([2, 1]), policy_info=np.array([4, 2]), next_step_type=np.array([1, 1]), reward=np.array([1., 1.]), discount=np.array([1., 1.])), trajectory.Trajectory( step_type=np.array([1, 1]), observation=np.array([2, 1]), action=np.array([1, 2]), policy_info=np.array([2, 4]), next_step_type=np.array([2, 1]), reward=np.array([1., 1.]), discount=np.array([0., 1.])), trajectory.Trajectory( step_type=np.array([2, 1]), observation=np.array([3, 3]), action=np.array([2, 1]), policy_info=np.array([4, 2]), next_step_type=np.array([0, 2]), reward=np.array([0., 1.]), discount=np.array([1., 0.])) ] env1 = driver_test_utils.PyEnvironmentMock(final_state=3) env2 = driver_test_utils.PyEnvironmentMock(final_state=4) env = batched_py_environment.BatchedPyEnvironment([env1, env2]) policy = driver_test_utils.PyPolicyMock( env.time_step_spec(), env.action_spec(), initial_policy_state=np.array([1, 2])) replay_buffer_observer = MockReplayBufferObserver() driver = py_driver.PyDriver( env, policy, observers=[replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes, ) initial_time_step = env.reset() initial_policy_state = policy.get_initial_state() driver.run(initial_time_step, initial_policy_state) trajectories = replay_buffer_observer.gather_all() self.assertEqual( len(trajectories), len(expected_trajectories[:expected_length])) for t1, t2 in zip(trajectories, expected_trajectories[:expected_length]): for t1_field, t2_field in zip(t1, t2): self.assertAllEqual(t1_field, t2_field)
def testValueErrorOnInvalidArgs(self, max_steps, max_episodes): env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() with self.assertRaises(ValueError): py_driver.PyDriver( env, policy, observers=[replay_buffer_observer], max_steps=max_steps, max_episodes=max_episodes, )
def testMultipleRunMaxSteps(self): num_steps = 3 num_expected_steps = 4 env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() driver = py_driver.PyDriver( env, policy, observers=[replay_buffer_observer], max_steps=1, max_episodes=None, ) time_step = env.reset() policy_state = policy.get_initial_state() for _ in range(num_steps): time_step, policy_state = driver.run(time_step, policy_state) trajectories = replay_buffer_observer.gather_all() self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
def testPolicyStateReset(self): num_episodes = 2 num_expected_steps = 6 env = driver_test_utils.PyEnvironmentMock() policy = driver_test_utils.PyPolicyMock(env.time_step_spec(), env.action_spec()) replay_buffer_observer = MockReplayBufferObserver() driver = py_driver.PyDriver( env, policy, observers=[replay_buffer_observer], max_steps=None, max_episodes=num_episodes, ) time_step = env.reset() policy_state = policy.get_initial_state() time_step, policy_state = driver.run(time_step, policy_state) trajectories = replay_buffer_observer.gather_all() self.assertEqual(trajectories, self._trajectories[:num_expected_steps]) self.assertEqual(num_episodes, policy.get_initial_state_call_count)