def _init_ale(rand_seed, rom_file): #assert os.path.exists(rom_file), '%s does not exists.' ale = ALEInterface() ale.setInt('random_seed', rand_seed) ale.setBool('showinfo', False) ale.setInt('frame_skip', 1) ale.setFloat('repeat_action_probability', 0.0) ale.setBool('color_averaging', False) ale.loadROM(rom_file) return ale
def __init__(self, frame_skip=None, repeat_action_probability=0.0, state_shape=[84, 84], rom_path=None, game_name='pong', random_state=None, rendering=False, record_dir=None, obs_showing=False, channel_weights=[0.5870, 0.2989, 0.1140]): self.ale = ALEInterface() self.frame_skip = frame_skip self.state_shape = state_shape if random_state is None: random_state = np.random.RandomState(1234) self.rng = random_state self.channel_weights = channel_weights self.ale.setInt(b'random_seed', self.rng.randint(1000)) self.ale.setFloat(b'repeat_action_probability', repeat_action_probability) self.ale.setBool(b'color_averaging', False) if rendering: if sys.platform == 'darwin': import pygame pygame.init() self.ale.setBool(b'sound', False) # Sound doesn't work on OSX elif sys.platform.startswith('linux'): self.ale.setBool(b'sound', True) self.ale.setBool(b'display_screen', True) if rendering and record_dir is not None: # should be before loadROM self.ale.setString(b'record_screen_dir', record_dir.encode()) self.ale.setString(b'record_sound_filename', os.path.join(record_dir, '/sound.wav').encode()) self.ale.setInt(b'fragsize', 64) # to ensure proper sound sync (see ALE doc) self.ale.loadROM(str.encode(rom_path + game_name + '.bin')) self.legal_actions = self.ale.getMinimalActionSet() self.nb_actions = len(self.legal_actions) (self.screen_width, self.screen_height) = self.ale.getScreenDims() self._buffer = np.empty((self.screen_height, self.screen_width, 3), dtype=np.uint8) self.obs_showing = obs_showing
def __init__(self, random_state, rollout_limit=1000): self.rollout_limit = rollout_limit self.random_state = random_state self.ale = ALEInterface() self.ale.setInt('random_seed', 123) rom_path = "atari_roms/freeway.bin" self.ale.loadROM(rom_path) # Set USE_SDL to true to display the screen. ALE must be compilied # with SDL enabled for this to work. On OSX, pygame init is used to # proxy-call SDL_main. USE_SDL = False if USE_SDL: if sys.platform == 'darwin': import pygame pygame.init() self.ale.setBool('sound', False) # Sound doesn't work on OSX elif sys.platform.startswith('linux'): # the sound is awful self.ale.setBool('sound', False) self.ale.setBool('display_screen', True) # get background subtraction arr = [] s = self.ale.getScreenGrayscale() num = 1000 empty = np.zeros((num, s.shape[0], s.shape[1])) for ii in range(num): self.ale.act(0) s = self.ale.getScreenGrayscale() empty[ii] = s[..., 0] # max and min are 214, 142 o = np.zeros((empty.shape[1], empty.shape[2]), np.int) for i in range(o.shape[0]): for j in range(o.shape[1]): this_pixel = empty[:,i,j] (values,cnts) = np.unique(this_pixel, return_counts=True) o[i,j] = int(values[np.argmax(cnts)]) self.background = o self.color = 233 self.vert_l = 45 self.vert_r = 50 self.ale.reset_game()
def __init__(self, rom_file, viz=0, frame_skip=4, nullop_start=30, live_lost_as_eoe=True, max_num_frames=0): """ Args: rom_file: path to the rom frame_skip: skip every k frames and repeat the action viz: visualization to be done. Set to 0 to disable. Set to a positive number to be the delay between frames to show. Set to a string to be a directory to store frames. nullop_start: start with random number of null ops. live_losts_as_eoe: consider lost of lives as end of episode. Useful for training. max_num_frames: maximum number of frames per episode. """ super(AtariPlayer, self).__init__() if not os.path.isfile(rom_file) and '/' not in rom_file: rom_file = get_dataset_path('atari_rom', rom_file) assert os.path.isfile(rom_file), \ "ROM {} not found. Please download at {}".format(rom_file, ROM_URL) try: ALEInterface.setLoggerMode(ALEInterface.Logger.Error) except AttributeError: if execute_only_once(): logger.warn("You're not using latest ALE") # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86 with _ALE_LOCK: self.ale = ALEInterface() self.rng = get_rng(self) self.ale.setInt(b"random_seed", self.rng.randint(0, 30000)) self.ale.setInt(b"max_num_frames_per_episode", max_num_frames) self.ale.setBool(b"showinfo", False) self.ale.setInt(b"frame_skip", 1) self.ale.setBool(b'color_averaging', False) # manual.pdf suggests otherwise. self.ale.setFloat(b'repeat_action_probability', 0.0) # viz setup if isinstance(viz, six.string_types): assert os.path.isdir(viz), viz self.ale.setString(b'record_screen_dir', viz) viz = 0 if isinstance(viz, int): viz = float(viz) self.viz = viz if self.viz and isinstance(self.viz, float): self.windowname = os.path.basename(rom_file) #cv2.namedWindow(self.windowname) self.ale.loadROM(rom_file.encode('utf-8')) self.width, self.height = self.ale.getScreenDims() self.actions = self.ale.getMinimalActionSet() self.live_lost_as_eoe = live_lost_as_eoe self.frame_skip = frame_skip self.nullop_start = nullop_start self.action_space = spaces.Discrete(len(self.actions)) self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width), dtype=np.uint8) self._restart_episode()
return S, stacked_tuple_info else: return None if __name__ == "__main__": from atari_py.ale_python_interface import ALEInterface import os if len(sys.argv) < 2: print('Using default game atari_roms/breakout.bin') game = "atari_roms/breakout.bin" else: game = sys.argv[1] ale = ALEInterface() # Get & Set the desired settings ale.setInt('random_seed', 123) # Load the ROM file ale.loadROM(game) # Get the list of legal actions legal_actions = ale.getLegalActionSet() batch_size = 10 exp_replay = ReplayBuffer(batch_size) (screen_width, screen_height) = ale.getScreenDims()
def test_full_run(): from atari_py.ale_python_interface import ALEInterface game = "atari_roms/breakout.bin" ale = ALEInterface() # Get & Set the desired settings ale.setInt('random_seed', 123) # Load the ROM file ale.loadROM(game) # Get the list of legal actions legal_actions = ale.getLegalActionSet() batch_size = 10 exp_replay = ReplayBuffer(batch_size) (screen_width, screen_height) = ale.getScreenDims() import os tot_m, used_m, free_m = os.popen("free -th").readlines()[-1].split()[1:] last_counter = 0 random_state = np.random.RandomState(218) print("initial: {}, {}, {}".format(tot_m, used_m, free_m)) # Play 2k episodes for episode in range(2000): total_reward = 0 S = np.zeros(screen_width * screen_height, dtype=np.uint8) S = S.reshape(screen_height, screen_width)[:84, :84] this_counter = exp_replay.sent_counter if this_counter > last_counter + 1000: last_counter = this_counter tot_m, used_m, free_m = os.popen( "free -th").readlines()[-1].split()[1:] # the first three entries should match til 1M steps # then the second 2 should continue in lock step print("{}: {}, {}; {}, {}, {}".format( exp_replay.sent_counter, len(exp_replay.memory), len(exp_replay.reverse_experience_lookup.keys()), tot_m, used_m, free_m)) while not ale.game_over(): S_prime = np.zeros(screen_width * screen_height, dtype=np.uint8) ale.getScreen(S_prime) S_prime = S_prime.reshape(screen_height, screen_width)[:84, :84] a = random_state.choice(len(legal_actions)) action = legal_actions[a] # Apply an action and get the resulting reward reward = ale.act(action) won = 0 ongoing_flag = 1 experience = (S_prime, action, reward, won, ongoing_flag) S = S_prime exp_replay.add_experience(experience) batch = exp_replay.get_minibatch() batch = exp_replay.get_minibatch(index_list=[1, 2, 3, 10, 11]) if batch is not None: mb_S = batch[0] other_info = batch[1] del batch total_reward += reward print 'Episode', episode, 'ended with score:', total_reward ale.reset_game() lst = 0 for i in range(10000): if i > lst + 1000: tot_m, used_m, free_m = os.popen( "free -th").readlines()[-1].split()[1:] print("POST MEM {}: {}, {}; {}, {}, {}".format( exp_replay.sent_counter, len(exp_replay.memory), len(exp_replay.reverse_experience_lookup.keys()), tot_m, used_m, free_m)) lst = i batch = exp_replay.get_minibatch() mb_S = batch[0] other_info = batch[1] from IPython import embed embed() raise ValueError()