예제 #1
0
 def test_atari_grayscaling_observation_spec(self, grayscaling,
                                             expected_name, expected_shape):
     env = gym_atari.GymAtari('pong', seed=1)
     env = processors.AtariEnvironmentWrapper(environment=env,
                                              grayscaling=grayscaling)
     spec = env.observation_spec()
     self.assertEqual(spec.shape, expected_shape)
     self.assertEqual(spec.name, expected_name)
예제 #2
0
    def test_atari_grayscaling_observation_shape(self, grayscaling,
                                                 expected_shape):
        env = gym_atari.GymAtari('pong', seed=1)
        env = processors.AtariEnvironmentWrapper(environment=env,
                                                 grayscaling=grayscaling)

        timestep = env.reset()
        for _ in range(10):
            assert not timestep.step_type.last()
            chex.assert_shape(timestep.observation, expected_shape)
            timestep = env.step(0)
예제 #3
0
 def make_object_under_test(self):
     env = gym_atari.GymAtari('pong', seed=1)
     return processors.AtariEnvironmentWrapper(environment=env)