예제 #1
0
    def _load_trajectories(self, initial):
        # Cache the initial trajectories in memory, as loading them can take a lot
        # of time and they don't change.
        if initial:
            if self._initial_trajectories is not None:
                return self._initial_trajectories
            trajectory_dir = self._initial_trajectory_dir
        else:
            trajectory_dir = self._trajectory_dump_root_dir

        trajectories = simple.load_trajectories(trajectory_dir,
                                                self._data_eval_frac)

        if initial:
            self._initial_trajectories = trajectories
        return trajectories
예제 #2
0
  def test_loads_trajectories(self):
    temp_dir = self.get_temp_dir()
    # Dump two trajectory pickles with given observations.
    self._dump_trajectory_pickle(
        observations=[0, 1, 2, 3], path=os.path.join(temp_dir, '0.pkl'))
    self._dump_trajectory_pickle(
        observations=[4, 5, 6, 7], path=os.path.join(temp_dir, '1.pkl'))
    (train_trajs, eval_trajs) = simple.load_trajectories(
        temp_dir, eval_frac=0.25)
    extract_obs = lambda t: t.last_time_step.observation
    # The order of pickles is undefined, so we compare sets.
    actual_train_obs = set(map(extract_obs, train_trajs))
    actual_eval_obs = set(map(extract_obs, eval_trajs))

    # First 3 trajectories from each pickle go to train, the last one to eval.
    expected_train_obs = {0, 1, 2, 4, 5, 6}
    expected_eval_obs = {3, 7}
    self.assertEqual(actual_train_obs, expected_train_obs)
    self.assertEqual(actual_eval_obs, expected_eval_obs)