def _load_trajectories(self, initial): # Cache the initial trajectories in memory, as loading them can take a lot # of time and they don't change. if initial: if self._initial_trajectories is not None: return self._initial_trajectories trajectory_dir = self._initial_trajectory_dir else: trajectory_dir = self._trajectory_dump_root_dir trajectories = simple.load_trajectories(trajectory_dir, self._data_eval_frac) if initial: self._initial_trajectories = trajectories return trajectories
def test_loads_trajectories(self): temp_dir = self.get_temp_dir() # Dump two trajectory pickles with given observations. self._dump_trajectory_pickle( observations=[0, 1, 2, 3], path=os.path.join(temp_dir, '0.pkl')) self._dump_trajectory_pickle( observations=[4, 5, 6, 7], path=os.path.join(temp_dir, '1.pkl')) (train_trajs, eval_trajs) = simple.load_trajectories( temp_dir, eval_frac=0.25) extract_obs = lambda t: t.last_time_step.observation # The order of pickles is undefined, so we compare sets. actual_train_obs = set(map(extract_obs, train_trajs)) actual_eval_obs = set(map(extract_obs, eval_trajs)) # First 3 trajectories from each pickle go to train, the last one to eval. expected_train_obs = {0, 1, 2, 4, 5, 6} expected_eval_obs = {3, 7} self.assertEqual(actual_train_obs, expected_train_obs) self.assertEqual(actual_eval_obs, expected_eval_obs)