Ejemplo n.º 1
0
def test_skip():
    """
    Test a FrameSkipEnv wrapper.
    """
    # Timestep limit is 5.
    env = SimpleEnv(4, (3, 2, 5), 'float32')
    act1 = np.random.uniform(high=255.0, size=(3, 2, 5))
    act2 = np.random.uniform(high=255.0, size=(3, 2, 5))
    obs1 = env.reset()
    rew1 = 0.0
    rew2 = 0.0
    for _ in range(3):
        obs2, rew, _, _ = env.step(act1)
        rew1 += rew
    for _ in range(2):
        obs3, rew, done, _ = env.step(act2)
        rew2 += rew
    assert done

    env = FrameSkipEnv(env, num_frames=3)
    actual_obs1 = env.reset()
    assert np.allclose(actual_obs1, obs1)
    actual_obs2, actual_rew1, done, _ = env.step(act1)
    assert not done
    assert actual_rew1 == rew1
    assert np.allclose(actual_obs2, obs2)
    actual_obs3, actual_rew2, done, _ = env.step(act2)
    assert done
    assert actual_rew2 == rew2
    assert np.allclose(actual_obs3, obs3)
Ejemplo n.º 2
0
 def test_resize_even(self):
     """
     Test resizing for an even number of pixels.
     """
     env = SimpleEnv(5, (13, 5, 3), 'float32')
     frame = env.reset()
     actual = ResizeImageEnv(env, size=(5, 4)).reset()
     expected = tf.Session().run(
         tf.image.resize_images(frame, [5, 4], method=tf.image.ResizeMethod.AREA))
     self.assertEqual(actual.shape, (5, 4, 3))
     self.assertTrue(np.allclose(actual, expected))
Ejemplo n.º 3
0
 def test_max_2(self):
     """
     Test maxing 2 frames.
     """
     env = SimpleEnv(5, (3, 2, 5), 'float32')
     actions = [env.action_space.sample() for _ in range(4)]
     frame1 = env.reset()
     frame2 = env.step(actions[0])[0]
     frame3 = env.step(actions[1])[0]
     frame4 = env.step(actions[2])[0]
     frame5 = env.step(actions[3])[0]
     wrapped = MaxEnv(env, num_images=2)
     max1 = wrapped.reset()
     max2 = wrapped.step(actions[0])[0]
     max3 = wrapped.step(actions[1])[0]
     max4 = wrapped.step(actions[2])[0]
     max5 = wrapped.step(actions[3])[0]
     self.assertTrue((max1 == frame1).all())
     self.assertTrue((max2 == np.max([frame1, frame2], axis=0)).all())
     self.assertTrue((max3 == np.max([frame2, frame3], axis=0)).all())
     self.assertTrue((max4 == np.max([frame3, frame4], axis=0)).all())
     self.assertTrue((max5 == np.max([frame4, frame5], axis=0)).all())