Exemplo n.º 1
0
class AtariPlayer(gym.Env):
    """
    A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.

    Info:
        score: the accumulated reward in the current game
        gameOver: True when the current game is Over
    """
    def __init__(self,
                 rom_file,
                 viz=0,
                 frame_skip=4,
                 nullop_start=30,
                 live_lost_as_eoe=True,
                 max_num_frames=0):
        """
        Args:
            rom_file: path to the rom
            frame_skip: skip every k frames and repeat the action
            viz: visualization to be done.
                Set to 0 to disable.
                Set to a positive number to be the delay between frames to show.
                Set to a string to be a directory to store frames.
            nullop_start: start with random number of null ops.
            live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
            max_num_frames: maximum number of frames per episode.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "ROM {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                #cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start

        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=(self.height, self.width),
                                            dtype=np.uint8)
        self._restart_episode()

    def get_action_meanings(self):
        return [ACTION_MEANING[i] for i in self.actions]

    def _grab_raw_image(self):
        """
        :returns: the current 3-channel image
        """
        m = self.ale.getScreenRGB()
        return m.reshape((self.height, self.width, 3))

    def _current_state(self):
        """
        :returns: a gray-scale (h, w) uint8 image
        """
        ret = self._grab_raw_image()
        # max-pooled over the last screen
        ret = np.maximum(ret, self.last_raw_screen)
        if self.viz:
            if isinstance(self.viz, float):
                #cv2.imshow(self.windowname, ret)
                #plt.ion()
                #cv2_imshow(ret)
                #cv2.waitKey(int(self.viz * 1000))

                #ret.view(dtype=np.int8).reshape(640, 480)
                plt.figure()
                plt.imshow(ret)
                #plt.pause(0.0001)
                #print(ret.shape)
                # plt.draw()
                plt.show()

        ret = ret.astype('float32')
        # 0.299,0.587.0.114. same as rgb2y in torch/image
        ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :]
        return ret.astype('uint8')  # to save some memory

    def _restart_episode(self):
        with _ALE_LOCK:
            self.ale.reset_game()

        # random null-ops start
        n = self.rng.randint(self.nullop_start)
        self.last_raw_screen = self._grab_raw_image()
        for k in range(n):
            if k == n - 1:
                self.last_raw_screen = self._grab_raw_image()
            self.ale.act(0)

    def reset(self):
        if self.ale.game_over():
            self._restart_episode()
        return self._current_state()

    def render(self, *args, **kwargs):
        pass  # visualization for this env is through the viz= argument when creating the player

    def step(self, act):
        oldlives = self.ale.lives()
        r = 0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives

        info = {'ale.lives': newlives}
        return self._current_state(), r, isOver, info
Exemplo n.º 2
0
class AtariEnv(object):
    def __init__(self,
                 frame_skip=None,
                 repeat_action_probability=0.0,
                 state_shape=[84, 84],
                 rom_path=None,
                 game_name='pong',
                 random_state=None,
                 rendering=False,
                 record_dir=None,
                 obs_showing=False,
                 channel_weights=[0.5870, 0.2989, 0.1140]):
        self.ale = ALEInterface()
        self.frame_skip = frame_skip
        self.state_shape = state_shape
        if random_state is None:
            random_state = np.random.RandomState(1234)
        self.rng = random_state
        self.channel_weights = channel_weights
        self.ale.setInt(b'random_seed', self.rng.randint(1000))
        self.ale.setFloat(b'repeat_action_probability',
                          repeat_action_probability)
        self.ale.setBool(b'color_averaging', False)
        if rendering:
            if sys.platform == 'darwin':
                import pygame
                pygame.init()
                self.ale.setBool(b'sound', False)  # Sound doesn't work on OSX
            elif sys.platform.startswith('linux'):
                self.ale.setBool(b'sound', True)
            self.ale.setBool(b'display_screen', True)
        if rendering and record_dir is not None:  # should be before loadROM
            self.ale.setString(b'record_screen_dir', record_dir.encode())
            self.ale.setString(b'record_sound_filename',
                               os.path.join(record_dir, '/sound.wav').encode())
            self.ale.setInt(b'fragsize',
                            64)  # to ensure proper sound sync (see ALE doc)
        self.ale.loadROM(str.encode(rom_path + game_name + '.bin'))
        self.legal_actions = self.ale.getMinimalActionSet()
        self.nb_actions = len(self.legal_actions)
        (self.screen_width, self.screen_height) = self.ale.getScreenDims()
        self._buffer = np.empty((self.screen_height, self.screen_width, 3),
                                dtype=np.uint8)

        self.obs_showing = obs_showing

    def reset(self):
        self.ale.reset_game()
        return self.get_state()

    def step(self, action):
        reward = 0.0
        if self.frame_skip is None:
            num_steps = 1
        elif isinstance(self.frame_skip, int):
            num_steps = self.frame_skip
        else:
            num_steps = self.rng.randint(self.frame_skip[0],
                                         self.frame_skip[1])
        for i in range(num_steps):
            reward += self.ale.act(self.legal_actions[action])
        return self.get_state(), reward, self.ale.game_over(), {}

    def _get_image(self):
        self.ale.getScreenRGB(self._buffer)
        gray = self.channel_weights[0] * self._buffer[:, :, 0] + self.channel_weights[1] * self._buffer[:, :, 1] + \
           self.channel_weights[2] * self._buffer[:, :, 2]
        x = cv2.resize(gray,
                       tuple(self.state_shape),
                       interpolation=cv2.INTER_LINEAR)
        return x

    def get_state(self):
        return self._get_image()

    def get_lives(self):
        return self.ale.lives()
 for episode in range(2000):
     total_reward = 0
     S = np.zeros(screen_width * screen_height, dtype=np.uint8)
     S = S.reshape(screen_height, screen_width)[:84, :84]
     this_counter = exp_replay.sent_counter
     if this_counter > last_counter + 1000:
         last_counter = this_counter
         tot_m, used_m, free_m = os.popen(
             "free -th").readlines()[-1].split()[1:]
         # the first three entries should match til 1M steps
         # then the second 2 should continue in lock step
         print("{}: {}, {}; {}, {}, {}".format(
             exp_replay.sent_counter, len(exp_replay.memory),
             len(exp_replay.reverse_experience_lookup.keys()), tot_m,
             used_m, free_m))
     while not ale.game_over():
         S_prime = np.zeros(screen_width * screen_height, dtype=np.uint8)
         ale.getScreen(S_prime)
         S_prime = S_prime.reshape(screen_height, screen_width)[:84, :84]
         a = random_state.choice(len(legal_actions))
         action = legal_actions[a]
         # Apply an action and get the resulting reward
         reward = ale.act(action)
         won = 0
         ongoing_flag = 1
         experience = (S_prime, action, reward, won, ongoing_flag)
         S = S_prime
         exp_replay.add_experience(experience)
         batch = exp_replay.get_minibatch()
         batch = exp_replay.get_minibatch(index_list=[1, 2, 3, 10, 11])
         #if batch is not None:
class FreewayStateManager(object):
    """
    A state manager with a goal state and action space of 2

    State manager must, itself have *NO* state / updating behavior
    internally. Otherwise we need deepcopy() or cloneSystemState in get_action_probs, making it slower
    """
    def __init__(self, random_state, rollout_limit=1000):
        self.rollout_limit = rollout_limit
        self.random_state = random_state
        self.ale = ALEInterface()
        self.ale.setInt('random_seed', 123)
        rom_path = "atari_roms/freeway.bin"
        self.ale.loadROM(rom_path)

        # Set USE_SDL to true to display the screen. ALE must be compilied
        # with SDL enabled for this to work. On OSX, pygame init is used to
        # proxy-call SDL_main.
        USE_SDL = False
        if USE_SDL:
          if sys.platform == 'darwin':
            import pygame
            pygame.init()
            self.ale.setBool('sound', False) # Sound doesn't work on OSX
          elif sys.platform.startswith('linux'):
            # the sound is awful
            self.ale.setBool('sound', False)
          self.ale.setBool('display_screen', True)

        # get background subtraction
        arr = []
        s = self.ale.getScreenGrayscale()
        num = 1000
        empty = np.zeros((num, s.shape[0], s.shape[1]))
        for ii in range(num):
             self.ale.act(0)
             s = self.ale.getScreenGrayscale()
             empty[ii] = s[..., 0]

        # max and min are 214, 142
        o = np.zeros((empty.shape[1], empty.shape[2]), np.int)
        for i in range(o.shape[0]):
            for j in range(o.shape[1]):
                this_pixel = empty[:,i,j]
                (values,cnts) = np.unique(this_pixel, return_counts=True)
                o[i,j] = int(values[np.argmax(cnts)])
        self.background = o
        self.color = 233
        self.vert_l = 45
        self.vert_r = 50
        self.ale.reset_game()

    def _extract(self, s):
        if hasattr(s, "shape"):
            s = s[..., 0]
            s_a = [s]
        else:
            s_a = [s_i[..., 0] for s_i in s]
        """
        if hasattr(s, "shape"):
            s = s[..., 0]
        else:
            # list of frames in
            s = np.concatenate(s, axis=-1)
            s = np.max(s, axis=-1)
        """
        all_pos = []
        for s in s_a:
            # try to find ourselves
            vert = s[:, self.vert_l:self.vert_r]
            # "racecar" strip detector
            u_d = np.zeros((vert.shape[1]))
            hodl = 0. * u_d - 1
            strip_dets = [[] for i in range(len(u_d))]
            for ii in list(range(len(vert)))[::-1]:
                detect = (vert[ii] == 233)
                for jj in range(len(detect)):
                    if u_d[jj] == 1:
                        strip_dets[jj].append(ii)
                    if u_d[jj] == 0 and detect[jj] == True:
                        u_d[jj] = 1.
                        hodl[jj] = ii
                        strip_dets[jj].append(ii)
                    elif u_d[jj] == 1 and detect[jj] == False:
                        u_d[jj] = 0.
                        hodl[jj] = -1
            flat_dets = [(n, sl) for n, l in enumerate(strip_dets) for sl in l]
            # count of occurence for every pixel - max should be vert_l - vert_r
            counts = Counter(flat_dets)
            # aggregate counts across pixel window
            # pseudoconv 3x3
            c_counts = {}
            h = 3
            w = 3
            for k1 in counts.keys():
                c_counts[k1] = 0
                for k2 in counts.keys():
                    if abs(k1[0] - k2[0]) <= w and abs(k1[1] - k2[1]) <= h:
                        c_counts[k1] += 1
            mx = max(c_counts.values())
            all_mx = [k for k in c_counts if c_counts[k] == mx]
            # for now, assume med_w always == 2 since we move vertically
            #med_w = int(np.median([k[0] for k in all_mx]))
            med_w = 2
            med_h = int(np.median([k[1] for k in all_mx]))

            all_pos.append(med_h)

            ib = 7
            ob = 20
            mb = 3
            uu = med_h - ob
            u = med_h - ib
            d = med_h + ib
            dd = med_h + ob
            l = 45 + med_w - ib
            ll = 45 + med_w - ob
            r = 45 + med_w + ib
            rr = 45 + med_w + ob
            u_diff = s[uu:u, ll:rr] - self.background[uu:u, ll:rr]
            d_diff = s[d:dd, ll:rr] - self.background[d:dd, ll:rr]
            lu_diff = s[u:med_h-mb, ll:l] - self.background[u:med_h-mb, ll:l]
            lm_diff = s[med_h-mb:med_h+mb, ll:l] - self.background[med_h-mb:med_h+mb, ll:l]
            ld_diff = s[med_h+mb:d, ll:l] - self.background[med_h+mb:d, ll:l]
            ru_diff = s[u:med_h-mb, r:rr] - self.background[u:med_h-mb, r:rr]
            rm_diff = s[med_h-mb:med_h+mb, r:rr] - self.background[med_h-mb:med_h+mb, r:rr]
            rd_diff = s[med_h+mb:d, r:rr] - self.background[med_h+mb:d, r:rr]
            u_det = (np.abs(u_diff).sum() > 0)
            d_det = (np.abs(d_diff).sum() > 0)
            lu_det = (np.abs(lu_diff).sum() > 0)
            lm_det = (np.abs(lm_diff).sum() > 0)
            ld_det = (np.abs(ld_diff).sum() > 0)
            ru_det = (np.abs(ru_diff).sum() > 0)
            rm_det = (np.abs(rm_diff).sum() > 0)
            rd_det = (np.abs(rd_diff).sum() > 0)
            # buffer detector around our position
            s_min = [med_h, u_det, d_det,
                     lu_det, lm_det, ld_det,
                     ru_det, rm_det, rd_det]
        s_min = [int(np.min(all_pos)), int(np.max(all_pos))]
        return s_min

    def get_next_state_reward(self, state, action):
        # ignores state input due to ale being stateful
        frameskip = 48
        total_reward = 0

        all_s = []
        for i in range(frameskip):
            reward = self.ale.act(action)
            s = self.ale.getScreenGrayscale()
            all_s.append(s)
            total_reward += reward
        # use last 2 frames?
        s_min = self._extract(all_s)
        return s_min, total_reward

    def get_action_space(self):
        return list([a for a in self.ale.getMinimalActionSet()])

    def get_valid_actions(self, state):
        return list([a for a in self.ale.getMinimalActionSet()])

    def get_ale_clone(self):
        # if state manager is stateful, will use this to clone
        return self.ale.cloneState()

    def get_init_state(self):
        ss = self.ale.getScreenGrayscale()
        s = self._extract(ss)
        s[0] = ss.shape[0]
        s[1] = ss.shape[0]
        return s

    def rollout_fn(self, state):
        # can define custom rollout function
        return self.random_state.choice(self.get_valid_actions(state))

    def score(self, state):
        return 0.

    def is_finished(self, state):
        # if this check is slow
        # can rewrite as _is_finished
        # then add
        # self.is_finished = MemoizeMutable(self._is_finished)
        # to __init__ instead

        # return winner, score, end
        # winner normally in [-1, 0, 1]
        # if it's one player, can just use [0, 1] and it's fine
        # score arbitrary float value
        # end in [True, False]
        #return (1, 1., True) if state == self.goal_state else (0, 0., False)
        fin = self.ale.game_over()
        return (1, 1., True) if fin else (0, 0., False)

    def rollout_from_state(self, state):
        # example rollout function
        s = state
        w, sc, e = self.is_finished(state)
        if e:
            return sc

        c = 0
        t_r = 0
        while True:
            a = self.rollout_fn(s)
            s_n, r = self.get_next_state_reward(s, a)
            t_r += r

            e = self.is_finished(s_n)
            s = s_n
            c += 1
            if e:
                return np.abs(s[1] - s[0]) / 200.
                #return t_r

            if c > self.rollout_limit:
                
                return np.abs(s[1] - s[0]) / 200.
def test_full_run():
    from atari_py.ale_python_interface import ALEInterface

    game = "atari_roms/breakout.bin"

    ale = ALEInterface()

    # Get & Set the desired settings
    ale.setInt('random_seed', 123)

    # Load the ROM file
    ale.loadROM(game)

    # Get the list of legal actions
    legal_actions = ale.getLegalActionSet()

    batch_size = 10
    exp_replay = ReplayBuffer(batch_size)

    (screen_width, screen_height) = ale.getScreenDims()

    import os
    tot_m, used_m, free_m = os.popen("free -th").readlines()[-1].split()[1:]
    last_counter = 0
    random_state = np.random.RandomState(218)
    print("initial: {}, {}, {}".format(tot_m, used_m, free_m))
    # Play 2k episodes
    for episode in range(2000):
        total_reward = 0
        S = np.zeros(screen_width * screen_height, dtype=np.uint8)
        S = S.reshape(screen_height, screen_width)[:84, :84]
        this_counter = exp_replay.sent_counter
        if this_counter > last_counter + 1000:
            last_counter = this_counter
            tot_m, used_m, free_m = os.popen(
                "free -th").readlines()[-1].split()[1:]
            # the first three entries should match til 1M steps
            # then the second 2 should continue in lock step
            print("{}: {}, {}; {}, {}, {}".format(
                exp_replay.sent_counter, len(exp_replay.memory),
                len(exp_replay.reverse_experience_lookup.keys()), tot_m,
                used_m, free_m))
        while not ale.game_over():
            S_prime = np.zeros(screen_width * screen_height, dtype=np.uint8)
            ale.getScreen(S_prime)
            S_prime = S_prime.reshape(screen_height, screen_width)[:84, :84]
            a = random_state.choice(len(legal_actions))
            action = legal_actions[a]
            # Apply an action and get the resulting reward
            reward = ale.act(action)
            won = 0
            ongoing_flag = 1
            experience = (S_prime, action, reward, won, ongoing_flag)
            S = S_prime
            exp_replay.add_experience(experience)
            batch = exp_replay.get_minibatch()
            batch = exp_replay.get_minibatch(index_list=[1, 2, 3, 10, 11])
            if batch is not None:
                mb_S = batch[0]
                other_info = batch[1]
            del batch
            total_reward += reward
        print 'Episode', episode, 'ended with score:', total_reward
        ale.reset_game()

    lst = 0
    for i in range(10000):
        if i > lst + 1000:
            tot_m, used_m, free_m = os.popen(
                "free -th").readlines()[-1].split()[1:]
            print("POST MEM {}: {}, {}; {}, {}, {}".format(
                exp_replay.sent_counter, len(exp_replay.memory),
                len(exp_replay.reverse_experience_lookup.keys()), tot_m,
                used_m, free_m))
            lst = i

        batch = exp_replay.get_minibatch()
        mb_S = batch[0]
        other_info = batch[1]
    from IPython import embed
    embed()
    raise ValueError()
Exemplo n.º 6
0
    print('Starting up: ' + args.game)
    game_path = atari_py.get_game_path(args.game)
    ale.loadROM(str.encode(game_path))
    print('Legal Actions: ', ale.getLegalActionSet())
    pygame.init()

    (w, h) = ale.getScreenDims()
    dim = (w * args.scale, h * args.scale)

    pygame.display.set_mode(dim)
    clock = pygame.time.Clock()
    FPS = 32

    quit = False
    while not quit and not ale.game_over():
        for event in pygame.event.get():
            if event.type == QUIT:
                quit = True
                break
            if event.type == KEYDOWN and event.key == K_ESCAPE:
                quit = True
                break
        key_state = pygame.key.get_pressed()

        left = key_state[K_LEFT] or key_state[K_a]
        right = key_state[K_RIGHT] or key_state[K_d]
        up = key_state[K_UP] or key_state[K_w]
        down = key_state[K_DOWN] or key_state[K_s]
        button1 = key_state[K_z] or key_state[K_SPACE]
        button2 = key_state[K_x] or key_state[K_RSHIFT] or key_state[K_LSHIFT]