def expert_source(env, goal, max_path_length, n_traj): expert = OptimalPolicy(env.spec, goal=goal) workers = WorkerFactory(seed=100, max_path_length=max_path_length) expert_sampler = LocalSampler.from_worker_factory(workers, expert, env) for _ in range(n_traj): traj_batch = expert_sampler.obtain_samples(0, max_path_length, None) yield TimeStepBatch.from_trajectory_batch(traj_batch)
def test_time_step_batch_from_trajectory_batch(traj_data): traj = TrajectoryBatch(**traj_data) timestep_batch = TimeStepBatch.from_trajectory_batch(traj) assert (timestep_batch.observations == traj.observations).all() assert (timestep_batch.next_observations[:traj.lengths[0] - 1] == traj.observations[1:traj.lengths[0]]).all() assert (timestep_batch.next_observations[traj.lengths[0]] == traj.last_observations[0]).all()