Exemple #1
0
def create_atari_environment(game_name, sticky_actions=True):
    """Wraps an Atari 2600 Gym environment with some basic preprocessing.

  This preprocessing matches the guidelines proposed in Machado et al. (2017),
  "Revisiting the Arcade Learning Environment: Evaluation Protocols and Open
  Problems for General Agents".

  The created environment is the Gym wrapper around the Arcade Learning
  Environment.

  The main choice available to the user is whether to use sticky actions or not.
  Sticky actions, as prescribed by Machado et al., cause actions to persist
  with some probability (0.25) when a new command is sent to the ALE. This
  can be viewed as introducing a mild form of stochasticity in the environment.
  We use them by default.

  Args:
    game_name: str, the name of the Atari 2600 domain.
    sticky_actions: bool, whether to use sticky_actions as per Machado et al.

  Returns:
    An Atari 2600 environment with some standard preprocessing.
  """
    game_version = 'v0' if sticky_actions else 'v4'
    full_game_name = '{}NoFrameskip-{}'.format(game_name, game_version)
    env = gym.make(full_game_name)
    # Strip out the TimeLimit wrapper from Gym, which caps us at 100k frames. We
    # handle this time limit internally instead, which lets us cap at 108k frames
    # (30 minutes). The TimeLimit wrapper also plays poorly with saving and
    # restoring states.
    env = env.env
    env = preprocessing.AtariPreprocessing(env)
    return env
    def testResetPassesObservation(self):
        env = MockEnvironment()
        env = preprocessing.AtariPreprocessing(env,
                                               frame_skip=1,
                                               screen_size=16)
        observation = env.reset()

        self.assertEqual(observation.shape, (16, 16, 1))
    def testMaxFramePooling(self):
        frame_skip = 2
        env = MockEnvironment()
        env = preprocessing.AtariPreprocessing(env, frame_skip=frame_skip)
        env.reset()

        # The first observation is 2, the second 0; max is 2.
        observation, _, _, _ = env.step(0)
        self.assertTrue((observation == 8).all())
    def testFrameSkipAccumulatesReward(self):
        frame_skip = 2
        env = MockEnvironment()
        env = preprocessing.AtariPreprocessing(env, frame_skip=frame_skip)
        env.reset()

        # Make sure we get the right number of steps. Reward is 1 when we
        # pass in action 0.
        _, reward, _, _ = env.step(0)
        self.assertEqual(reward, frame_skip)
    def testTerminalPassedThrough(self):
        max_steps = 10
        env = MockEnvironment(max_steps=max_steps)
        env = preprocessing.AtariPreprocessing(env, frame_skip=1)
        env.reset()

        # Make sure we get the right number of steps.
        for _ in range(max_steps - 1):
            _, _, is_terminal, _ = env.step(0)
            self.assertFalse(is_terminal)

        _, _, is_terminal, _ = env.step(0)
        self.assertTrue(is_terminal)
    def __call__(self, *args, **kwargs):
        print("GOOD", self.game_name, self.sticky_actions)

        def get_env_id(game_name, sticky_actions):
            game_version = 'v0' if sticky_actions else 'v4'
            full_game_name = '{}NoFrameskip-{}'.format(game_name, game_version)
            return full_game_name

        env_id = get_env_id(self.game_name, self.sticky_actions)
        env = gym.make(env_id)
        # logger.info('env_id {}'.format(env_id))

        record_video_trigger = RecordVideoTriggerEpisodeFreq(episode_freq=500)

        # INFO: this is what dopamine does in create_atari_environement
        # They say:
        # Strip out the TimeLimit wrapper from Gym, which caps us at 100k frames. We
        # handle this time limit internally instead, which lets us cap at 108k frames
        # (30 minutes). The TimeLimit wrapper also plays poorly with saving and
        # restoring states.

        base_env = env.env

        env = VideoRecorderWrapper(base_env,
                                   directory=str(
                                       Path(self.video_directory) / 'videos'),
                                   record_video_trigger=record_video_trigger,
                                   video_length=2000000)

        class DopamineWrapper(Wrapper):
            def __init__(self, env, base_env):
                super().__init__(env)
                self.base_env = base_env

            @property
            def ale(self):
                return self.base_env.ale

        env = DopamineWrapper(env, base_env)
        env = preprocessing.AtariPreprocessing(env)

        return env