def test_with_py_driver(self):
        env = driver_test_utils.PyEnvironmentMock()
        policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        trajectory_spec = trajectory.from_transition(env.time_step_spec(),
                                                     policy.policy_step_spec,
                                                     env.time_step_spec())
        trajectory_spec = tensor_spec.from_spec(trajectory_spec)

        tfrecord_observer = example_encoding_dataset.TFRecordObserver(
            self.dataset_path, trajectory_spec, py_mode=True)

        driver = py_driver.PyDriver(env,
                                    policy, [tfrecord_observer],
                                    max_steps=10)
        time_step = env.reset()
        driver.run(time_step)
        tfrecord_observer.flush()
        tfrecord_observer.close()

        dataset = example_encoding_dataset.load_tfrecord_dataset(
            [self.dataset_path], buffer_size=2, as_trajectories=True)

        iterator = eager_utils.dataset_iterator(dataset)
        sample = self.evaluate(eager_utils.get_next(iterator))
        self.assertIsInstance(sample, trajectory.Trajectory)
Exemple #2
0
    def testRunOnce(self, max_steps, max_episodes, expected_steps):
        env = driver_test_utils.PyEnvironmentMock()
        policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                                env.action_spec())
        replay_buffer_observer = MockReplayBufferObserver()
        transition_replay_buffer_observer = MockReplayBufferObserver()
        driver = py_driver.PyDriver(
            env,
            policy,
            observers=[replay_buffer_observer],
            transition_observers=[transition_replay_buffer_observer],
            max_steps=max_steps,
            max_episodes=max_episodes,
        )

        initial_time_step = env.reset()
        initial_policy_state = policy.get_initial_state()
        driver.run(initial_time_step, initial_policy_state)
        trajectories = replay_buffer_observer.gather_all()
        self.assertEqual(trajectories, self._trajectories[:expected_steps])

        transitions = transition_replay_buffer_observer.gather_all()
        self.assertLen(transitions, expected_steps)
        # TimeStep, Action, NextTimeStep
        self.assertLen(transitions[0], 3)
Exemple #3
0
 def test_get_collect_data_spec(self):
     env = suite_gym.load('CartPole-v0')
     policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                             env.action_spec())
     collect_spec = spec_utils.get_collect_data_spec_from_policy_and_env(
         env, policy)
     self.assertEqual(collect_spec.observation.name, 'observation')
     self.assertEqual(collect_spec.reward.name, 'reward')
Exemple #4
0
  def testBatchedEnvironment(self, max_steps, max_episodes, expected_length):

    expected_trajectories = [
        trajectory.Trajectory(
            step_type=np.array([0, 0]),
            observation=np.array([0, 0]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([1, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([1., 1.])),
        trajectory.Trajectory(
            step_type=np.array([1, 1]),
            observation=np.array([2, 1]),
            action=np.array([1, 2]),
            policy_info=np.array([2, 4]),
            next_step_type=np.array([2, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([0., 1.])),
        trajectory.Trajectory(
            step_type=np.array([2, 1]),
            observation=np.array([3, 3]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([0, 2]),
            reward=np.array([0., 1.]),
            discount=np.array([1., 0.]))
    ]

    env1 = driver_test_utils.PyEnvironmentMock(final_state=3)
    env2 = driver_test_utils.PyEnvironmentMock(final_state=4)
    env = batched_py_environment.BatchedPyEnvironment([env1, env2])

    policy = driver_test_utils.PyPolicyMock(
        env.time_step_spec(),
        env.action_spec(),
        initial_policy_state=np.array([1, 2]))
    replay_buffer_observer = MockReplayBufferObserver()

    driver = py_driver.PyDriver(
        env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=max_steps,
        max_episodes=max_episodes,
    )
    initial_time_step = env.reset()
    initial_policy_state = policy.get_initial_state()
    driver.run(initial_time_step, initial_policy_state)
    trajectories = replay_buffer_observer.gather_all()

    self.assertEqual(
        len(trajectories), len(expected_trajectories[:expected_length]))

    for t1, t2 in zip(trajectories, expected_trajectories[:expected_length]):
      for t1_field, t2_field in zip(t1, t2):
        self.assertAllEqual(t1_field, t2_field)
Exemple #5
0
 def testValueErrorOnInvalidArgs(self, max_steps, max_episodes):
     env = driver_test_utils.PyEnvironmentMock()
     policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                             env.action_spec())
     replay_buffer_observer = MockReplayBufferObserver()
     with self.assertRaises(ValueError):
         py_driver.PyDriver(
             env,
             policy,
             observers=[replay_buffer_observer],
             max_steps=max_steps,
             max_episodes=max_episodes,
         )
Exemple #6
0
  def testMultipleRunMaxSteps(self):
    num_steps = 3
    num_expected_steps = 4

    env = driver_test_utils.PyEnvironmentMock()
    policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                            env.action_spec())
    replay_buffer_observer = MockReplayBufferObserver()
    driver = py_driver.PyDriver(
        env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=1,
        max_episodes=None,
    )

    time_step = env.reset()
    policy_state = policy.get_initial_state()
    for _ in range(num_steps):
      time_step, policy_state = driver.run(time_step, policy_state)
    trajectories = replay_buffer_observer.gather_all()
    self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
Exemple #7
0
  def testPolicyStateReset(self):
    num_episodes = 2
    num_expected_steps = 6

    env = driver_test_utils.PyEnvironmentMock()
    policy = driver_test_utils.PyPolicyMock(env.time_step_spec(),
                                            env.action_spec())
    replay_buffer_observer = MockReplayBufferObserver()
    driver = py_driver.PyDriver(
        env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=None,
        max_episodes=num_episodes,
    )

    time_step = env.reset()
    policy_state = policy.get_initial_state()
    time_step, policy_state = driver.run(time_step, policy_state)
    trajectories = replay_buffer_observer.gather_all()
    self.assertEqual(trajectories, self._trajectories[:num_expected_steps])
    self.assertEqual(num_episodes, policy.get_initial_state_call_count)