def make_gym_atari_env(game, seed=1): env = gym_atari.GymAtari(game, seed=seed) env = gym_atari.RandomNoopsEnvironmentWrapper(env, min_noop_steps=1, max_noop_steps=30, seed=seed) return env
def test_atari_grayscaling_observation_spec(self, grayscaling, expected_name, expected_shape): env = gym_atari.GymAtari('pong', seed=1) env = processors.AtariEnvironmentWrapper(environment=env, grayscaling=grayscaling) spec = env.observation_spec() self.assertEqual(spec.shape, expected_shape) self.assertEqual(spec.name, expected_name)
def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari( FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), )
def test_atari_grayscaling_observation_shape(self, grayscaling, expected_shape): env = gym_atari.GymAtari('pong', seed=1) env = processors.AtariEnvironmentWrapper(environment=env, grayscaling=grayscaling) timestep = env.reset() for _ in range(10): assert not timestep.step_type.last() chex.assert_shape(timestep.observation, expected_shape) timestep = env.step(0)
def test_can_use_in_an_agent(self): """Example of using Atari processor on the agent side.""" env = gym_atari.GymAtari('pong', seed=1) action_spec = env.action_spec() agent = AgentWithPreprocessing(num_actions=action_spec.num_values) agent.reset() timestep = env.reset() actions = [] for _ in range(20): action = agent.step(timestep) timestep = env.step(action) assert not timestep.last() actions.append(action) self.assertEqual( [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4], actions)
def make_object_under_test(self): env = gym_atari.GymAtari('pong', seed=1) return processors.AtariEnvironmentWrapper(environment=env)
def test_default_on_fixed_input(self): """End-to-end test on fixed input. This is to test (mainly observation) processors do not change due to updates in underlying library functions. """ # Create environment just for the observation spec. env = gym_atari.GymAtari('pong', seed=1) rgb_spec, unused_lives_spec = env.observation_spec() random_state = np.random.RandomState(seed=1) # Generate timesteps with fixed data to feed into processor. def generate_rgb_obs(): return random_state.randint(0, 256, size=rgb_spec.shape, dtype=rgb_spec.dtype) step_types = [F, M, M, M, M] rewards = [None, 0.5, 0.2, 0, 0.1] discounts = [None, 0.9, 0.9, 0.9, 0.9] rgb_obs = [generate_rgb_obs() for _ in range(len(step_types))] lives_obs = [3, 3, 3, 3, 3] timesteps = [] for i in range(len(step_types)): timesteps.append( dm_env.TimeStep(step_type=step_types[i], reward=rewards[i], discount=discounts[i], observation=(rgb_obs[i], lives_obs[i]))) def hash_array(array): return hashlib.sha256(array).hexdigest() # Make sure generated observation data is fixed and the random number # generator has not changed from underneath us, causing the test to fail. hash_rgb_obs = [hash_array(obs) for obs in rgb_obs] expected_hashes = [ '250557b2184381fc2ec541fc313127050098fce825a6e98a728c2993874db300', 'db8054ca287971a0e1264bfbc5642233085f1b27efbca9082a29f5be8a24c552', '7016e737a257fcdb77e5f23daf96d94f9820bd7361766ca7b1401ec90984ef71', '356dfcf0c6eaa4e2b5e80f4611375c0131435cc22e6a413b573818d7d084e9b2', '73078bedd438422ad1c3dda6718aa1b54f6163f571d2c26ed714c515a6372159', ] assert hash_rgb_obs == expected_hashes, (hash_rgb_obs, expected_hashes) # Run timesteps through processor. processor = processors.atari() for timestep in timesteps: processed = processor(timestep) # Assert the returned timestep is not None, and tell pytype. self.assertIsNotNone(processed) processed = typing.cast(dm_env.TimeStep, processed) # Compare with expected timestep, just the hash for the observation. self.assertEqual(dm_env.StepType.MID, processed.step_type) self.assertAlmostEqual(0.5 + 0.2 + 0. + 0.1, processed.reward) self.assertAlmostEqual(0.9**4 * 0.99, processed.discount) processed_obs_hash = hash_array(processed.observation.flatten()) # Note the algorithm used for image resizing can have a noticeable impact on # learning performance. This test helps ensure changes to image processing # are intentional. self.assertEqual( '0d158a8f45aa09aa6fad0354d2eb1fc0e3f57add88e772f3b71f54819d8200aa', processed_obs_hash)
def test_can_call_close(self): gym_atari.GymAtari('pong', seed=1).close()
def test_seed_range(self): for seed in (0, 1, 2**32 - 1): gym_atari.GymAtari('pong', seed=seed)