class AtariEnv(object): def __init__(self, frame_skip=None, repeat_action_probability=0.0, state_shape=[84, 84], rom_path=None, game_name='pong', random_state=None, rendering=False, record_dir=None, obs_showing=False, channel_weights=[0.5870, 0.2989, 0.1140]): self.ale = ALEInterface() self.frame_skip = frame_skip self.state_shape = state_shape if random_state is None: random_state = np.random.RandomState(1234) self.rng = random_state self.channel_weights = channel_weights self.ale.setInt(b'random_seed', self.rng.randint(1000)) self.ale.setFloat(b'repeat_action_probability', repeat_action_probability) self.ale.setBool(b'color_averaging', False) if rendering: if sys.platform == 'darwin': import pygame pygame.init() self.ale.setBool(b'sound', False) # Sound doesn't work on OSX elif sys.platform.startswith('linux'): self.ale.setBool(b'sound', True) self.ale.setBool(b'display_screen', True) if rendering and record_dir is not None: # should be before loadROM self.ale.setString(b'record_screen_dir', record_dir.encode()) self.ale.setString(b'record_sound_filename', os.path.join(record_dir, '/sound.wav').encode()) self.ale.setInt(b'fragsize', 64) # to ensure proper sound sync (see ALE doc) self.ale.loadROM(str.encode(rom_path + game_name + '.bin')) self.legal_actions = self.ale.getMinimalActionSet() self.nb_actions = len(self.legal_actions) (self.screen_width, self.screen_height) = self.ale.getScreenDims() self._buffer = np.empty((self.screen_height, self.screen_width, 3), dtype=np.uint8) self.obs_showing = obs_showing def reset(self): self.ale.reset_game() return self.get_state() def step(self, action): reward = 0.0 if self.frame_skip is None: num_steps = 1 elif isinstance(self.frame_skip, int): num_steps = self.frame_skip else: num_steps = self.rng.randint(self.frame_skip[0], self.frame_skip[1]) for i in range(num_steps): reward += self.ale.act(self.legal_actions[action]) return self.get_state(), reward, self.ale.game_over(), {} def _get_image(self): self.ale.getScreenRGB(self._buffer) gray = self.channel_weights[0] * self._buffer[:, :, 0] + self.channel_weights[1] * self._buffer[:, :, 1] + \ self.channel_weights[2] * self._buffer[:, :, 2] x = cv2.resize(gray, tuple(self.state_shape), interpolation=cv2.INTER_LINEAR) return x def get_state(self): return self._get_image() def get_lives(self): return self.ale.lives()
class AtariPlayer(gym.Env): """ A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings. Info: score: the accumulated reward in the current game gameOver: True when the current game is Over """ def __init__(self, rom_file, viz=0, frame_skip=4, nullop_start=30, live_lost_as_eoe=True, max_num_frames=0): """ Args: rom_file: path to the rom frame_skip: skip every k frames and repeat the action viz: visualization to be done. Set to 0 to disable. Set to a positive number to be the delay between frames to show. Set to a string to be a directory to store frames. nullop_start: start with random number of null ops. live_losts_as_eoe: consider lost of lives as end of episode. Useful for training. max_num_frames: maximum number of frames per episode. """ super(AtariPlayer, self).__init__() if not os.path.isfile(rom_file) and '/' not in rom_file: rom_file = get_dataset_path('atari_rom', rom_file) assert os.path.isfile(rom_file), \ "ROM {} not found. Please download at {}".format(rom_file, ROM_URL) try: ALEInterface.setLoggerMode(ALEInterface.Logger.Error) except AttributeError: if execute_only_once(): logger.warn("You're not using latest ALE") # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 with _ALE_LOCK: self.ale = ALEInterface() self.rng = get_rng(self) self.ale.setInt(b"random_seed", self.rng.randint(0, 30000)) self.ale.setInt(b"max_num_frames_per_episode", max_num_frames) self.ale.setBool(b"showinfo", False) self.ale.setInt(b"frame_skip", 1) self.ale.setBool(b'color_averaging', False) # manual.pdf suggests otherwise. self.ale.setFloat(b'repeat_action_probability', 0.0) # viz setup if isinstance(viz, six.string_types): assert os.path.isdir(viz), viz self.ale.setString(b'record_screen_dir', viz) viz = 0 if isinstance(viz, int): viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): self.windowname = os.path.basename(rom_file) #cv2.namedWindow(self.windowname) self.ale.loadROM(rom_file.encode('utf-8')) self.width, self.height = self.ale.getScreenDims() self.actions = self.ale.getMinimalActionSet() self.live_lost_as_eoe = live_lost_as_eoe self.frame_skip = frame_skip self.nullop_start = nullop_start self.action_space = spaces.Discrete(len(self.actions)) self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width), dtype=np.uint8) self._restart_episode() def get_action_meanings(self): return [ACTION_MEANING[i] for i in self.actions] def _grab_raw_image(self): """ :returns: the current 3-channel image """ m = self.ale.getScreenRGB() return m.reshape((self.height, self.width, 3)) def _current_state(self): """ :returns: a gray-scale (h, w) uint8 image """ ret = self._grab_raw_image() # max-pooled over the last screen ret = np.maximum(ret, self.last_raw_screen) if self.viz: if isinstance(self.viz, float): #cv2.imshow(self.windowname, ret) #plt.ion() #cv2_imshow(ret) #cv2.waitKey(int(self.viz * 1000)) #ret.view(dtype=np.int8).reshape(640, 480) plt.figure() plt.imshow(ret) #plt.pause(0.0001) #print(ret.shape) # plt.draw() plt.show() ret = ret.astype('float32') # 0.299,0.587.0.114. same as rgb2y in torch/image ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :] return ret.astype('uint8') # to save some memory def _restart_episode(self): with _ALE_LOCK: self.ale.reset_game() # random null-ops start n = self.rng.randint(self.nullop_start) self.last_raw_screen = self._grab_raw_image() for k in range(n): if k == n - 1: self.last_raw_screen = self._grab_raw_image() self.ale.act(0) def reset(self): if self.ale.game_over(): self._restart_episode() return self._current_state() def render(self, *args, **kwargs): pass # visualization for this env is through the viz= argument when creating the player def step(self, act): oldlives = self.ale.lives() r = 0 for k in range(self.frame_skip): if k == self.frame_skip - 1: self.last_raw_screen = self._grab_raw_image() r += self.ale.act(self.actions[act]) newlives = self.ale.lives() if self.ale.game_over() or \ (self.live_lost_as_eoe and newlives < oldlives): break isOver = self.ale.game_over() if self.live_lost_as_eoe: isOver = isOver or newlives < oldlives info = {'ale.lives': newlives} return self._current_state(), r, isOver, info
class FreewayStateManager(object): """ A state manager with a goal state and action space of 2 State manager must, itself have *NO* state / updating behavior internally. Otherwise we need deepcopy() or cloneSystemState in get_action_probs, making it slower """ def __init__(self, random_state, rollout_limit=1000): self.rollout_limit = rollout_limit self.random_state = random_state self.ale = ALEInterface() self.ale.setInt('random_seed', 123) rom_path = "atari_roms/freeway.bin" self.ale.loadROM(rom_path) # Set USE_SDL to true to display the screen. ALE must be compilied # with SDL enabled for this to work. On OSX, pygame init is used to # proxy-call SDL_main. USE_SDL = False if USE_SDL: if sys.platform == 'darwin': import pygame pygame.init() self.ale.setBool('sound', False) # Sound doesn't work on OSX elif sys.platform.startswith('linux'): # the sound is awful self.ale.setBool('sound', False) self.ale.setBool('display_screen', True) # get background subtraction arr = [] s = self.ale.getScreenGrayscale() num = 1000 empty = np.zeros((num, s.shape[0], s.shape[1])) for ii in range(num): self.ale.act(0) s = self.ale.getScreenGrayscale() empty[ii] = s[..., 0] # max and min are 214, 142 o = np.zeros((empty.shape[1], empty.shape[2]), np.int) for i in range(o.shape[0]): for j in range(o.shape[1]): this_pixel = empty[:,i,j] (values,cnts) = np.unique(this_pixel, return_counts=True) o[i,j] = int(values[np.argmax(cnts)]) self.background = o self.color = 233 self.vert_l = 45 self.vert_r = 50 self.ale.reset_game() def _extract(self, s): if hasattr(s, "shape"): s = s[..., 0] s_a = [s] else: s_a = [s_i[..., 0] for s_i in s] """ if hasattr(s, "shape"): s = s[..., 0] else: # list of frames in s = np.concatenate(s, axis=-1) s = np.max(s, axis=-1) """ all_pos = [] for s in s_a: # try to find ourselves vert = s[:, self.vert_l:self.vert_r] # "racecar" strip detector u_d = np.zeros((vert.shape[1])) hodl = 0. * u_d - 1 strip_dets = [[] for i in range(len(u_d))] for ii in list(range(len(vert)))[::-1]: detect = (vert[ii] == 233) for jj in range(len(detect)): if u_d[jj] == 1: strip_dets[jj].append(ii) if u_d[jj] == 0 and detect[jj] == True: u_d[jj] = 1. hodl[jj] = ii strip_dets[jj].append(ii) elif u_d[jj] == 1 and detect[jj] == False: u_d[jj] = 0. hodl[jj] = -1 flat_dets = [(n, sl) for n, l in enumerate(strip_dets) for sl in l] # count of occurence for every pixel - max should be vert_l - vert_r counts = Counter(flat_dets) # aggregate counts across pixel window # pseudoconv 3x3 c_counts = {} h = 3 w = 3 for k1 in counts.keys(): c_counts[k1] = 0 for k2 in counts.keys(): if abs(k1[0] - k2[0]) <= w and abs(k1[1] - k2[1]) <= h: c_counts[k1] += 1 mx = max(c_counts.values()) all_mx = [k for k in c_counts if c_counts[k] == mx] # for now, assume med_w always == 2 since we move vertically #med_w = int(np.median([k[0] for k in all_mx])) med_w = 2 med_h = int(np.median([k[1] for k in all_mx])) all_pos.append(med_h) ib = 7 ob = 20 mb = 3 uu = med_h - ob u = med_h - ib d = med_h + ib dd = med_h + ob l = 45 + med_w - ib ll = 45 + med_w - ob r = 45 + med_w + ib rr = 45 + med_w + ob u_diff = s[uu:u, ll:rr] - self.background[uu:u, ll:rr] d_diff = s[d:dd, ll:rr] - self.background[d:dd, ll:rr] lu_diff = s[u:med_h-mb, ll:l] - self.background[u:med_h-mb, ll:l] lm_diff = s[med_h-mb:med_h+mb, ll:l] - self.background[med_h-mb:med_h+mb, ll:l] ld_diff = s[med_h+mb:d, ll:l] - self.background[med_h+mb:d, ll:l] ru_diff = s[u:med_h-mb, r:rr] - self.background[u:med_h-mb, r:rr] rm_diff = s[med_h-mb:med_h+mb, r:rr] - self.background[med_h-mb:med_h+mb, r:rr] rd_diff = s[med_h+mb:d, r:rr] - self.background[med_h+mb:d, r:rr] u_det = (np.abs(u_diff).sum() > 0) d_det = (np.abs(d_diff).sum() > 0) lu_det = (np.abs(lu_diff).sum() > 0) lm_det = (np.abs(lm_diff).sum() > 0) ld_det = (np.abs(ld_diff).sum() > 0) ru_det = (np.abs(ru_diff).sum() > 0) rm_det = (np.abs(rm_diff).sum() > 0) rd_det = (np.abs(rd_diff).sum() > 0) # buffer detector around our position s_min = [med_h, u_det, d_det, lu_det, lm_det, ld_det, ru_det, rm_det, rd_det] s_min = [int(np.min(all_pos)), int(np.max(all_pos))] return s_min def get_next_state_reward(self, state, action): # ignores state input due to ale being stateful frameskip = 48 total_reward = 0 all_s = [] for i in range(frameskip): reward = self.ale.act(action) s = self.ale.getScreenGrayscale() all_s.append(s) total_reward += reward # use last 2 frames? s_min = self._extract(all_s) return s_min, total_reward def get_action_space(self): return list([a for a in self.ale.getMinimalActionSet()]) def get_valid_actions(self, state): return list([a for a in self.ale.getMinimalActionSet()]) def get_ale_clone(self): # if state manager is stateful, will use this to clone return self.ale.cloneState() def get_init_state(self): ss = self.ale.getScreenGrayscale() s = self._extract(ss) s[0] = ss.shape[0] s[1] = ss.shape[0] return s def rollout_fn(self, state): # can define custom rollout function return self.random_state.choice(self.get_valid_actions(state)) def score(self, state): return 0. def is_finished(self, state): # if this check is slow # can rewrite as _is_finished # then add # self.is_finished = MemoizeMutable(self._is_finished) # to __init__ instead # return winner, score, end # winner normally in [-1, 0, 1] # if it's one player, can just use [0, 1] and it's fine # score arbitrary float value # end in [True, False] #return (1, 1., True) if state == self.goal_state else (0, 0., False) fin = self.ale.game_over() return (1, 1., True) if fin else (0, 0., False) def rollout_from_state(self, state): # example rollout function s = state w, sc, e = self.is_finished(state) if e: return sc c = 0 t_r = 0 while True: a = self.rollout_fn(s) s_n, r = self.get_next_state_reward(s, a) t_r += r e = self.is_finished(s_n) s = s_n c += 1 if e: return np.abs(s[1] - s[0]) / 200. #return t_r if c > self.rollout_limit: return np.abs(s[1] - s[0]) / 200.