def test_trajectory(env_id): env = gym.make(env_id) env = TimeStepEnv(env) traj = Trajectory() assert len(traj) == 0 and traj.T == 0 with pytest.raises(AssertionError): traj.add(TimeStep(StepType.MID, 1, 2, True, {}), 0.5) timestep = env.reset() traj.add(timestep, None) assert len(traj) == 1 and traj.T == 0 with pytest.raises(AssertionError): traj.add(TimeStep(StepType.MID, 1, 2, True, {}), None) while not timestep.last(): action = env.action_space.sample() timestep = env.step(action) traj.add(timestep, action) if not timestep.last(): assert len(traj) == traj.T + 1 assert not traj.finished with pytest.raises(AssertionError): traj.add(timestep, 5.3) assert traj.finished assert traj.reach_time_limit == traj[-1].time_limit() assert traj.reach_terminal == traj[-1].terminal() assert np.asarray(traj.observations).shape == (traj.T + 1, *env.observation_space.shape) assert len(traj.actions) == traj.T assert len(traj.rewards) == traj.T assert len(traj.dones) == traj.T assert len(traj.infos) == traj.T if traj.reach_time_limit: assert len(traj.get_infos('TimeLimit.truncated')) == 1
def test_timestep_env(env_id): env = gym.make(env_id) wrapped_env = TimeStepEnv(gym.make(env_id)) env.seed(0) wrapped_env.seed(0) obs = env.reset() timestep = wrapped_env.reset() assert timestep.first() assert np.allclose(timestep.observation, obs) for t in range(env.spec.max_episode_steps): action = env.action_space.sample() obs, reward, done, info = env.step(action) timestep = wrapped_env.step(action) assert np.allclose(timestep.observation, obs) assert timestep.reward == reward assert timestep.done == done assert timestep.info == info if done: assert timestep.last() if 'TimeLimit.truncated' in info and info['TimeLimit.truncated']: assert timestep.time_limit() else: assert timestep.terminal() break else: assert timestep.mid()