Exemplo n.º 1
0
 def test_frame_stack(self):
     env = gym.make('CartPole-v1')
     env = GymWrapper(env)
     env = FrameStack(env, num_stack=4)
     env.seed(1)
     assert isinstance(env, Env) and isinstance(env, FrameStack)
     assert env.num_stack == 4
     assert env.observation_space.shape == (4, 4)
     assert isinstance(env.stack_buffer, np.ndarray)
     assert env.stack_buffer.shape == (4, 4)
     assert np.all(env.stack_buffer == 0.0)
     assert env.stack_buffer.dtype == np.float32
     assert env.reset().shape == (4, 4)
     obs = env.step(0)[0]
     assert obs[:, 0].sum() != 0.0
     assert obs[:, 1].sum() != 0.0
     assert np.all(obs[:, 2:] == 0.0)
     assert np.any(obs[:, 0] != obs[:, 1])
     obs = env.step(1)[0]
     obs = env.step(1)[0]
     assert np.allclose(obs[:, -1],
                        [0.03073904, 0.00145001, -0.03088818, -0.03131252])
     assert np.allclose(obs[:, 2],
                        [0.03076804, -0.19321568, -0.03151444, 0.25146705])
     obs = env.step(1)[0]
     assert np.allclose(obs[:, -1],
                        [0.03076804, -0.19321568, -0.03151444, 0.25146705])
Exemplo n.º 2
0
def test_frame_stack(env_id, num_stack):
    env = gym.make(env_id)
    shape = env.observation_space.shape
    env = FrameStack(env, num_stack)
    assert env.observation_space.shape == (num_stack, ) + shape

    obs = env.reset()
    obs = np.asarray(obs)
    assert obs.shape == (num_stack, ) + shape
    for i in range(1, num_stack):
        assert np.allclose(obs[i - 1], obs[i])

    obs, _, _, _ = env.step(env.action_space.sample())
    obs = np.asarray(obs)
    assert obs.shape == (num_stack, ) + shape
    for i in range(1, num_stack - 1):
        assert np.allclose(obs[i - 1], obs[i])
    assert not np.allclose(obs[-1], obs[-2])
Exemplo n.º 3
0
def test_get_all_wrappers(env_id):
    def make_env():
        return gym.make(env_id)

    env = make_env()
    env = ClipReward(env, 0.1, 0.5)
    env = FlattenObservation(env)
    env = FrameStack(env, 4)
    assert get_all_wrappers(env) == [
        'FrameStack', 'FlattenObservation', 'ClipReward', 'TimeLimit'
    ]
Exemplo n.º 4
0
def test_get_wrapper(env_id):
    def make_env():
        return gym.make(env_id)

    env = make_env()
    env = ClipReward(env, 0.1, 0.5)
    env = FlattenObservation(env)
    env = FrameStack(env, 4)

    assert get_wrapper(env, 'ClipReward').__class__.__name__ == 'ClipReward'
    assert get_wrapper(
        env, 'FlattenObservation').__class__.__name__ == 'FlattenObservation'
    assert get_wrapper(env, 'Env') is None

    del env

    # vec_env
    env = make_vec_env(make_env, 3, 0)
    env = VecMonitor(env)
    assert get_wrapper(env, 'VecMonitor').__class__.__name__ == 'VecMonitor'
    assert get_wrapper(env, 'ClipReward') is None