class pyrlcade_environment(object):
    def init(self,rom_file,ale_frame_skip):

        self.ale = ALEInterface()

        self.max_frames_per_episode = self.ale.getInt("max_num_frames_per_episode");
        self.ale.set("random_seed",123)
        self.ale.set("disable_color_averaging",1)
        self.ale.set("frame_skip",ale_frame_skip)

        self.ale.loadROM(rom_file)
        self.legal_actions = self.ale.getMinimalActionSet()
        ram_size = self.ale.getRAMSize()
        self.ram = np.zeros((ram_size),dtype=np.uint8)
        self.ale.getRAM(self.ram)

        self.state = self.ale.getRAM(self.ram)

    def reset_state(self):
        self.ale.reset_game()

    def set_action(self,a):
        self.action = a

    def step(self):
        self.reward = self.ale.act(self.action)
        is_terminal = self.ale.game_over()
        return is_terminal

    def get_state(self):
        self.ale.getRAM(self.ram)
        return self.ram

    def get_reward(self):
        return self.reward
class pyrlcade_environment(object):
    def init(self, rom_file, ale_frame_skip):

        self.ale = ALEInterface()

        self.max_frames_per_episode = self.ale.getInt(
            "max_num_frames_per_episode")
        self.ale.set("random_seed", 123)
        self.ale.set("disable_color_averaging", 1)
        self.ale.set("frame_skip", ale_frame_skip)

        self.ale.loadROM(rom_file)
        self.legal_actions = self.ale.getMinimalActionSet()
        ram_size = self.ale.getRAMSize()
        self.ram = np.zeros((ram_size), dtype=np.uint8)
        self.ale.getRAM(self.ram)

        self.state = self.ale.getRAM(self.ram)

    def reset_state(self):
        self.ale.reset_game()

    def set_action(self, a):
        self.action = a

    def step(self):
        self.reward = self.ale.act(self.action)
        is_terminal = self.ale.game_over()
        return is_terminal

    def get_state(self):
        self.ale.getRAM(self.ram)
        return self.ram

    def get_reward(self):
        return self.reward
Esempio n. 3
0
    #clear screen
    screen.fill((0,0,0))

    #get atari screen pixels and blit them
    numpy_surface = np.frombuffer(game_surface.get_buffer(),dtype=np.int32)
    ale.getScreenRGB(numpy_surface)
    
    logger.log(a, TYPE_ACTION, cur_time)
    #if cur_time %2 == 0:
    logger.log(numpy_surface, TYPE_SCREEN, cur_time)

    del numpy_surface
    screen.blit(pygame.transform.scale2x(game_surface),(0,0))

    #get RAM
    ram_size = ale.getRAMSize()
    ram = np.zeros((ram_size),dtype=np.uint8)
    ale.getRAM(ram)
    
    #Display ram bytes
    font = pygame.font.SysFont("Ubuntu Mono",32)
    text = font.render("RAM: " ,1,(255,208,208))
    screen.blit(text,(330,10))

    font = pygame.font.SysFont("Ubuntu Mono",25)
    height = font.get_height()*1.2

    line_pos = 40
    ram_pos = 0
    while(ram_pos < 128):
        ram_string = ''.join(["%02X "%ram[x] for x in range(ram_pos,min(ram_pos+16,128))])
def main():
    result = {
        'name': [],
        'grouped_num': [],
        'distribution': [],
    }
    result_str = ''

    # all_game_list = ['air_raid-n', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis']
    # all_game_list = ['bank_heist', 'battle_zone', 'beam_rider', 'berzerk-n', 'bowling', 'boxing', 'breakout', 'carnival-n']
    # all_game_list = ['centipede', 'chopper_command', 'crazy_climber', 'demon_attack', 'double_dunk']
    # all_game_list = ['elevator_action-n', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar']
    # all_game_list = ['hero', 'ice_hockey', 'jamesbond', 'journey_escape-n', 'kangaroo', 'krull', 'kung_fu_master']
    # all_game_list = ['montezuma_revenge-n', 'ms_pacman', 'name_this_game', 'phoenix-n', 'pitfall-n', 'pong', 'pooyan-n']
    # all_game_list = ['private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 'seaquest', 'skiing-n']
    # all_game_list = ['solaris-n', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down']
    # all_game_list = ['venture', 'video_pinball', 'wizard_of_wor', 'yars_revenge-n', 'zaxxon']

    # all_game_list = ['pong', 'assault','ms_pacman']
    all_game_list = ['assault']

    for game in all_game_list:

        if '-n' in game:
            '''games that are not in the nature DQN list'''
            continue

        import atari_py
        game_path = atari_py.get_game_path(game)
        game_path = str.encode(game_path)

        env = ALEInterface()
        env.setFloat('repeat_action_probability'.encode('utf-8'), 0.0)
        env.setInt(b'random_seed', 3)
        env.loadROM(game_path)
        env.reset_game()

        if test in ['restoreState']:
            state_after_reset = env.cloneState()
        if test in ['restoreSystemState']:
            state_after_reset = env.cloneSystemState()
        if test in ['setRAM']:
            ram_after_reset = env.getRAM()
            state_after_reset = env.cloneSystemState()
            ram_candidate = np.load(
                './stochasticity_ram_mask/{}.npy'.format(game), )

        print('=====================================================')
        try:
            action_sequence = np.load(
                './action_sequence/action_sequence_{}_{}.npy'.format(
                    sequence,
                    game,
                ))
            print('action_sequence loaded')
        except Exception as e:
            '''generate a sequence of actions'''
            action_sequence = np.random.randint(
                len(env.getMinimalActionSet()),
                size=sequence,
            )
            np.save(
                './action_sequence/action_sequence_{}_{}.npy'.format(
                    sequence,
                    game,
                ),
                action_sequence,
            )
            print('action_sequence generated')
        print('=====================================================')

        bunch_obs = []
        distribution = []
        episode_length = -1
        state_metrix = []
        ram_metrix = []
        for bunch_i in range(bunch):

            if test in ['loadROM']:
                env.setInt(b'random_seed', bunch_i)
                env.loadROM(game_path)
                env.reset_game()
            elif test in ['restoreState']:
                env.restoreState(state_after_reset)
            elif test in ['restoreSystemState']:
                env.restoreSystemState(state_after_reset)
            elif test in ['setRAM']:
                env.reset_game()
                env.restoreSystemState(state_after_reset)
                env.setRAM(ram_after_reset)
                env.setRAM(env.getRAM() * (1 - ram_candidate) + ram_candidate *
                           (bunch_i % 255))

            state_sequence = []
            ram_sequence = []

            has_terminated = False
            for sequence_i in range(sequence):

                for frame_skip_i in range(frame_skip):
                    if not has_terminated:
                        env.act(env.getMinimalActionSet()[
                            action_sequence[sequence_i]])
                        if env.game_over():
                            episode_length = sequence_i
                            has_terminated = True
                    if has_terminated:
                        break

                try:
                    clear_print('[{}|{}|{}]'.format(bunch_i, sequence_i,
                                                    episode_length))
                except Exception as e:
                    pass

                state_sequence += [env.getScreenRGB()]
                ram_sequence += [process_ram(env.getRAM())]

                if has_terminated:
                    break

            if sequence > 0:
                if episode_length < 0:
                    # raise Exception('Did not terminated')
                    print('# WARNING: Did not terminated')

            obs = env.getScreenRGB()

            state_metrix += [copy.deepcopy(state_sequence)]
            ram_metrix += [copy.deepcopy(ram_sequence)]

            if_has_identical_one = False
            for bunch_obs_i in range(len(bunch_obs)):
                max_value = np.max(np.abs(obs - bunch_obs[bunch_obs_i]))
                if max_value < 1:
                    if_has_identical_one = True
                    distribution[bunch_obs_i] += 1
                    break

            if if_has_identical_one is False:
                bunch_obs += [obs]
                distribution += [1]

        grouped_num = len(bunch_obs)
        result_str = '{}game:{} grouped_num:{} distribution:{} \n'.format(
            result_str,
            game,
            grouped_num,
            distribution,
        )
        try:
            game_list += [game]
        except Exception as e:
            game_list = [game]
        try:
            grouped_num_list += [grouped_num]
        except Exception as e:
            grouped_num_list = [grouped_num]

        max_lenth = 0
        for bunch_i in range(len(state_metrix)):
            if len(state_metrix[bunch_i]) > max_lenth:
                max_lenth = len(state_metrix[bunch_i])
        for bunch_i in range(len(state_metrix)):
            state_metrix[bunch_i] += ([
                np.zeros(shape=state_metrix[0][0].shape,
                         dtype=state_metrix[0][0].dtype)
            ] * (max_lenth - len(state_metrix[bunch_i])))
            ram_metrix[bunch_i] += ([
                np.zeros(shape=ram_metrix[0][0].shape,
                         dtype=ram_metrix[0][0].dtype)
            ] * (max_lenth - len(state_metrix[bunch_i])))

        state_list = []
        state_metrix_id = np.zeros((len(state_metrix), len(state_metrix[0])),
                                   dtype=int)
        for bunch_i in range(len(state_metrix)):
            for sequence_i in range(len(state_metrix[0])):
                found_in_state_list = False
                for state_list_id in range(len(state_list)):
                    if np.max(state_list[state_list_id] -
                              state_metrix[bunch_i][sequence_i]) < 1:
                        state_metrix_id[bunch_i][sequence_i] = state_list_id
                        found_in_state_list = True
                        break
                if not found_in_state_list:
                    state_list += [np.copy(state_metrix[bunch_i][sequence_i])]
                    state_metrix_id[bunch_i][sequence_i] = (len(state_list) -
                                                            1)

        state_metrix_id_unsorted = np.copy(state_metrix_id)
        state_metrix_id = state_metrix_id.tolist()
        state_metrix_id.sort(key=lambda row: row[:], reverse=True)
        state_metrix_id = np.array(state_metrix_id)

        fig, ax = plt.subplots()
        im = ax.imshow(state_metrix_id)
        plt.show()
        plt.savefig(
            './results/{}_state_metrix_id.jpg'.format(game),
            dpi=600,
        )

        state_metrix_figure = np.zeros(
            ((10 + state_metrix[0][0].shape[0]) * len(state_metrix),
             state_metrix[0][0].shape[1] * len(state_metrix[0]),
             state_metrix[0][0].shape[2]),
            dtype=state_metrix[0][0].dtype)
        ram_metrix_figure = np.zeros(
            ((5 + ram_metrix[0][0].shape[0]) * len(state_metrix),
             ram_metrix[0][0].shape[1] * len(state_metrix[0]),
             ram_metrix[0][0].shape[2]),
            dtype=ram_metrix[0][0].dtype)

        ram_candidate = list(range(env.getRAMSize()))

        for bunch_i in range(len(state_metrix)):
            ram_metrix_figure[((bunch_i) * (5 + ram_metrix[0][0].shape[0])):(
                5 + (bunch_i) * (5 + ram_metrix[0][0].shape[0])), :, 2] = 255
        for bunch_i in range(len(state_metrix)):
            for sequence_i in range(len(state_metrix[0])):
                state_metrix_figure[
                    (10 + (bunch_i) *
                     (10 + state_metrix[0][0].shape[0])):(bunch_i + 1) *
                    (10 + state_metrix[0][0].shape[0]), (sequence_i) *
                    state_metrix[0][0].shape[1]:(sequence_i + 1) *
                    state_metrix[0][0].shape[1]] = state_list[
                        state_metrix_id[bunch_i][sequence_i]]
                for bunch_ii in range(state_metrix_id.shape[0]):
                    if np.max(state_metrix_id_unsorted[bunch_ii] -
                              state_metrix_id[bunch_i]) < 1:
                        at_unsorted_bunch = bunch_ii
                        break
                ram_metrix_figure[(
                    5 + (bunch_i) *
                    (5 + ram_metrix[0][0].shape[0])):(bunch_i + 1) *
                                  (5 + ram_metrix[0][0].shape[0]),
                                  (sequence_i) *
                                  ram_metrix[0][0].shape[1]:(sequence_i + 1) *
                                  ram_metrix[0][0].shape[1]] = ram_metrix[
                                      at_unsorted_bunch][sequence_i]

        for bunch_i in range(len(state_metrix)):
            for sequence_i in range(len(state_metrix[0])):
                if bunch_i > 0:
                    if state_metrix_id[bunch_i][sequence_i] != state_metrix_id[
                            bunch_i - 1][sequence_i]:
                        # draw a line to seperate the bunches
                        previous = ram_metrix_figure[(
                            5 + (bunch_i - 1) *
                            (5 + ram_metrix[0][0].shape[0])):(
                                (bunch_i) * (5 + ram_metrix[0][0].shape[0])),
                                                     sequence_i, 0]
                        later = ram_metrix_figure[(
                            5 + (bunch_i) * (5 + ram_metrix[0][0].shape[0])):(
                                (bunch_i + 1) *
                                (5 + ram_metrix[0][0].shape[0])), sequence_i,
                                                  0]
                        delta = np.abs(previous - later)
                        state_metrix_figure[(
                            (bunch_i) * (10 + state_metrix[0][0].shape[0])):(
                                10 + (bunch_i) *
                                (10 + state_metrix[0][0].shape[0])),
                                            (sequence_i) *
                                            state_metrix[0][0].shape[1]:,
                                            0] = 255
                        ram_metrix_figure[((bunch_i) *
                                           (5 + ram_metrix[0][0].shape[0])
                                           ):(5 + (bunch_i) *
                                              (5 + ram_metrix[0][0].shape[0])),
                                          (sequence_i) *
                                          ram_metrix[0][0].shape[1]:, 0] = 255
                        ram_metrix_figure[((bunch_i) *
                                           (5 + ram_metrix[0][0].shape[0])
                                           ):(5 + (bunch_i) *
                                              (5 + ram_metrix[0][0].shape[0])),
                                          (sequence_i) *
                                          ram_metrix[0][0].shape[1]:, 1:] = 0

        from PIL import Image
        Image.fromarray(state_metrix_figure).save(
            "./results/{}_state_metrix_figure.jpeg".format(game))
        Image.fromarray(ram_metrix_figure.astype(
            state_metrix_figure.dtype)).save(
                "./results/{}_ram_metrix_figure.jpeg".format(game))

    print(result_str)
    print('===============')
    for game_i in range(len(game_list)):
        print(game_list[game_i])
    for grouped_num_i in range(len(grouped_num_list)):
        print(grouped_num_list[grouped_num_i])
    keys |= pressed[pygame.K_z] << 4
    a = key_action_tform_table[keys]
    reward = ale.act(a)
    total_reward += reward

    #clear screen
    screen.fill((0, 0, 0))

    #get atari screen pixels and blit them
    numpy_surface = np.frombuffer(game_surface.get_buffer(), dtype=np.int32)
    ale.getScreenRGB(numpy_surface)
    del numpy_surface
    screen.blit(pygame.transform.scale2x(game_surface), (0, 0))

    #get RAM
    ram_size = ale.getRAMSize()
    ram = np.zeros((ram_size), dtype=np.uint8)
    ale.getRAM(ram)

    #Display ram bytes
    font = pygame.font.SysFont("Ubuntu Mono", 32)
    text = font.render("RAM: ", 1, (255, 208, 208))
    screen.blit(text, (330, 10))

    font = pygame.font.SysFont("Ubuntu Mono", 25)
    height = font.get_height() * 1.2

    line_pos = 40
    ram_pos = 0
    while (ram_pos < 128):
        ram_string = ''.join(
def main():
    result = {
        'name': [],
        'grouped_num': [],
        'distribution': [],
    }
    result_str = ''

    # all_game_list = ['air_raid-n', 'alien', 'amidar', 'assault', 'asterix', 'asteroids', 'atlantis']
    # all_game_list = ['bank_heist', 'battle_zone', 'beam_rider', 'berzerk-n', 'bowling', 'boxing', 'breakout', 'carnival-n']
    # all_game_list = ['centipede', 'chopper_command', 'crazy_climber', 'demon_attack', 'double_dunk']
    # all_game_list = ['elevator_action-n', 'enduro', 'fishing_derby', 'freeway', 'frostbite', 'gopher', 'gravitar']
    # all_game_list = ['hero', 'ice_hockey', 'jamesbond', 'journey_escape-n', 'kangaroo', 'krull', 'kung_fu_master']
    # all_game_list = ['montezuma_revenge-n', 'ms_pacman', 'name_this_game', 'phoenix-n', 'pitfall-n', 'pong', 'pooyan-n']
    # all_game_list = ['private_eye', 'qbert', 'riverraid', 'road_runner', 'robotank', 'seaquest', 'skiing-n']
    # all_game_list = ['solaris-n', 'space_invaders', 'star_gunner', 'tennis', 'time_pilot', 'tutankham', 'up_n_down']
    # all_game_list = ['venture', 'video_pinball', 'wizard_of_wor', 'yars_revenge-n', 'zaxxon']

    all_game_list = ['assault']

    for game in all_game_list:

        if '-n' in game:
            '''games that are not in the nature DQN list'''
            continue

        import atari_py
        game_path = atari_py.get_game_path(game)
        game_path = str.encode(game_path)

        env = ALEInterface()
        env.setFloat('repeat_action_probability'.encode('utf-8'), 0.0)

        env.setInt(b'random_seed', 3)
        env.loadROM(game_path)
        env.reset_game()

        print('=====================================================')
        try:
            action_sequence = np.load(
                './action_sequence/action_sequence_{}_{}.npy'.format(
                    sequence,
                    game,
                ))
            print('action_sequence loaded')
        except Exception as e:
            '''generate a sequence of actions'''
            action_sequence = np.random.randint(
                len(env.getMinimalActionSet()),
                size=sequence,
            )
            np.save(
                './action_sequence/action_sequence_{}_{}.npy'.format(
                    sequence,
                    game,
                ),
                action_sequence,
            )
            print('action_sequence generated')
        print('=====================================================')

        state_sequence_base = []
        ram_sequence_base = []
        has_terminated = False
        for sequence_i in range(sequence):

            state_sequence_base += [env.getScreenRGB()]
            ram_sequence_base += [env.getRAM()]

            for frame_skip_i in range(frame_skip):
                if not has_terminated:
                    env.act(
                        env.getMinimalActionSet()[action_sequence[sequence_i]])
                    if env.game_over():
                        episode_length = sequence_i
                        has_terminated = True
                if has_terminated:
                    break

            if has_terminated:
                break

        if has_terminated in [False]:
            raise Exception('sequence length is not enough')

        ram_candidate = np.ones((env.getRAMSize()), dtype=np.uint8)

        state_sequence_branch = []
        ram_sequence_branch = []
        for bunch_i in range(bunch):

            env.setInt(b'random_seed', bunch_i)
            env.loadROM(game_path)
            env.reset_game()

            has_terminated = False
            for sequence_i in range(sequence):

                state_sequence_branch += [env.getScreenRGB()]
                ram_sequence_branch += [env.getRAM()]

                if sequence_i > 0:
                    max_value = np.max(
                        np.abs(env.getScreenRGB() -
                               state_sequence_base[sequence_i]))
                    if max_value > 0:
                        delta_ram = np.sign(
                            np.abs(ram_sequence_branch[sequence_i - 1] -
                                   ram_sequence_base[sequence_i - 1]))
                        ram_candidate *= delta_ram
                        remain = np.sum(ram_candidate)
                        print('remain {} bytes'.format(remain))
                        if remain <= 1:
                            if remain == 1:
                                print(ram_candidate)
                                np.save(
                                    './stochasticity_ram_mask/{}.npy'.format(
                                        game),
                                    ram_candidate,
                                )
                                raise Exception('done')
                            else:
                                raise Exception('error')
                        has_terminated = True

                if has_terminated:
                    break

                for frame_skip_i in range(frame_skip):
                    if not has_terminated:
                        env.act(env.getMinimalActionSet()[
                            action_sequence[sequence_i]])
                        if env.game_over():
                            has_terminated = True
                    if has_terminated:
                        break

                if has_terminated:
                    break
    def __init__(self,
                 random_seed,
                 frame_skip,
                 repeat_action_probability,
                 sound,
                 display_screen,
                 block_state_repr=None,
                 enemy_state_repr=None,
                 friendly_state_repr=None):
        ale = ALEInterface()

        # Get & Set the desired settings
        if random_seed is not None:
            ale.setInt('random_seed', random_seed)
        ale.setInt('frame_skip', frame_skip)
        ale.setFloat('repeat_action_probability', repeat_action_probability)

        if display_screen:
            if sys.platform == 'darwin':
                import pygame
                pygame.init()
            ale.setBool('sound', sound)

        ale.setBool('display_screen', display_screen)

        # Load the ROM file
        ale.loadROM('qbert.bin')

        # Get the list of legal actions
        legal_actions = ale.getLegalActionSet()
        minimal_actions = ale.getMinimalActionSet()
        logging.debug('Legal actions: {}'.format(
            [action_number_to_name(a) for a in legal_actions]))
        logging.debug('Minimal actions: {}'.format(
            [action_number_to_name(a) for a in minimal_actions]))

        width, height = ale.getScreenDims()
        rgb_screen = np.empty([height, width, 3], dtype=np.uint8)

        ram_size = ale.getRAMSize()
        ram = np.zeros(ram_size, dtype=np.uint8)

        # ALE components
        self.ale = ale
        self.lives = ale.lives()
        self.rgb_screen = rgb_screen
        self.ram_size = ale.getRAMSize()
        self.ram = ram

        # Verbose state representation
        self.desired_color = COLOR_YELLOW
        self.block_colors = INITIAL_COLORS
        self.enemies = INITIAL_ENEMY_POSITIONS
        self.friendlies = INITIAL_FRIENDLY_POSITIONS
        self.discs = INITIAL_DISCS
        self.current_row, self.current_col = 0, 0
        self.level = 1
        self.enemy_present = False
        self.friendly_present = False
        self.block_state_repr = block_state_repr
        self.enemy_state_repr = enemy_state_repr
        self.friendly_state_repr = friendly_state_repr
        self.num_colored_blocks = 0
Esempio n. 8
0
class KungFuMaster(object):
    def __init__(
            self,
            rom='/home/josema/AI/ALE/Arcade-Learning-Environment/Roms/kung_fu_master.bin',
            trainsessionname='test'):

        self.agent = None
        self.isAuto = True
        self.gui_visible = False
        self.userquit = False
        self.optimalPolicyUser = False  # optimal policy set by user
        self.trainsessionname = trainsessionname
        self.elapsedtime = 0  # elapsed time for this experiment

        self.keys = 0

        # Configuration
        self.pause = False  # game is paused
        self.debug = False

        self.sleeptime = 0.0
        self.command = 0
        self.iteration = 0
        self.cumreward = 0
        self.cumreward100 = 0  # cum reward for statistics
        self.cumscore100 = 0
        self.ngoalreached = 0
        self.max_level = 1

        self.hiscore = 0
        self.hireward = -1000000
        self.resfile = open("data/" + self.trainsessionname + ".dat", "a+")

        self.legal_actions = 0
        self.rom = rom
        self.key_status = []

    def init(self, agent):  # init after creation (uses args set from cli)
        self.ale = ALEInterface()
        self.ale.setInt('random_seed', 123)
        ram_size = self.ale.getRAMSize()
        self.ram = np.zeros((ram_size), dtype=np.uint8)

        if (self.gui_visible):
            os.environ['SDL_VIDEO_CENTERED'] = '1'
            if sys.platform == 'darwin':
                pygame.init()
                self.ale.setBool('sound', False)  # Sound doesn't work on OSX
            elif sys.platform.startswith('linux'):
                pygame.init()

                self.ale.setBool('sound', True)
                self.ale.setBool('display_screen', False)

        self.ale.loadROM(self.rom)
        self.legal_actions = self.ale.getLegalActionSet()

        if (self.gui_visible):
            (self.screen_width, self.screen_height) = self.ale.getScreenDims()
            print("width/height: " + str(self.screen_width) + "/" +
                  str(self.screen_height))

            (display_width, display_height) = (1024, 420)
            self.screen = pygame.display.set_mode(
                (display_width, display_height))

            pygame.display.set_caption(
                "Reinforcement Learning - Sapienza - Jose M Salas")
            self.numpy_surface = np.zeros(
                (self.screen_height, self.screen_width, 3), dtype=np.uint8)

            self.game_surface = pygame.Surface(
                (self.screen_width, self.screen_height))

            pygame.display.flip()
            #init clock
            self.clock = pygame.time.Clock()

        self.agent = agent
        self.nactions = len(
            self.legal_actions
        )  # 0: not moving, 1: left, 2: right, 3: up, 4: down
        for i in range(self.nactions):
            self.key_status.append(False)

        print(self.nactions)
        #        ns = 89999 # Number of statuses if we use enemy type ram info without level number
        #FINAL        ns = 489999 # Number of statuses if we use enemy type ram info
        ns = 4899999  # Number of statuses if we use enemy type ram info

        #        ns = 48999
        print('Number of states: %d' % ns)
        self.agent.init(ns, self.nactions)  # 1 for RA not used here

    def initScreen(self):

        if (self.gui_visible):
            if sys.platform == 'darwin':
                pygame.init()
                self.ale.setBool('sound', False)  # Sound doesn't work on OSX
            elif sys.platform.startswith('linux'):
                pygame.init()

                self.ale.setBool('sound', True)
                self.ale.setBool('display_screen', False)
        if (self.gui_visible):
            (self.screen_width, self.screen_height) = self.ale.getScreenDims()
            print("width/height: " + str(self.screen_width) + "/" +
                  str(self.screen_height))

            (display_width, display_height) = (1024, 420)
            self.screen = pygame.display.set_mode(
                (display_width, display_height))

            pygame.display.set_caption(
                "Reinforcement Learning - Sapienza - Jose M Salas")
            self.numpy_surface = np.zeros(
                (self.screen_height, self.screen_width, 3), dtype=np.uint8)

            self.game_surface = pygame.Surface(
                (self.screen_width, self.screen_height))

            pygame.display.flip()
            #init clock
            self.clock = pygame.time.Clock()

    def reset(self):
        self.pos_x = 0
        self.pos_y = 0
        # Kung fu master observations
        self.enemy_pos = 0
        self.n_enemies = 0
        self.my_pos = 0
        self.danger_pos = 0
        self.danger_type = 0
        self.enemy_type = 0  # 0, 1, 2, 3, 80, 81, 82, 40
        self.blocked = 0
        self.prev_blocked = 0
        self.hold_hit = 0
        self.time_left1 = 0
        self.time_left2 = 0
        self.my_energy = 39
        self.previous_my_energy = 39
        self.lifes = 3
        self.previous_lifes = 3
        self.got_hit = 0
        self.got_blocked = 0
        self.got_unblocked = 0
        self.still_blocked = False
        self.starting_pos = 0
        self.level = 1

        self.score = 0
        self.cumreward = 0
        self.cumscore = 0
        self.action_reward = 0

        self.current_reward = 0  # accumulate reward over all events happened during this action until next different state

        self.prev_state = None  # previous state
        self.firstAction = True  # first action of the episode
        self.finished = False  # episode finished
        self.newstate = True  # new state reached
        self.numactions = 0  # number of actions in this episode
        self.iteration += 1

        self.agent.optimal = self.optimalPolicyUser or (
            self.iteration % 100
        ) == 0  # False #(random.random() < 0.5)  # choose greedy action selection for the entire episode

    def pair_function(self):
        # Combine the number of enemies, player blocked and danger type information into 7 different states
        if self.n_enemies > 0:
            self.danger_type = 0

    # print (str(self.n_enemies) + " - " + str(self.danger_type) + ' - ' + str(self.blocked))
        pair = (int)(
            (0.5 * (self.n_enemies + self.danger_type) *
             (self.n_enemies + self.danger_type + 1) + self.danger_type + 1) *
            (1 - (self.blocked / 128)))
        if pair > 8:
            return 5  #game not started yet
        else:
            return pair

    def enemy_type_s(self):
        if self.enemy_type > 127:
            return (self.enemy_type - 128 + 4)
        elif self.enemy_type == 64:
            return 8
        else:
            return self.enemy_type

    def getstate(self):

        #        print ('enemy type: ' + str(self.enemy_type_s()) + 'level: ' + str(self.level -1) )
        x = (int)((self.level - 1) * 1000000 + self.pair_function() * 100000 +
                  (self.enemy_type_s() * 10000) +
                  np.rint(self.my_pos / 32) * 1000 +
                  np.rint(self.enemy_pos / 32) * 100 +
                  np.rint(self.danger_pos / 32) * 10 +
                  np.rint(self.hold_hit / 16))
        #3FINAL        x = (int)((self.enemy_type_s()*1000) + (self.level-1)*100000 + self.pair_function()*10000 + np.rint(self.enemy_pos/32)*100 + np.rint(self.danger_pos/32)*10 + np.rint(self.hold_hit/16))

        #2NO LEVEL        x = (int)((self.enemy_type_s()*1000) + self.pair_function()*10000 + np.rint(self.enemy_pos/32)*100 + np.rint(self.danger_pos/32)*10 + np.rint(self.hold_hit/16))
        #1NO ENEMY TYPE        x = (int)((self.level-1)*10000 + self.pair_function()*1000 + np.rint(self.enemy_pos/32)*100 + np.rint(self.danger_pos/32)*10 + np.rint(self.hold_hit/16))

        return x

    def goal_reached(self):

        #return (self.my_energy>0 and self.time_left1==0 and self.time_left2<5) #and self.my_energy==39)
        return (self.level == 5)

    def update(self, a):

        self.command = a
        # Update RAM
        self.ale.getRAM(self.ram)

        # Get info from RAM
        self.enemy_pos = self.ram[72]
        self.n_enemies = self.ram[91]
        self.danger_pos = self.ram[73]
        self.my_pos = self.ram[74]
        self.hold_hit = self.ram[77]

        self.enemy_type = self.ram[54]

        if self.level < self.ram[31]:
            self.starting_pos = self.ram[74]
        self.level = self.ram[31]
        self.max_level = max(self.level, self.max_level)

        # Danger/Enemy position:
        # 49 = no danger
        # 50 = danger approaching from left
        # 208 = danger approaching from right

        # ram[96] = 6, danger comes from top
        # ram[96] = 29, danger comes from bottom
        # ram[96] = 188, none
        if self.ram[96] == 6:
            self.danger_type = 0
        elif self.ram[96] == 29:
            self.danger_type = 1
        else:
            self.danger_type = 2

        self.time_left1 = self.ram[27]
        self.time_left2 = self.ram[28]

        self.previous_my_energy = self.my_energy
        self.my_energy = self.ram[75]

        if self.my_energy < self.previous_my_energy and not self.still_blocked and self.ram[
                34] == 0:
            self.got_hit = STATES['GotHit']
        else:
            self.got_hit = 0

        self.previous_lifes = self.lifes
        self.lifes = self.ram[29]
        self.prev_blocked = self.blocked
        self.blocked = self.ram[61]
        if self.blocked > self.prev_blocked and not self.still_blocked:
            self.got_blocked = STATES['GotBlocked']
            self.still_blocked = True
            self.got_unblocked = 0
        elif self.blocked < self.prev_blocked and self.still_blocked:
            self.got_unblocked = STATES['GotUnblocked']
            self.still_blocked = False
            self.got_blocked = 0
        else:
            self.got_blocked = 0
            self.got_unblocked = 0

#        print ('enemy_pos=' +str(self.enemy_pos) + ' - danger_pos=' + str(self.danger_pos) + ' - my_position='
#               + str(self.my_pos) + ' - my_energy=' + str(self.my_energy) + ' - blocked=' + str(self.blocked) + ' - danger_type=' + str(self.danger_type))

        self.prev_state = self.getstate()  # remember previous state

        # print " == Update start ",self.prev_state," action",self.command

        self.current_reward = 0  # accumulate reward over all events happened during this action until next different state
        #print('self.current_reward = 0')
        self.numactions += 1  # total number of actions axecuted in this episode

        # while (self.prev_state == self.getstate()):

        if (self.firstAction):
            self.starting_pos = self.ram[74]
            self.firstAction = False
            self.current_reward = self.ale.act(a)
        else:
            self.current_reward = self.ale.act(a)

        if self.ram[34] == 0:  #only when playing
            if (a == 3 and self.starting_pos < self.my_pos) or (
                    a == 4 and self.starting_pos > self.my_pos):
                self.action_reward = STATES['MoveFW']
            elif (a == 3 and self.starting_pos > self.my_pos) or (
                    a == 4 and self.starting_pos < self.my_pos):
                self.action_reward = STATES['MoveBW']
            else:
                self.action_reward = STATES['NotMoving']

        self.score += self.current_reward
        self.current_reward += self.action_reward

        #        print('score= ' + str(self.score) + ' current reward=' +str(np.rint(self.current_reward))+ ' - energy=' + str(self.my_energy/39.0) +
        #        ' - got_hot='+ str(self.got_hit) + ' - got_blocked='  + str(self.got_blocked) + ' - got_unblocked=' + str(self.got_unblocked))
        # check if episode terminated

        #self.draw_screen

        if self.goal_reached():
            self.current_reward += STATES['Alive']
            self.ngoalreached += 1
            #self.ale.reset_game()
            self.finished = True

        if (self.ale.game_over()):
            self.current_reward += STATES['Dead']
            if self.level > 1:
                print('game over in level ' + str(self.level))
            if self.my_energy > 0 and self.lifes == 3:
                print('Game over alive????')
            self.ale.reset_game()

            self.finished = True
        if self.level > 2:
            if self.gui_visible == False:
                self.gui_visible = True
                self.initScreen()
        #print " ** Update end ",self.getstate(), " prev ",self.prev_state

    def input(self):
        self.isPressed = False
        if self.gui_visible:

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    return False

                if event.type == pygame.KEYDOWN:

                    if event.key == pygame.K_SPACE:
                        self.pause = not self.pause
                        print "Game paused: ", self.pause
                    elif event.key == pygame.K_a:
                        self.isAuto = not self.isAuto
                        self.sleeptime = int(self.isAuto) * 0.07
                    elif event.key == pygame.K_s:
                        self.sleeptime = 1.0
                        self.agent.debug = False
                    elif event.key == pygame.K_d:
                        self.sleeptime = 0.07
                        self.agent.debug = False
                    elif event.key == pygame.K_f:
                        self.sleeptime = 0.005
                        self.agent.debug = False
                    elif event.key == pygame.K_g:
                        self.sleeptime = 0.0
                        self.agent.debug = False
                    elif event.key == pygame.K_o:
                        self.optimalPolicyUser = not self.optimalPolicyUser
                        print "Best policy: ", self.optimalPolicyUser
                    elif event.key == pygame.K_q:
                        self.userquit = True
                        print "User quit !!!"
                    else:

                        pressed = pygame.key.get_pressed()

                        self.keys = 0
                        self.keys |= pressed[pygame.K_UP]
                        self.keys |= pressed[pygame.K_DOWN] << 1
                        self.keys |= pressed[pygame.K_LEFT] << 2
                        self.keys |= pressed[pygame.K_RIGHT] << 3
                        self.keys |= pressed[pygame.K_z] << 4
                        self.command = key_action_tform_table[self.keys]
                        self.key_status[self.command] = True

                if event.type == pygame.KEYUP:
                    pressed = pygame.key.get_pressed()

                    self.keys = 0
                    self.keys |= pressed[pygame.K_UP]
                    self.keys |= pressed[pygame.K_DOWN] << 1
                    self.keys |= pressed[pygame.K_LEFT] << 2
                    self.keys |= pressed[pygame.K_RIGHT] << 3
                    self.keys |= pressed[pygame.K_z] << 4
                    self.command = key_action_tform_table[self.keys]
                    self.key_status[self.command] = False
                    if not (True in self.key_status):
                        self.command = 0

        return True

    def getUserAction(self):
        return self.command

    def getreward(self):

        r = np.rint(
            self.current_reward
        ) + self.got_hit + self.got_blocked + self.got_unblocked - np.rint(
            self.blocked / 128)
        self.cumreward += r

        return r

    def print_report(self, printall=False):
        toprint = printall
        ch = ' '
        if (self.agent.optimal):
            ch = '*'
            toprint = True

        s = 'Iter %6d, sc: %3d, l: %d,  na: %4d, r: %5d %c' % (
            self.iteration, self.score, self.level, self.numactions,
            self.cumreward, ch)

        if self.score > self.hiscore:
            self.hiscore = self.score
            s += ' HISCORE '
            toprint = True
        if self.cumreward > self.hireward:
            self.hireward = self.cumreward
            s += ' HIREWARD '
            toprint = True

        if (toprint):
            print(s)

        self.cumreward100 += self.cumreward
        self.cumscore100 += self.score
        numiter = 100
        if (self.iteration % numiter == 0):
            #self.doSave()
            pgoal = float(self.ngoalreached * 100) / numiter
            print(
                '----------------------------------------------------------------------------------------------------------------------'
            )
            print(
                "%s %6d avg last 100: reward %d | score %.2f | level %d | p goals %.1f %%"
                % (self.trainsessionname, self.iteration, self.cumreward100 /
                   100, float(self.cumscore100) / 100, self.max_level, pgoal))
            print(
                '----------------------------------------------------------------------------------------------------------------------'
            )
            self.cumreward100 = 0
            self.cumscore100 = 0
            self.ngoalreached = 0

        sys.stdout.flush()

        self.resfile.write(
            "%d,%d,%d,%d\n" %
            (self.score, self.cumreward, self.goal_reached(), self.numactions))
        self.resfile.flush()

    def draw(self):
        if self.gui_visible:

            self.screen.fill((0, 0, 0))

            self.ale.getScreenRGB(self.numpy_surface)

            pygame.surfarray.blit_array(
                self.game_surface, np.transpose(self.numpy_surface, (1, 0, 2)))
            #        pygame.pixelcopy.array_to_surface(self.game_surface, np.transpose(self.numpy_surface,(1,0,2)))
            self.screen.blit(
                pygame.transform.scale2x(
                    pygame.transform.scale(
                        self.game_surface,
                        (self.screen_height, self.screen_height))), (0, 0))

            #Display ram bytes
            font = pygame.font.SysFont("Ubuntu Mono", 32)
            text = font.render("RAM: ", 1, (255, 208, 208))
            self.screen.blit(text, (430, 10))

            font = pygame.font.SysFont("Ubuntu Mono", 25)
            height = font.get_height() * 1.2

            line_pos = 40
            ram_pos = 0
            while (ram_pos < 128):
                ram_string = ''.join([
                    "%02X " % self.ram[x]
                    for x in range(ram_pos, min(ram_pos + 16, 128))
                ])
                text = font.render(ram_string, 1, (255, 255, 255))
                self.screen.blit(text, (440, line_pos))
                line_pos += height
                ram_pos += 16

            #display current action
            font = pygame.font.SysFont("Ubuntu Mono", 32)
            text = font.render("Current Action: " + str(self.command), 1,
                               (208, 208, 255))
            height = font.get_height() * 1.2
            self.screen.blit(text, (430, line_pos))
            line_pos += height

            #display reward
            font = pygame.font.SysFont("Ubuntu Mono", 30)
            text = font.render("Total Reward: " + str(self.cumreward), 1,
                               (208, 255, 255))
            self.screen.blit(text, (430, line_pos))

            pygame.display.flip()
#            clock.tick(60.)
        else:
            return 0

    def quit(self):
        self.resfile.close()
        pygame.quit()