def expert_source(env, goal, max_episode_length, n_eps): expert = OptimalPolicy(env.spec, goal=goal) workers = WorkerFactory(seed=100, max_episode_length=max_episode_length) expert_sampler = LocalSampler.from_worker_factory(workers, expert, env) for _ in range(n_eps): eps_batch = expert_sampler.obtain_samples(0, max_episode_length, None) yield TimeStepBatch.from_episode_batch(eps_batch)
def test_time_step_batch_from_episode_batch(eps_data): eps = EpisodeBatch(**eps_data) timestep_batch = TimeStepBatch.from_episode_batch(eps) assert (timestep_batch.observations == eps.observations).all() assert (timestep_batch.next_observations[:eps.lengths[0] - 1] == eps.observations[1:eps.lengths[0]]).all() assert (timestep_batch.next_observations[eps.lengths[0]] == eps.last_observations[0]).all()