예제 #1
0
def test_channel_first():
    env = gym.make("Breakout-v0")

    width, height, channel = env.observation_space.shape

    wrapper = ChannelFirst(env)

    # check reset
    observation = wrapper.reset()
    assert observation.shape == (channel, width, height)

    # check step
    observation, _, _, _ = wrapper.step(wrapper.action_space.sample())
    assert observation.shape == (channel, width, height)

    # check with algorithm
    dqn = DQN()
    dqn.build_with_env(wrapper)
    dqn.predict([observation])
예제 #2
0
def test_channel_first_with_2_dim_obs():
    env = DummyAtari(squeeze=True)

    width, height = env.observation_space.shape

    wrapper = ChannelFirst(env)

    # check reset
    observation = wrapper.reset()
    assert observation.shape == (1, width, height)

    # check step
    observation, _, _, _ = wrapper.step(wrapper.action_space.sample())
    assert observation.shape == (1, width, height)

    # check with algorithm
    dqn = DQN()
    dqn.build_with_env(wrapper)
    dqn.predict([observation])
예제 #3
0
def test_channel_first_with_2_dim_obs():
    env = AtariPreprocessing(gym.make("BreakoutNoFrameskip-v4"))

    width, height = env.observation_space.shape

    wrapper = ChannelFirst(env)

    # check reset
    observation = wrapper.reset()
    assert observation.shape == (1, width, height)

    # check step
    observation, _, _, _ = wrapper.step(wrapper.action_space.sample())
    assert observation.shape == (1, width, height)

    # check with algorithm
    dqn = DQN()
    dqn.build_with_env(wrapper)
    dqn.predict([observation])
예제 #4
0
def test_channel_first():
    env = DummyAtari(grayscale=False)

    width, height, channel = env.observation_space.shape

    wrapper = ChannelFirst(env)

    # check reset
    observation = wrapper.reset()
    assert observation.shape == (channel, width, height)

    # check step
    observation, _, _, _ = wrapper.step(wrapper.action_space.sample())
    assert observation.shape == (channel, width, height)

    # check with algorithm
    dqn = DQN()
    dqn.build_with_env(wrapper)
    dqn.predict([observation])