def test_atari_grayscaling_observation_spec(self, grayscaling, expected_name, expected_shape): env = gym_atari.GymAtari('pong', seed=1) env = processors.AtariEnvironmentWrapper(environment=env, grayscaling=grayscaling) spec = env.observation_spec() self.assertEqual(spec.shape, expected_shape) self.assertEqual(spec.name, expected_name)
def test_atari_grayscaling_observation_shape(self, grayscaling, expected_shape): env = gym_atari.GymAtari('pong', seed=1) env = processors.AtariEnvironmentWrapper(environment=env, grayscaling=grayscaling) timestep = env.reset() for _ in range(10): assert not timestep.step_type.last() chex.assert_shape(timestep.observation, expected_shape) timestep = env.step(0)
def make_object_under_test(self): env = gym_atari.GymAtari('pong', seed=1) return processors.AtariEnvironmentWrapper(environment=env)