def test_skip(): """ Test a FrameSkipEnv wrapper. """ # Timestep limit is 5. env = SimpleEnv(4, (3, 2, 5), 'float32') act1 = np.random.uniform(high=255.0, size=(3, 2, 5)) act2 = np.random.uniform(high=255.0, size=(3, 2, 5)) obs1 = env.reset() rew1 = 0.0 rew2 = 0.0 for _ in range(3): obs2, rew, _, _ = env.step(act1) rew1 += rew for _ in range(2): obs3, rew, done, _ = env.step(act2) rew2 += rew assert done env = FrameSkipEnv(env, num_frames=3) actual_obs1 = env.reset() assert np.allclose(actual_obs1, obs1) actual_obs2, actual_rew1, done, _ = env.step(act1) assert not done assert actual_rew1 == rew1 assert np.allclose(actual_obs2, obs2) actual_obs3, actual_rew2, done, _ = env.step(act2) assert done assert actual_rew2 == rew2 assert np.allclose(actual_obs3, obs3)
def test_max_2(self): """ Test maxing 2 frames. """ env = SimpleEnv(5, (3, 2, 5), 'float32') actions = [env.action_space.sample() for _ in range(4)] frame1 = env.reset() frame2 = env.step(actions[0])[0] frame3 = env.step(actions[1])[0] frame4 = env.step(actions[2])[0] frame5 = env.step(actions[3])[0] wrapped = MaxEnv(env, num_images=2) max1 = wrapped.reset() max2 = wrapped.step(actions[0])[0] max3 = wrapped.step(actions[1])[0] max4 = wrapped.step(actions[2])[0] max5 = wrapped.step(actions[3])[0] self.assertTrue((max1 == frame1).all()) self.assertTrue((max2 == np.max([frame1, frame2], axis=0)).all()) self.assertTrue((max3 == np.max([frame2, frame3], axis=0)).all()) self.assertTrue((max4 == np.max([frame3, frame4], axis=0)).all()) self.assertTrue((max5 == np.max([frame4, frame5], axis=0)).all())