def _init_ale(rand_seed, rom_file): #assert os.path.exists(rom_file), '%s does not exists.' ale = ALEInterface() ale.setInt('random_seed', rand_seed) ale.setBool('showinfo', False) ale.setInt('frame_skip', 1) ale.setFloat('repeat_action_probability', 0.0) ale.setBool('color_averaging', False) ale.loadROM(rom_file) return ale
class AtariEnv(object): def __init__(self, frame_skip=None, repeat_action_probability=0.0, state_shape=[84, 84], rom_path=None, game_name='pong', random_state=None, rendering=False, record_dir=None, obs_showing=False, channel_weights=[0.5870, 0.2989, 0.1140]): self.ale = ALEInterface() self.frame_skip = frame_skip self.state_shape = state_shape if random_state is None: random_state = np.random.RandomState(1234) self.rng = random_state self.channel_weights = channel_weights self.ale.setInt(b'random_seed', self.rng.randint(1000)) self.ale.setFloat(b'repeat_action_probability', repeat_action_probability) self.ale.setBool(b'color_averaging', False) if rendering: if sys.platform == 'darwin': import pygame pygame.init() self.ale.setBool(b'sound', False) # Sound doesn't work on OSX elif sys.platform.startswith('linux'): self.ale.setBool(b'sound', True) self.ale.setBool(b'display_screen', True) if rendering and record_dir is not None: # should be before loadROM self.ale.setString(b'record_screen_dir', record_dir.encode()) self.ale.setString(b'record_sound_filename', os.path.join(record_dir, '/sound.wav').encode()) self.ale.setInt(b'fragsize', 64) # to ensure proper sound sync (see ALE doc) self.ale.loadROM(str.encode(rom_path + game_name + '.bin')) self.legal_actions = self.ale.getMinimalActionSet() self.nb_actions = len(self.legal_actions) (self.screen_width, self.screen_height) = self.ale.getScreenDims() self._buffer = np.empty((self.screen_height, self.screen_width, 3), dtype=np.uint8) self.obs_showing = obs_showing def reset(self): self.ale.reset_game() return self.get_state() def step(self, action): reward = 0.0 if self.frame_skip is None: num_steps = 1 elif isinstance(self.frame_skip, int): num_steps = self.frame_skip else: num_steps = self.rng.randint(self.frame_skip[0], self.frame_skip[1]) for i in range(num_steps): reward += self.ale.act(self.legal_actions[action]) return self.get_state(), reward, self.ale.game_over(), {} def _get_image(self): self.ale.getScreenRGB(self._buffer) gray = self.channel_weights[0] * self._buffer[:, :, 0] + self.channel_weights[1] * self._buffer[:, :, 1] + \ self.channel_weights[2] * self._buffer[:, :, 2] x = cv2.resize(gray, tuple(self.state_shape), interpolation=cv2.INTER_LINEAR) return x def get_state(self): return self._get_image() def get_lives(self): return self.ale.lives()
class AtariPlayer(gym.Env): """ A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings. Info: score: the accumulated reward in the current game gameOver: True when the current game is Over """ def __init__(self, rom_file, viz=0, frame_skip=4, nullop_start=30, live_lost_as_eoe=True, max_num_frames=0): """ Args: rom_file: path to the rom frame_skip: skip every k frames and repeat the action viz: visualization to be done. Set to 0 to disable. Set to a positive number to be the delay between frames to show. Set to a string to be a directory to store frames. nullop_start: start with random number of null ops. live_losts_as_eoe: consider lost of lives as end of episode. Useful for training. max_num_frames: maximum number of frames per episode. """ super(AtariPlayer, self).__init__() if not os.path.isfile(rom_file) and '/' not in rom_file: rom_file = get_dataset_path('atari_rom', rom_file) assert os.path.isfile(rom_file), \ "ROM {} not found. Please download at {}".format(rom_file, ROM_URL) try: ALEInterface.setLoggerMode(ALEInterface.Logger.Error) except AttributeError: if execute_only_once(): logger.warn("You're not using latest ALE") # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 with _ALE_LOCK: self.ale = ALEInterface() self.rng = get_rng(self) self.ale.setInt(b"random_seed", self.rng.randint(0, 30000)) self.ale.setInt(b"max_num_frames_per_episode", max_num_frames) self.ale.setBool(b"showinfo", False) self.ale.setInt(b"frame_skip", 1) self.ale.setBool(b'color_averaging', False) # manual.pdf suggests otherwise. self.ale.setFloat(b'repeat_action_probability', 0.0) # viz setup if isinstance(viz, six.string_types): assert os.path.isdir(viz), viz self.ale.setString(b'record_screen_dir', viz) viz = 0 if isinstance(viz, int): viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): self.windowname = os.path.basename(rom_file) #cv2.namedWindow(self.windowname) self.ale.loadROM(rom_file.encode('utf-8')) self.width, self.height = self.ale.getScreenDims() self.actions = self.ale.getMinimalActionSet() self.live_lost_as_eoe = live_lost_as_eoe self.frame_skip = frame_skip self.nullop_start = nullop_start self.action_space = spaces.Discrete(len(self.actions)) self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width), dtype=np.uint8) self._restart_episode() def get_action_meanings(self): return [ACTION_MEANING[i] for i in self.actions] def _grab_raw_image(self): """ :returns: the current 3-channel image """ m = self.ale.getScreenRGB() return m.reshape((self.height, self.width, 3)) def _current_state(self): """ :returns: a gray-scale (h, w) uint8 image """ ret = self._grab_raw_image() # max-pooled over the last screen ret = np.maximum(ret, self.last_raw_screen) if self.viz: if isinstance(self.viz, float): #cv2.imshow(self.windowname, ret) #plt.ion() #cv2_imshow(ret) #cv2.waitKey(int(self.viz * 1000)) #ret.view(dtype=np.int8).reshape(640, 480) plt.figure() plt.imshow(ret) #plt.pause(0.0001) #print(ret.shape) # plt.draw() plt.show() ret = ret.astype('float32') # 0.299,0.587.0.114. same as rgb2y in torch/image ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :] return ret.astype('uint8') # to save some memory def _restart_episode(self): with _ALE_LOCK: self.ale.reset_game() # random null-ops start n = self.rng.randint(self.nullop_start) self.last_raw_screen = self._grab_raw_image() for k in range(n): if k == n - 1: self.last_raw_screen = self._grab_raw_image() self.ale.act(0) def reset(self): if self.ale.game_over(): self._restart_episode() return self._current_state() def render(self, *args, **kwargs): pass # visualization for this env is through the viz= argument when creating the player def step(self, act): oldlives = self.ale.lives() r = 0 for k in range(self.frame_skip): if k == self.frame_skip - 1: self.last_raw_screen = self._grab_raw_image() r += self.ale.act(self.actions[act]) newlives = self.ale.lives() if self.ale.game_over() or \ (self.live_lost_as_eoe and newlives < oldlives): break isOver = self.ale.game_over() if self.live_lost_as_eoe: isOver = isOver or newlives < oldlives info = {'ale.lives': newlives} return self._current_state(), r, isOver, info
parser.add_argument('--query_args', type=str, default="null") parser.add_argument( '--seed', help='The random seed for the Atari env (default is 1234).', default=1234, type=int) parser.add_argument('--write_actions', type=argparse.FileType('w'), default=sys.stdout) args = parser.parse_args() ale = ALEInterface() __USE_SDL = True ale.setInt(b'random_seed', args.seed) ale.setFloat(b'repeat_action_probability', 0.0) print('Starting up: ' + args.game) game_path = atari_py.get_game_path(args.game) ale.loadROM(str.encode(game_path)) print('Legal Actions: ', ale.getLegalActionSet()) pygame.init() (w, h) = ale.getScreenDims() dim = (w * args.scale, h * args.scale) pygame.display.set_mode(dim) clock = pygame.time.Clock() FPS = 32 quit = False