class AtariWrapper(): """ ALE wrapper that tries to mimic the options in the DQN paper including the preprocessing (except resizing/cropping) """ action_words = [ 'NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', "UPRIGHT", "UPLEFT", "DOWNRIGHT", "DOWNLEFT" ] _action_set = [0, 2, 3, 4, 5, 6, 7, 8, 9] # Valid actions for ALE. # Possible actions are just a list from 0,num_valid_actions # We still need to map from the latter to the former when possible_actions = list(range(len(_action_set))) def __init__(self, rom_path, seed=123, frameskip=4, show_display=False, stack_num_states=4, concatenate_state_every=4): """ Parameters: Frameskip should be either a tuple (indicating a random range to choose from, with the top value exclude), or an int. It's aka action repeat. stack_num_states: Number of dimensions/channels to have. concatenate_state_every: After how many frames should one channel be appended to state. Number is in terms of absolute frames independent of frameskip """ self.stack_num_states = stack_num_states self.concatenate_state_every = concatenate_state_every self.game_path = rom_path if not os.path.exists(self.game_path): raise IOError('You asked for game %s but path %s does not exist' % (game, self.game_path)) self.frameskip = frameskip try: self.ale = ALEInterface() except Exception as e: print( "ALEInterface could not be loaded. ale_python_interface import failed" ) raise e # Set some default options self.ale.setInt(b'random_seed', seed) self.ale.setBool(b'sound', False) self.ale.setBool(b'display_screen', show_display) self.ale.setFloat(b'repeat_action_probability', 0.) # Load the rom self.ale.loadROM(self.game_path) (self.screen_width, self.screen_height) = self.ale.getScreenDims() self.latest_frame_fifo = deque( maxlen=2) # Holds the two closest frames to max. self.state_fifo = deque(maxlen=stack_num_states) def _step(self, a, force_noop=False): """Perform one step of the environment. Automatically repeats the step self.frameskip number of times parameters: force_noop: Force it to perform a no-op ignoring the action supplied. """ assert a in self.possible_actions + [0] if force_noop: action, num_steps = 0, 1 else: action = self._action_set[a] if isinstance(self.frameskip, int): num_steps = self.frameskip else: num_steps = np.random.randint(self.frameskip[0], self.frameskip[1]) reward = 0.0 for i in range(num_steps): reward += self.ale.act(action) cur_frame = self.observe_raw(get_rgb=True) cur_frame_cropped = self.crop_frame(cur_frame) self.latest_frame_fifo.append(cur_frame_cropped) if i % self.concatenate_state_every == 0: curmax_frame = np.amax(self.latest_frame_fifo, axis=0) frame_lumi = self.convert_to_gray(curmax_frame) self.state_fifo.append(frame_lumi) # Transpose so we get HxWxC instead of CxHxW self.current_frame = np.array(np.transpose(self.state_fifo, (1, 2, 0))) self.current_frame = cv2.resize(self.current_frame, (84, 84)) return self.current_frame, reward, self.ale.game_over(), { "ale.lives": self.ale.lives() } def step(self, *args, **kwargs): """Performs one step of the environment """ lives_before = self.ale.lives() next_state, reward, done, info = self._step(*args, **kwargs) lives_after = self.ale.lives() # End the episode when a life is lost if lives_before > lives_after: done = True return next_state, reward, done, info def observe_raw(self, get_rgb=False): """Observe either RGB or Gray frames. Initialzing arrays forces it to not modify stale pointers """ if get_rgb: cur_frame_rgb = np.zeros( (self.screen_height, self.screen_width, 3), dtype=np.uint8) self.ale.getScreenRGB(cur_frame_rgb) return cur_frame_rgb else: cur_frame_gray = np.zeros((self.screen_height, self.screen_width), dtype=np.uint8) self.ale.getScreenGrayscale(cur_frame_gray) return cur_frame_gray def crop_frame(self, frame): """Simply crops a frame. Does nothing by default. """ return frame def convert_to_gray(self, img): """Get Luminescence channel """ img_f = np.float32(img) img_lumi = 0.299 * img_f[:, :, 0] + \ 0.587 * img_f[:, :, 1] + \ 0.114 * img_f[:, :, 2] return np.uint8(img_lumi) def reset(self): """Reset the game """ self.ale.reset_game() s = self.observe_raw(get_rgb=True) s = self.crop_frame(s) # Populate missing frames with blank ones. for _ in range(self.stack_num_states - 1): self.state_fifo.append(np.zeros(shape=(s.shape[0], s.shape[1]))) self.latest_frame_fifo.append(s) # Push the latest frame curmax_frame = s frame_lumi = self.convert_to_gray(s) self.state_fifo.append(frame_lumi) self.state = np.transpose(self.state_fifo, (1, 2, 0)) self.state = cv2.resize(self.state, (84, 84)) return self.state def get_action_meanings(self): """Return in text what the actions correspond to. """ return [ACTION_MEANING[i] for i in self._action_set] def save_state(self): """Saves the current state and returns a identifier to saved state """ return self.ale.cloneSystemState() def restore_state(self, ident): """Restore game state Restores the saved state of the system and perform a no-op so a new frame can be generated incase a restore is followed by an observe() """ self.ale.restoreSystemState(ident) self.step(0, force_noop=True)
class AtariPlayer(gym.Env): """ A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings. Info: score: the accumulated reward in the current game gameOver: True when the current game is Over """ def __init__(self, rom_file, viz=0, frame_skip=4, nullop_start=30, live_lost_as_eoe=True, max_num_frames=0): """ Args: rom_file: path to the rom frame_skip: skip every k frames and repeat the action viz: visualization to be done. Set to 0 to disable. Set to a positive number to be the delay between frames to show. Set to a string to be a directory to store frames. nullop_start: start with random number of null ops. live_losts_as_eoe: consider lost of lives as end of episode. Useful for training. max_num_frames: maximum number of frames per episode. """ super(AtariPlayer, self).__init__() assert os.path.isfile(rom_file), \ "rom {} not found. Please download at {}".format(rom_file, ROM_URL) try: ALEInterface.setLoggerMode(ALEInterface.Logger.Error) except AttributeError: print("You're not using latest ALE") # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 with _ALE_LOCK: self.ale = ALEInterface() self.ale.setInt(b"random_seed", np.random.randint(0, 30000)) self.ale.setInt(b"max_num_frames_per_episode", max_num_frames) self.ale.setBool(b"showinfo", False) self.ale.setInt(b"frame_skip", 1) self.ale.setBool(b'color_averaging', False) # manual.pdf suggests otherwise. self.ale.setFloat(b'repeat_action_probability', 0.0) # viz setup if isinstance(viz, str): assert os.path.isdir(viz), viz self.ale.setString(b'record_screen_dir', viz) viz = 0 if isinstance(viz, int): viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): self.windowname = os.path.basename(rom_file) cv2.startWindowThread() cv2.namedWindow(self.windowname) self.ale.loadROM(rom_file.encode('utf-8')) self.width, self.height = self.ale.getScreenDims() self.actions = self.ale.getMinimalActionSet() self.live_lost_as_eoe = live_lost_as_eoe self.frame_skip = frame_skip self.nullop_start = nullop_start self.action_space = spaces.Discrete(len(self.actions)) self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width), dtype=np.uint8) self._restart_episode() def get_action_meanings(self): return [ACTION_MEANING[i] for i in self.actions] def _grab_raw_image(self): """ :returns: the current 3-channel image """ m = self.ale.getScreenRGB() return m.reshape((self.height, self.width, 3)) def _current_state(self): """ returns: a gray-scale (h, w) uint8 image """ ret = self._grab_raw_image() # avoid missing frame issue: max-pooled over the last screen ret = np.maximum(ret, self.last_raw_screen) if self.viz: if isinstance(self.viz, float): cv2.imshow(self.windowname, ret) cv2.waitKey(int(self.viz * 1000)) ret = ret.astype('float32') # 0.299,0.587.0.114. same as rgb2y in torch/image ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY) return ret.astype('uint8') # to save some memory def _restart_episode(self): with _ALE_LOCK: self.ale.reset_game() # random null-ops start n = np.random.randint(self.nullop_start) self.last_raw_screen = self._grab_raw_image() for k in range(n): if k == n - 1: self.last_raw_screen = self._grab_raw_image() self.ale.act(0) def reset(self): if self.ale.game_over(): self._restart_episode() return self._current_state() def step(self, act): oldlives = self.ale.lives() r = 0 for k in range(self.frame_skip): if k == self.frame_skip - 1: self.last_raw_screen = self._grab_raw_image() r += self.ale.act(self.actions[act]) newlives = self.ale.lives() if self.ale.game_over() or \ (self.live_lost_as_eoe and newlives < oldlives): break isOver = self.ale.game_over() if self.live_lost_as_eoe: isOver = isOver or newlives < oldlives info = {'ale.lives': newlives} return self._current_state(), r, isOver, info