예제 #1
0
class AtariPlayer(gym.Env):
    """
    A wrapper for ALE emulator, with configurations to mimic DeepMind DQN settings.

    Info:
        score: the accumulated reward in the current game
        gameOver: True when the current game is Over
    """
    def __init__(self,
                 rom_file,
                 viz=0,
                 frame_skip=4,
                 nullop_start=30,
                 live_lost_as_eoe=True,
                 max_num_frames=0):
        """
        Args:
            rom_file: path to the rom
            frame_skip: skip every k frames and repeat the action
            viz: visualization to be done.
                Set to 0 to disable.
                Set to a positive number to be the delay between frames to show.
                Set to a string to be a directory to store frames.
            nullop_start: start with random number of null ops.
            live_losts_as_eoe: consider lost of lives as end of episode. Useful for training.
            max_num_frames: maximum number of frames per episode.
        """
        super(AtariPlayer, self).__init__()
        if not os.path.isfile(rom_file) and '/' not in rom_file:
            rom_file = get_dataset_path('atari_rom', rom_file)
        assert os.path.isfile(rom_file), \
            "ROM {} not found. Please download at {}".format(rom_file, ROM_URL)

        try:
            ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
        except AttributeError:
            if execute_only_once():
                logger.warn("You're not using latest ALE")

        # avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
        with _ALE_LOCK:
            self.ale = ALEInterface()
            self.rng = get_rng(self)
            self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
            self.ale.setInt(b"max_num_frames_per_episode", max_num_frames)
            self.ale.setBool(b"showinfo", False)

            self.ale.setInt(b"frame_skip", 1)
            self.ale.setBool(b'color_averaging', False)
            # manual.pdf suggests otherwise.
            self.ale.setFloat(b'repeat_action_probability', 0.0)

            # viz setup
            if isinstance(viz, six.string_types):
                assert os.path.isdir(viz), viz
                self.ale.setString(b'record_screen_dir', viz)
                viz = 0
            if isinstance(viz, int):
                viz = float(viz)
            self.viz = viz
            if self.viz and isinstance(self.viz, float):
                self.windowname = os.path.basename(rom_file)
                #cv2.namedWindow(self.windowname)

            self.ale.loadROM(rom_file.encode('utf-8'))
        self.width, self.height = self.ale.getScreenDims()
        self.actions = self.ale.getMinimalActionSet()

        self.live_lost_as_eoe = live_lost_as_eoe
        self.frame_skip = frame_skip
        self.nullop_start = nullop_start

        self.action_space = spaces.Discrete(len(self.actions))
        self.observation_space = spaces.Box(low=0,
                                            high=255,
                                            shape=(self.height, self.width),
                                            dtype=np.uint8)
        self._restart_episode()

    def get_action_meanings(self):
        return [ACTION_MEANING[i] for i in self.actions]

    def _grab_raw_image(self):
        """
        :returns: the current 3-channel image
        """
        m = self.ale.getScreenRGB()
        return m.reshape((self.height, self.width, 3))

    def _current_state(self):
        """
        :returns: a gray-scale (h, w) uint8 image
        """
        ret = self._grab_raw_image()
        # max-pooled over the last screen
        ret = np.maximum(ret, self.last_raw_screen)
        if self.viz:
            if isinstance(self.viz, float):
                #cv2.imshow(self.windowname, ret)
                #plt.ion()
                #cv2_imshow(ret)
                #cv2.waitKey(int(self.viz * 1000))

                #ret.view(dtype=np.int8).reshape(640, 480)
                plt.figure()
                plt.imshow(ret)
                #plt.pause(0.0001)
                #print(ret.shape)
                # plt.draw()
                plt.show()

        ret = ret.astype('float32')
        # 0.299,0.587.0.114. same as rgb2y in torch/image
        ret = cv2.cvtColor(ret, cv2.COLOR_RGB2GRAY)[:, :]
        return ret.astype('uint8')  # to save some memory

    def _restart_episode(self):
        with _ALE_LOCK:
            self.ale.reset_game()

        # random null-ops start
        n = self.rng.randint(self.nullop_start)
        self.last_raw_screen = self._grab_raw_image()
        for k in range(n):
            if k == n - 1:
                self.last_raw_screen = self._grab_raw_image()
            self.ale.act(0)

    def reset(self):
        if self.ale.game_over():
            self._restart_episode()
        return self._current_state()

    def render(self, *args, **kwargs):
        pass  # visualization for this env is through the viz= argument when creating the player

    def step(self, act):
        oldlives = self.ale.lives()
        r = 0
        for k in range(self.frame_skip):
            if k == self.frame_skip - 1:
                self.last_raw_screen = self._grab_raw_image()
            r += self.ale.act(self.actions[act])
            newlives = self.ale.lives()
            if self.ale.game_over() or \
                    (self.live_lost_as_eoe and newlives < oldlives):
                break

        isOver = self.ale.game_over()
        if self.live_lost_as_eoe:
            isOver = isOver or newlives < oldlives

        info = {'ale.lives': newlives}
        return self._current_state(), r, isOver, info
예제 #2
0
class AtariEnv(object):
    def __init__(self,
                 frame_skip=None,
                 repeat_action_probability=0.0,
                 state_shape=[84, 84],
                 rom_path=None,
                 game_name='pong',
                 random_state=None,
                 rendering=False,
                 record_dir=None,
                 obs_showing=False,
                 channel_weights=[0.5870, 0.2989, 0.1140]):
        self.ale = ALEInterface()
        self.frame_skip = frame_skip
        self.state_shape = state_shape
        if random_state is None:
            random_state = np.random.RandomState(1234)
        self.rng = random_state
        self.channel_weights = channel_weights
        self.ale.setInt(b'random_seed', self.rng.randint(1000))
        self.ale.setFloat(b'repeat_action_probability',
                          repeat_action_probability)
        self.ale.setBool(b'color_averaging', False)
        if rendering:
            if sys.platform == 'darwin':
                import pygame
                pygame.init()
                self.ale.setBool(b'sound', False)  # Sound doesn't work on OSX
            elif sys.platform.startswith('linux'):
                self.ale.setBool(b'sound', True)
            self.ale.setBool(b'display_screen', True)
        if rendering and record_dir is not None:  # should be before loadROM
            self.ale.setString(b'record_screen_dir', record_dir.encode())
            self.ale.setString(b'record_sound_filename',
                               os.path.join(record_dir, '/sound.wav').encode())
            self.ale.setInt(b'fragsize',
                            64)  # to ensure proper sound sync (see ALE doc)
        self.ale.loadROM(str.encode(rom_path + game_name + '.bin'))
        self.legal_actions = self.ale.getMinimalActionSet()
        self.nb_actions = len(self.legal_actions)
        (self.screen_width, self.screen_height) = self.ale.getScreenDims()
        self._buffer = np.empty((self.screen_height, self.screen_width, 3),
                                dtype=np.uint8)

        self.obs_showing = obs_showing

    def reset(self):
        self.ale.reset_game()
        return self.get_state()

    def step(self, action):
        reward = 0.0
        if self.frame_skip is None:
            num_steps = 1
        elif isinstance(self.frame_skip, int):
            num_steps = self.frame_skip
        else:
            num_steps = self.rng.randint(self.frame_skip[0],
                                         self.frame_skip[1])
        for i in range(num_steps):
            reward += self.ale.act(self.legal_actions[action])
        return self.get_state(), reward, self.ale.game_over(), {}

    def _get_image(self):
        self.ale.getScreenRGB(self._buffer)
        gray = self.channel_weights[0] * self._buffer[:, :, 0] + self.channel_weights[1] * self._buffer[:, :, 1] + \
           self.channel_weights[2] * self._buffer[:, :, 2]
        x = cv2.resize(gray,
                       tuple(self.state_shape),
                       interpolation=cv2.INTER_LINEAR)
        return x

    def get_state(self):
        return self._get_image()

    def get_lives(self):
        return self.ale.lives()
def test_full_run():
    from atari_py.ale_python_interface import ALEInterface

    game = "atari_roms/breakout.bin"

    ale = ALEInterface()

    # Get & Set the desired settings
    ale.setInt('random_seed', 123)

    # Load the ROM file
    ale.loadROM(game)

    # Get the list of legal actions
    legal_actions = ale.getLegalActionSet()

    batch_size = 10
    exp_replay = ReplayBuffer(batch_size)

    (screen_width, screen_height) = ale.getScreenDims()

    import os
    tot_m, used_m, free_m = os.popen("free -th").readlines()[-1].split()[1:]
    last_counter = 0
    random_state = np.random.RandomState(218)
    print("initial: {}, {}, {}".format(tot_m, used_m, free_m))
    # Play 2k episodes
    for episode in range(2000):
        total_reward = 0
        S = np.zeros(screen_width * screen_height, dtype=np.uint8)
        S = S.reshape(screen_height, screen_width)[:84, :84]
        this_counter = exp_replay.sent_counter
        if this_counter > last_counter + 1000:
            last_counter = this_counter
            tot_m, used_m, free_m = os.popen(
                "free -th").readlines()[-1].split()[1:]
            # the first three entries should match til 1M steps
            # then the second 2 should continue in lock step
            print("{}: {}, {}; {}, {}, {}".format(
                exp_replay.sent_counter, len(exp_replay.memory),
                len(exp_replay.reverse_experience_lookup.keys()), tot_m,
                used_m, free_m))
        while not ale.game_over():
            S_prime = np.zeros(screen_width * screen_height, dtype=np.uint8)
            ale.getScreen(S_prime)
            S_prime = S_prime.reshape(screen_height, screen_width)[:84, :84]
            a = random_state.choice(len(legal_actions))
            action = legal_actions[a]
            # Apply an action and get the resulting reward
            reward = ale.act(action)
            won = 0
            ongoing_flag = 1
            experience = (S_prime, action, reward, won, ongoing_flag)
            S = S_prime
            exp_replay.add_experience(experience)
            batch = exp_replay.get_minibatch()
            batch = exp_replay.get_minibatch(index_list=[1, 2, 3, 10, 11])
            if batch is not None:
                mb_S = batch[0]
                other_info = batch[1]
            del batch
            total_reward += reward
        print 'Episode', episode, 'ended with score:', total_reward
        ale.reset_game()

    lst = 0
    for i in range(10000):
        if i > lst + 1000:
            tot_m, used_m, free_m = os.popen(
                "free -th").readlines()[-1].split()[1:]
            print("POST MEM {}: {}, {}; {}, {}, {}".format(
                exp_replay.sent_counter, len(exp_replay.memory),
                len(exp_replay.reverse_experience_lookup.keys()), tot_m,
                used_m, free_m))
            lst = i

        batch = exp_replay.get_minibatch()
        mb_S = batch[0]
        other_info = batch[1]
    from IPython import embed
    embed()
    raise ValueError()
    ale = ALEInterface()

    # Get & Set the desired settings
    ale.setInt('random_seed', 123)

    # Load the ROM file
    ale.loadROM(game)

    # Get the list of legal actions
    legal_actions = ale.getLegalActionSet()

    batch_size = 10
    exp_replay = ReplayBuffer(batch_size)

    (screen_width, screen_height) = ale.getScreenDims()

    tot_m, used_m, free_m = os.popen("free -th").readlines()[-1].split()[1:]
    last_counter = 0

    random_state = np.random.RandomState(218)

    print("initial: {}, {}, {}".format(tot_m, used_m, free_m))
    # Play 2k episodes
    for episode in range(2000):
        total_reward = 0
        S = np.zeros(screen_width * screen_height, dtype=np.uint8)
        S = S.reshape(screen_height, screen_width)[:84, :84]
        this_counter = exp_replay.sent_counter
        if this_counter > last_counter + 1000:
            last_counter = this_counter
예제 #5
0
                        default=sys.stdout)

    args = parser.parse_args()

    ale = ALEInterface()
    __USE_SDL = True
    ale.setInt(b'random_seed', args.seed)
    ale.setFloat(b'repeat_action_probability', 0.0)

    print('Starting up: ' + args.game)
    game_path = atari_py.get_game_path(args.game)
    ale.loadROM(str.encode(game_path))
    print('Legal Actions: ', ale.getLegalActionSet())
    pygame.init()

    (w, h) = ale.getScreenDims()
    dim = (w * args.scale, h * args.scale)

    pygame.display.set_mode(dim)
    clock = pygame.time.Clock()
    FPS = 32

    quit = False
    while not quit and not ale.game_over():
        for event in pygame.event.get():
            if event.type == QUIT:
                quit = True
                break
            if event.type == KEYDOWN and event.key == K_ESCAPE:
                quit = True
                break