示例#1
0
def test_trajectory(env_id):
    env = gym.make(env_id)
    env = TimeStepEnv(env)
    traj = Trajectory()
    assert len(traj) == 0 and traj.T == 0

    with pytest.raises(AssertionError):
        traj.add(TimeStep(StepType.MID, 1, 2, True, {}), 0.5)

    timestep = env.reset()
    traj.add(timestep, None)
    assert len(traj) == 1 and traj.T == 0

    with pytest.raises(AssertionError):
        traj.add(TimeStep(StepType.MID, 1, 2, True, {}), None)

    while not timestep.last():
        action = env.action_space.sample()
        timestep = env.step(action)
        traj.add(timestep, action)
        if not timestep.last():
            assert len(traj) == traj.T + 1
            assert not traj.finished
    with pytest.raises(AssertionError):
        traj.add(timestep, 5.3)
    assert traj.finished
    assert traj.reach_time_limit == traj[-1].time_limit()
    assert traj.reach_terminal == traj[-1].terminal()
    assert np.asarray(traj.observations).shape == (traj.T + 1, *env.observation_space.shape)
    assert len(traj.actions) == traj.T
    assert len(traj.rewards) == traj.T
    assert len(traj.dones) == traj.T
    assert len(traj.infos) == traj.T
    if traj.reach_time_limit:
        assert len(traj.get_infos('TimeLimit.truncated')) == 1
示例#2
0
def test_timestep_env(env_id):
    env = gym.make(env_id)
    wrapped_env = TimeStepEnv(gym.make(env_id))

    env.seed(0)
    wrapped_env.seed(0)

    obs = env.reset()
    timestep = wrapped_env.reset()
    assert timestep.first()
    assert np.allclose(timestep.observation, obs)

    for t in range(env.spec.max_episode_steps):
        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)
        timestep = wrapped_env.step(action)
        assert np.allclose(timestep.observation, obs)
        assert timestep.reward == reward
        assert timestep.done == done
        assert timestep.info == info
        if done:
            assert timestep.last()
            if 'TimeLimit.truncated' in info and info['TimeLimit.truncated']:
                assert timestep.time_limit()
            else:
                assert timestep.terminal()
            break
        else:
            assert timestep.mid()