Example #1
0
 def preprocessor_builder():
   return processors.atari(
       additional_discount=FLAGS.additional_discount,
       max_abs_reward=FLAGS.max_abs_reward,
       resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
       num_action_repeats=FLAGS.num_action_repeats,
       num_pooled_frames=2,
       zero_discount_on_life_loss=True,
       num_stacked_frames=FLAGS.num_stacked_frames,
       grayscaling=True,
   )
Example #2
0
    def test_default_on_fixed_input(self):
        """End-to-end test on fixed input.

    This is to test (mainly observation) processors do not change due to updates
    in underlying library functions.
    """
        # Create environment just for the observation spec.
        env = gym_atari.GymAtari('pong', seed=1)
        rgb_spec, unused_lives_spec = env.observation_spec()
        random_state = np.random.RandomState(seed=1)

        # Generate timesteps with fixed data to feed into processor.
        def generate_rgb_obs():
            return random_state.randint(0,
                                        256,
                                        size=rgb_spec.shape,
                                        dtype=rgb_spec.dtype)

        step_types = [F, M, M, M, M]
        rewards = [None, 0.5, 0.2, 0, 0.1]
        discounts = [None, 0.9, 0.9, 0.9, 0.9]
        rgb_obs = [generate_rgb_obs() for _ in range(len(step_types))]
        lives_obs = [3, 3, 3, 3, 3]

        timesteps = []
        for i in range(len(step_types)):
            timesteps.append(
                dm_env.TimeStep(step_type=step_types[i],
                                reward=rewards[i],
                                discount=discounts[i],
                                observation=(rgb_obs[i], lives_obs[i])))

        def hash_array(array):
            return hashlib.sha256(array).hexdigest()

        # Make sure generated observation data is fixed and the random number
        # generator has not changed from underneath us, causing the test to fail.
        hash_rgb_obs = [hash_array(obs) for obs in rgb_obs]
        expected_hashes = [
            '250557b2184381fc2ec541fc313127050098fce825a6e98a728c2993874db300',
            'db8054ca287971a0e1264bfbc5642233085f1b27efbca9082a29f5be8a24c552',
            '7016e737a257fcdb77e5f23daf96d94f9820bd7361766ca7b1401ec90984ef71',
            '356dfcf0c6eaa4e2b5e80f4611375c0131435cc22e6a413b573818d7d084e9b2',
            '73078bedd438422ad1c3dda6718aa1b54f6163f571d2c26ed714c515a6372159',
        ]
        assert hash_rgb_obs == expected_hashes, (hash_rgb_obs, expected_hashes)

        # Run timesteps through processor.
        processor = processors.atari()
        for timestep in timesteps:
            processed = processor(timestep)

        # Assert the returned timestep is not None, and tell pytype.
        self.assertIsNotNone(processed)
        processed = typing.cast(dm_env.TimeStep, processed)

        # Compare with expected timestep, just the hash for the observation.
        self.assertEqual(dm_env.StepType.MID, processed.step_type)
        self.assertAlmostEqual(0.5 + 0.2 + 0. + 0.1, processed.reward)
        self.assertAlmostEqual(0.9**4 * 0.99, processed.discount)
        processed_obs_hash = hash_array(processed.observation.flatten())

        # Note the algorithm used for image resizing can have a noticeable impact on
        # learning performance. This test helps ensure changes to image processing
        # are intentional.
        self.assertEqual(
            '0d158a8f45aa09aa6fad0354d2eb1fc0e3f57add88e772f3b71f54819d8200aa',
            processed_obs_hash)
Example #3
0
 def __init__(self, num_actions):
     self._processor = processors.atari()
     self._num_actions = num_actions
     self._action = None