def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, )
def test_default_on_fixed_input(self): """End-to-end test on fixed input. This is to test (mainly observation) processors do not change due to updates in underlying library functions. """ # Create environment just for the observation spec. env = gym_atari.GymAtari('pong', seed=1) rgb_spec, unused_lives_spec = env.observation_spec() random_state = np.random.RandomState(seed=1) # Generate timesteps with fixed data to feed into processor. def generate_rgb_obs(): return random_state.randint(0, 256, size=rgb_spec.shape, dtype=rgb_spec.dtype) step_types = [F, M, M, M, M] rewards = [None, 0.5, 0.2, 0, 0.1] discounts = [None, 0.9, 0.9, 0.9, 0.9] rgb_obs = [generate_rgb_obs() for _ in range(len(step_types))] lives_obs = [3, 3, 3, 3, 3] timesteps = [] for i in range(len(step_types)): timesteps.append( dm_env.TimeStep(step_type=step_types[i], reward=rewards[i], discount=discounts[i], observation=(rgb_obs[i], lives_obs[i]))) def hash_array(array): return hashlib.sha256(array).hexdigest() # Make sure generated observation data is fixed and the random number # generator has not changed from underneath us, causing the test to fail. hash_rgb_obs = [hash_array(obs) for obs in rgb_obs] expected_hashes = [ '250557b2184381fc2ec541fc313127050098fce825a6e98a728c2993874db300', 'db8054ca287971a0e1264bfbc5642233085f1b27efbca9082a29f5be8a24c552', '7016e737a257fcdb77e5f23daf96d94f9820bd7361766ca7b1401ec90984ef71', '356dfcf0c6eaa4e2b5e80f4611375c0131435cc22e6a413b573818d7d084e9b2', '73078bedd438422ad1c3dda6718aa1b54f6163f571d2c26ed714c515a6372159', ] assert hash_rgb_obs == expected_hashes, (hash_rgb_obs, expected_hashes) # Run timesteps through processor. processor = processors.atari() for timestep in timesteps: processed = processor(timestep) # Assert the returned timestep is not None, and tell pytype. self.assertIsNotNone(processed) processed = typing.cast(dm_env.TimeStep, processed) # Compare with expected timestep, just the hash for the observation. self.assertEqual(dm_env.StepType.MID, processed.step_type) self.assertAlmostEqual(0.5 + 0.2 + 0. + 0.1, processed.reward) self.assertAlmostEqual(0.9**4 * 0.99, processed.discount) processed_obs_hash = hash_array(processed.observation.flatten()) # Note the algorithm used for image resizing can have a noticeable impact on # learning performance. This test helps ensure changes to image processing # are intentional. self.assertEqual( '0d158a8f45aa09aa6fad0354d2eb1fc0e3f57add88e772f3b71f54819d8200aa', processed_obs_hash)
def __init__(self, num_actions): self._processor = processors.atari() self._num_actions = num_actions self._action = None