示例#1
0
def test_trajectory(env_id):
    env = gym.make(env_id)
    env = TimeStepEnv(env)
    traj = Trajectory()
    assert len(traj) == 0 and traj.T == 0

    with pytest.raises(AssertionError):
        traj.add(TimeStep(StepType.MID, 1, 2, True, {}), 0.5)

    timestep = env.reset()
    traj.add(timestep, None)
    assert len(traj) == 1 and traj.T == 0

    with pytest.raises(AssertionError):
        traj.add(TimeStep(StepType.MID, 1, 2, True, {}), None)

    while not timestep.last():
        action = env.action_space.sample()
        timestep = env.step(action)
        traj.add(timestep, action)
        if not timestep.last():
            assert len(traj) == traj.T + 1
            assert not traj.finished
    with pytest.raises(AssertionError):
        traj.add(timestep, 5.3)
    assert traj.finished
    assert traj.reach_time_limit == traj[-1].time_limit()
    assert traj.reach_terminal == traj[-1].terminal()
    assert np.asarray(traj.observations).shape == (traj.T + 1, *env.observation_space.shape)
    assert len(traj.actions) == traj.T
    assert len(traj.rewards) == traj.T
    assert len(traj.dones) == traj.T
    assert len(traj.infos) == traj.T
    if traj.reach_time_limit:
        assert len(traj.get_infos('TimeLimit.truncated')) == 1
示例#2
0
def test_step_runner(env_id, T):
    env = gym.make(env_id)
    env = TimeStepEnv(env)
    agent = RandomAgent(None, env, None)

    runner = StepRunner(reset_on_call=True)
    D = runner(agent, env, T)
    assert runner.observation is None
    assert all([isinstance(traj, Trajectory) for traj in D])
    assert all([traj[0].first() for traj in D])
    assert all([traj[-1].last() for traj in D[:-1]])
    assert all([
        'last_info' in traj.extra_info
        and 'raw_action' in traj.extra_info['last_info'] for traj in D
    ])

    runner = StepRunner(reset_on_call=False)
    D = runner(agent, env, 1)
    assert D[0][0].first()
    assert len(D[0]) == 2
    assert np.allclose(D[0][-1].observation, runner.observation)
    assert all([
        'last_info' in traj.extra_info
        and 'raw_action' in traj.extra_info['last_info'] for traj in D
    ])
    D2 = runner(agent, env, 3)
    assert np.allclose(D2[-1][-1].observation, runner.observation)
    assert np.allclose(D[0][-1].observation, D2[0][0].observation)
    assert D2[0][0].first() and D2[0][0].reward is None
    assert all([
        'last_info' in traj.extra_info
        and 'raw_action' in traj.extra_info['last_info'] for traj in D2
    ])
示例#3
0
def make_env(config, seed, mode):
    assert mode in ['train', 'eval']
    env = gym.make(config['env.id'])
    env.seed(seed)
    env.observation_space.seed(seed)
    env.action_space.seed(seed)
    if config['env.clip_action'] and isinstance(env.action_space, gym.spaces.Box):
        env = gym.wrappers.ClipAction(env)  # TODO: use tanh to squash policy output when RescaleAction wrapper merged in gym
    env = TimeStepEnv(env)
    return env
示例#4
0
def make_env(config, seed, mode):
    assert mode in ['train', 'eval']
    env = gym.make(config['env.id'])
    env.seed(seed)
    env.observation_space.seed(seed)
    env.action_space.seed(seed)
    if mode == 'eval':
        env = RecordEpisodeStatistics(env, deque_size=100)
    env = TimeStepEnv(env)
    return env
示例#5
0
def make_env(config, seed, mode):
    assert mode in ['train', 'eval']
    env = gym.make(config['env.id'])
    env.seed(seed)
    env.observation_space.seed(seed)
    env.action_space.seed(seed)
    env = NormalizeAction(env)  # TODO: use gym new wrapper RescaleAction when it's merged
    if mode == 'eval':
        env = RecordEpisodeStatistics(env, deque_size=100)
    env = TimeStepEnv(env)
    return env
示例#6
0
def test_timestep_env(env_id):
    env = gym.make(env_id)
    wrapped_env = TimeStepEnv(gym.make(env_id))

    env.seed(0)
    wrapped_env.seed(0)

    obs = env.reset()
    timestep = wrapped_env.reset()
    assert timestep.first()
    assert np.allclose(timestep.observation, obs)

    for t in range(env.spec.max_episode_steps):
        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)
        timestep = wrapped_env.step(action)
        assert np.allclose(timestep.observation, obs)
        assert timestep.reward == reward
        assert timestep.done == done
        assert timestep.info == info
        if done:
            assert timestep.last()
            if 'TimeLimit.truncated' in info and info['TimeLimit.truncated']:
                assert timestep.time_limit()
            else:
                assert timestep.terminal()
            break
        else:
            assert timestep.mid()
示例#7
0
def test_episode_runner(env_id, N):
    env = gym.make(env_id)
    env = TimeStepEnv(env)
    agent = RandomAgent(None, env, None)
    runner = EpisodeRunner()
    D = runner(agent, env, N)
    assert len(D) == N
    assert all([isinstance(d, Trajectory) for d in D])
    assert all([traj.finished for traj in D])
    assert all([traj[0].first() for traj in D])
    assert all([traj[-1].last() for traj in D])
    for traj in D:
        for timestep in traj[1:-1]:
            assert timestep.mid()
示例#8
0
def make_env(config, seed, mode):
    assert mode in ['train', 'eval']
    env = gym.make(config['env.id'])
    env.seed(seed)
    env.observation_space.seed(seed)
    env.action_space.seed(seed)
    if config['env.clip_action'] and isinstance(env.action_space,
                                                gym.spaces.Box):
        env = gym.wrappers.ClipAction(env)
    if mode == 'train':
        env = RecordEpisodeStatistics(env, deque_size=100)
        if config['env.normalize_obs']:
            env = NormalizeObservation(env, clip=5.)
        if config['env.normalize_reward']:
            env = NormalizeReward(env, clip=10., gamma=config['agent.gamma'])
    env = TimeStepEnv(env)
    return env