class AtariWrapper():
    """
    ALE wrapper that tries to mimic the options in the DQN paper including the
    preprocessing (except resizing/cropping)
    """
    action_words = [
        'NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', "UPRIGHT", "UPLEFT",
        "DOWNRIGHT", "DOWNLEFT"
    ]
    _action_set = [0, 2, 3, 4, 5, 6, 7, 8, 9]
    # Valid actions for ALE.
    # Possible actions are just a list from 0,num_valid_actions
    # We still need to map from the latter to the former when

    possible_actions = list(range(len(_action_set)))

    def __init__(self,
                 rom_path,
                 seed=123,
                 frameskip=4,
                 show_display=False,
                 stack_num_states=4,
                 concatenate_state_every=4):
        """

        Parameters:
            Frameskip should be either a tuple (indicating a random range to
            choose from, with the top value exclude), or an int. It's aka action repeat.

            stack_num_states: Number of dimensions/channels to have.

            concatenate_state_every: After how many frames should one channel be appended to state.
                Number is in terms of absolute frames independent of frameskip
        """

        self.stack_num_states = stack_num_states
        self.concatenate_state_every = concatenate_state_every

        self.game_path = rom_path
        if not os.path.exists(self.game_path):
            raise IOError('You asked for game %s but path %s does not exist' %
                          (game, self.game_path))
        self.frameskip = frameskip

        try:
            self.ale = ALEInterface()
        except Exception as e:
            print(
                "ALEInterface could not be loaded. ale_python_interface import failed"
            )
            raise e

        # Set some default options
        self.ale.setInt(b'random_seed', seed)
        self.ale.setBool(b'sound', False)
        self.ale.setBool(b'display_screen', show_display)
        self.ale.setFloat(b'repeat_action_probability', 0.)

        # Load the rom
        self.ale.loadROM(self.game_path)

        (self.screen_width, self.screen_height) = self.ale.getScreenDims()
        self.latest_frame_fifo = deque(
            maxlen=2)  # Holds the two closest frames to max.
        self.state_fifo = deque(maxlen=stack_num_states)

    def _step(self, a, force_noop=False):
        """Perform one step of the environment.
        Automatically repeats the step self.frameskip number of times

        parameters:
            force_noop: Force it to perform a no-op ignoring the action supplied.
        """
        assert a in self.possible_actions + [0]

        if force_noop:
            action, num_steps = 0, 1
        else:
            action = self._action_set[a]

        if isinstance(self.frameskip, int):
            num_steps = self.frameskip
        else:
            num_steps = np.random.randint(self.frameskip[0], self.frameskip[1])

        reward = 0.0
        for i in range(num_steps):
            reward += self.ale.act(action)
            cur_frame = self.observe_raw(get_rgb=True)
            cur_frame_cropped = self.crop_frame(cur_frame)
            self.latest_frame_fifo.append(cur_frame_cropped)

            if i % self.concatenate_state_every == 0:
                curmax_frame = np.amax(self.latest_frame_fifo, axis=0)
                frame_lumi = self.convert_to_gray(curmax_frame)
                self.state_fifo.append(frame_lumi)

        # Transpose so we get HxWxC instead of CxHxW
        self.current_frame = np.array(np.transpose(self.state_fifo, (1, 2, 0)))
        self.current_frame = cv2.resize(self.current_frame, (84, 84))
        return self.current_frame, reward, self.ale.game_over(), {
            "ale.lives": self.ale.lives()
        }

    def step(self, *args, **kwargs):
        """Performs one step of the environment
        """
        lives_before = self.ale.lives()
        next_state, reward, done, info = self._step(*args, **kwargs)
        lives_after = self.ale.lives()

        # End the episode when a life is lost
        if lives_before > lives_after:
            done = True

        return next_state, reward, done, info

    def observe_raw(self, get_rgb=False):
        """Observe either RGB or Gray frames.
        Initialzing arrays forces it to not modify stale pointers
        """
        if get_rgb:
            cur_frame_rgb = np.zeros(
                (self.screen_height, self.screen_width, 3), dtype=np.uint8)
            self.ale.getScreenRGB(cur_frame_rgb)
            return cur_frame_rgb
        else:
            cur_frame_gray = np.zeros((self.screen_height, self.screen_width),
                                      dtype=np.uint8)
            self.ale.getScreenGrayscale(cur_frame_gray)
            return cur_frame_gray

    def crop_frame(self, frame):
        """Simply crops a frame. Does nothing by default.
        """
        return frame

    def convert_to_gray(self, img):
        """Get Luminescence channel
        """
        img_f = np.float32(img)
        img_lumi = 0.299 * img_f[:, :, 0] + \
                   0.587 * img_f[:, :, 1] + \
                   0.114 * img_f[:, :, 2]
        return np.uint8(img_lumi)

    def reset(self):
        """Reset the game
        """
        self.ale.reset_game()
        s = self.observe_raw(get_rgb=True)
        s = self.crop_frame(s)

        # Populate missing frames with blank ones.
        for _ in range(self.stack_num_states - 1):
            self.state_fifo.append(np.zeros(shape=(s.shape[0], s.shape[1])))

        self.latest_frame_fifo.append(s)

        # Push the latest frame
        curmax_frame = s
        frame_lumi = self.convert_to_gray(s)
        self.state_fifo.append(frame_lumi)

        self.state = np.transpose(self.state_fifo, (1, 2, 0))
        self.state = cv2.resize(self.state, (84, 84))
        return self.state

    def get_action_meanings(self):
        """Return in text what the actions correspond to.
        """
        return [ACTION_MEANING[i] for i in self._action_set]

    def save_state(self):
        """Saves the current state and returns a identifier to saved state
        """
        return self.ale.cloneSystemState()

    def restore_state(self, ident):
        """Restore game state
        Restores the saved state of the system and perform a no-op
        so a new frame can be generated incase a restore is followed
        by an observe()
        """

        self.ale.restoreSystemState(ident)
        self.step(0, force_noop=True)
Exemple #2
0
class AtariPlayer(gym.Env):
    """
    A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.
    Info:
        score: the accumulated reward in the current game
        gameOver: True when the current game is Over
    """
    def __init__(self,
                 rom_file,
                 viz=0,
                 frame_skip=4,
                 nullop_start=30,
                 live_lost_as_eoe=True,
                 max_num_frames=0):
        """
        Args:
            rom_file: path to the rom
            frame_skip: skip every k frames and repeat the action
            viz: visualization to be done.
                Set to 0 to disable.
                Set to a positive number to be the delay between frames to show.
                Set to a string to be a directory to store frames.
            nullop_start: start with random number of null ops.
            live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
            max_num_frames: maximum number of frames per episode.
        """
        super(AtariPlayer, self).__init__()
        assert os.path.isfile(rom_file), \
            "rom {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
        except AttributeError:
            print("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.ale.setInt(b"random_seed", np.random.randint(0, 30000))
            self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, str):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                cv2.startWindowThread()
                cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start

        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=(self.height, self.width),
                                            dtype=np.uint8)
        self._restart_episode()

    def get_action_meanings(self):
        return [ACTION_MEANING[i] for i in self.actions]

    def _grab_raw_image(self):
        """
        :returns: the current 3-channel image
        """
        m = self.ale.getScreenRGB()
        return m.reshape((self.height, self.width, 3))

    def _current_state(self):
        """
        returns: a gray-scale (h, w) uint8 image
        """
        ret = self._grab_raw_image()
        # avoid missing frame issue: max-pooled over the last screen
        ret = np.maximum(ret, self.last_raw_screen)
        if self.viz:
            if isinstance(self.viz, float):
                cv2.imshow(self.windowname, ret)
                cv2.waitKey(int(self.viz * 1000))
        ret = ret.astype('float32')
        # 0.299,0.587.0.114. same as rgb2y in torch/image
        ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)
        return ret.astype('uint8')  # to save some memory

    def _restart_episode(self):
        with _ALE_LOCK:
            self.ale.reset_game()

        # random null-ops start
        n = np.random.randint(self.nullop_start)
        self.last_raw_screen = self._grab_raw_image()
        for k in range(n):
            if k == n - 1:
                self.last_raw_screen = self._grab_raw_image()
            self.ale.act(0)

    def reset(self):
        if self.ale.game_over():
            self._restart_episode()
        return self._current_state()

    def step(self, act):
        oldlives = self.ale.lives()
        r = 0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives

        info = {'ale.lives': newlives}
        return self._current_state(), r, isOver, info