コード例 #1
0
ファイル: ale_learning.py プロジェクト: vyraun/ALE_dqn
    def __init__(self, game_name, args):

        self.game_name = game_name
        self.logger = logger

        self.game = AleInterface(game_name, args)
        self.actions = self.game.get_actions_num()

        # DQN parameters
        self.observe = args.observe
        self.explore = args.explore
        self.replay_memory = args.replay_memory
        self.batch_size = args.batch_size
        self.gamma = args.gamma
        self.init_epsilon = args.init_epsilon
        self.final_epsilon = args.final_epsilon
        self.save_model_freq = args.save_model_freq

        self.update_frequency = args.update_frequency
        self.action_repeat = args.action_repeat

        self.frame_seq_num = args.frame_seq_num
        if args.saved_model_dir == "":
            self.param_file = "./saved_networks/%s.json" % game_name
        else:
            self.param_file = "%s/%s.json" % (args.saved_model_dir, game_name)

        self.net = DLNetwork(game_name, self.actions, args)

        # screen parameters
        # self.screen = (args.screen_width, args.screen_height)
        # pygame.display.set_caption(game_name)
        # self.display = pygame.display.set_mode(self.screen)

        self.deque = deque()
コード例 #2
0
    def __init__(self, args):

        # Save game name
        self.game_name = args.game

        # Initialize logger
        self._exptime = str(int(time.time()))
        if (args.log_dir is None):
            self._log_dir = "./log_" + self._exptime
        else:
            self._log_dir = args.log_dir
        self.logger = Logger(self._log_dir,
                             self.game_name,
                             verbosity=args.verbosity)

        # Initiallize ALE
        self.game = AleInterface(self.game_name, args)

        self.actions = self.game.get_actions_num()
        self.actionsB = self.game.get_actions_numB()

        # Set the number of iterations
        self.iterations = args.iterations

        # DQN parameters
        self.observe = args.observe
        self.explore = args.explore
        self.replay_memory = args.replay_memory
        self.batch_size = args.batch_size
        self.gamma = args.gamma
        self.init_epsilon = args.init_epsilon
        self.final_epsilon = args.final_epsilon
        self.save_model_freq = args.save_model_freq
        self.update_frequency = args.update_frequency
        self.action_repeat = args.action_repeat
        self.frame_seq_num = args.frame_seq_num
        self.save_model_at_termination = args.save_model_at_termination

        # Screen buffer for player B
        self.buffer_length = 2
        self.buffer_count = 0
        self.screen_buffer = np.empty((self.buffer_length, 80, 80),
                                      dtype=np.uint8)

        # Create folder for saved NN
        if (args.saved_model_dir == 'saved_networks'):
            args.saved_model_dir += "_" + self._exptime
        if not os.path.isdir("./" + args.saved_model_dir):
            os.makedirs("./" + args.saved_model_dir)

        # Parameters file of DQN
        self.param_file = "%s/%s.json" % (args.saved_model_dir, self.game_name)

        # Player A
        # DQN network
        self.net = DLNetwork(self.actions, self.logger, args)

        # Player B
        # SARSA learner and network
        self.sarsa_agent = self.sarsa_init(args)

        # Experience double ended queue
        self.deque = deque()
コード例 #3
0
ファイル: ale_learning.py プロジェクト: vyraun/ALE_dqn
class DQNLearning(object):
    def __init__(self, game_name, args):

        self.game_name = game_name
        self.logger = logger

        self.game = AleInterface(game_name, args)
        self.actions = self.game.get_actions_num()

        # DQN parameters
        self.observe = args.observe
        self.explore = args.explore
        self.replay_memory = args.replay_memory
        self.batch_size = args.batch_size
        self.gamma = args.gamma
        self.init_epsilon = args.init_epsilon
        self.final_epsilon = args.final_epsilon
        self.save_model_freq = args.save_model_freq

        self.update_frequency = args.update_frequency
        self.action_repeat = args.action_repeat

        self.frame_seq_num = args.frame_seq_num
        if args.saved_model_dir == "":
            self.param_file = "./saved_networks/%s.json" % game_name
        else:
            self.param_file = "%s/%s.json" % (args.saved_model_dir, game_name)

        self.net = DLNetwork(game_name, self.actions, args)

        # screen parameters
        # self.screen = (args.screen_width, args.screen_height)
        # pygame.display.set_caption(game_name)
        # self.display = pygame.display.set_mode(self.screen)

        self.deque = deque()

    def param_serierlize(self, epsilon, step):
        json.dump({"epsilon": epsilon, "step": step}, open(self.param_file, "w"))

    def param_unserierlize(self):
        if os.path.exists(self.param_file):
            jd = json.load(open(self.param_file, 'r'))
            return jd['epsilon'], jd["step"]
        else:
            return self.init_epsilon, 0

    def process_snapshot(self, snap_shot):
        # rgb to gray, and resize
        snap_shot = cv2.cvtColor(cv2.resize(snap_shot, (80, 80)), cv2.COLOR_BGR2GRAY)
        # image binary
        # _, snap_shot = cv2.threshold(snap_shot, 1, 255, cv2.THRESH_BINARY)
        return snap_shot

    def show_screen(self, np_array):
        return
        # np_array = cv2.resize(np_array, self.screen)
        # surface = pygame.surfarray.make_surface(np_array)
        # surface = pygame.transform.rotate(surface, 270)
        # rect = pygame.draw.rect(self.display, (255, 255, 255), (0, 0, self.screen[0], self.screen[1]))
        # self.display.blit(surface, rect)
        # pygame.display.update()

    def train_net(self):
        # training
        max_reward = 0
        epsilon, global_step = self.param_unserierlize()
        step = 0
        epoch = 0
        while True:  # loop epochs
            epoch += 1
            # initial state
            self.game.reset_game()
            # initial state sequences
            state_seq = np.empty((80, 80, self.frame_seq_num))
            for i in range(self.frame_seq_num):
                state = self.game.get_screen_rgb()
                self.show_screen(state)
                state = self.process_snapshot(state)
                state_seq[:, :, i] = state
            stage_reward = 0
            while True:  # loop game frames
                # select action
                best_act = self.net.predict([state_seq])[0]
                if random.random() <= epsilon or len(np.unique(best_act)) == 1:  # random select
                    action = random.randint(0, self.actions - 1)
                else:
                    action = np.argmax(best_act)
                # carry out selected action
                reward_n = self.game.act(action)
                state_n = self.game.get_screen_rgb()
                self.show_screen(state)
                state_n = self.process_snapshot(state_n)
                state_n = np.reshape(state_n, (80, 80, 1))
                state_seq_n = np.append(state_n, state_seq[:, :, : (self.frame_seq_num - 1)], axis=2)
                terminal_n = self.game.game_over()
                # scale down epsilon
                if step > self.observe and epsilon > self.final_epsilon:
                    epsilon -= (self.init_epsilon - self.final_epsilon) / self.explore
                # store experience
                act_onehot = np.zeros(self.actions)
                act_onehot[action] = 1
                self.deque.append((state_seq, act_onehot, reward_n, state_seq_n, terminal_n))
                if len(self.deque) > self.replay_memory:
                    self.deque.popleft()
                # minibatch train
                if step > self.observe and step % self.update_frequency == 0:
                    for _ in xrange(self.action_repeat):
                        mini_batch = random.sample(self.deque, self.batch_size)
                        batch_state_seq = [item[0] for item in mini_batch]
                        batch_action = [item[1] for item in mini_batch]
                        batch_reward = [item[2] for item in mini_batch]
                        batch_state_seq_n = [item[3] for item in mini_batch]
                        batch_terminal = [item[4] for item in mini_batch]
                        # predict
                        target_rewards = []
                        batch_pred_act_n = self.net.predict(batch_state_seq_n)
                        for i in xrange(len(mini_batch)):
                            if batch_terminal[i]:
                                t_r = batch_reward[i]
                            else:
                                t_r = batch_reward[i] + self.gamma * np.max(batch_pred_act_n[i])
                            target_rewards.append(t_r)
                        # train Q network
                        self.net.fit(batch_state_seq, batch_action, target_rewards)
                # update state
                state_seq = state_seq_n
                step += 1
                # serierlize param
                # save network model
                if step % self.save_model_freq == 0:
                    global_step += step
                    self.param_serierlize(epsilon, global_step)
                    self.net.save_model("%s-dqn" % self.game_name, global_step=global_step)
                    self.logger.info("save network model, global_step=%d, cur_step=%d" % (global_step, step))
                # state description
                if step < self.observe:
                    state_desc = "observe"
                elif epsilon > self.final_epsilon:
                    state_desc = "explore"
                else:
                    state_desc = "train"
                # record reward
                print "game running, step=%d, action=%s, reward=%d, max_Q=%.6f, min_Q=%.6f" % \
                          (step, action, reward_n, np.max(best_act), np.min(best_act))
                if reward_n > stage_reward:
                    stage_reward = reward_n
                if terminal_n:
                    break
            # record reward
            if stage_reward > max_reward:
                max_reward = stage_reward
            self.logger.info(
                "epoch=%d, state=%s, step=%d(%d), max_reward=%d, epsilon=%.5f, reward=%d, max_Q=%.6f" %
                (epoch, state_desc, step, global_step, max_reward, epsilon, stage_reward, np.max(best_act)))

    def play_game(self, epsilon):
        # play games
        max_reward = 0
        epoch = 0
        if epsilon == 0.0:
            epsilon, _ = self.param_unserierlize()
        while True:  # epoch
            epoch += 1
            self.logger.info("game start...")
            # init state
            self.game.reset_game()
            state_seq = np.empty((80, 80, self.frame_seq_num))
            for i in range(self.frame_seq_num):
                state = self.game.get_screen_rgb()
                self.show_screen(state)
                state = self.process_snapshot(state)
                state_seq[:, :, i] = state
            stage_step = 0
            stage_reward = 0
            while True:
                # select action
                best_act = self.net.predict([state_seq])[0]
                if random.random() < epsilon or len(np.unique(best_act)) == 1:
                    action = random.randint(0, self.actions - 1)
                else:
                    action = np.argmax(best_act)
                # carry out selected action
                reward_n = self.game.act(action)
                state_n = self.game.get_screen_rgb()
                self.show_screen(state_n)
                state_n = self.process_snapshot(state_n)
                state_n = np.reshape(state_n, (80, 80, 1))
                state_seq_n = np.append(state_n, state_seq[:, :, : (self.frame_seq_num - 1)], axis=2)
                terminal_n = self.game.game_over()

                state_seq = state_seq_n
                # record
                if reward_n > stage_reward:
                    stage_reward = reward_n
                if terminal_n:
                    time.sleep(2)
                    break
                else:
                    stage_step += 1
                    stage_reward = reward_n
                    print "game running, step=%d, action=%d, reward=%d" % \
                          (stage_step, action, reward_n)
            # record reward
            if stage_reward > max_reward:
                max_reward = stage_reward
            self.logger.info("game over, epoch=%d, step=%d, reward=%d, max_reward=%d" %
                             (epoch, stage_step, stage_reward, max_reward))
コード例 #4
0
class ALEtestbench(object):
    def __init__(self, args):

        # Save game name
        self.game_name = args.game

        # Initialize logger
        self._exptime = str(int(time.time()))
        if (args.log_dir is None):
            self._log_dir = "./log_" + self._exptime
        else:
            self._log_dir = args.log_dir
        self.logger = Logger(self._log_dir,
                             self.game_name,
                             verbosity=args.verbosity)

        # Initiallize ALE
        self.game = AleInterface(self.game_name, args)

        self.actions = self.game.get_actions_num()
        self.actionsB = self.game.get_actions_numB()

        # Set the number of iterations
        self.iterations = args.iterations

        # DQN parameters
        self.observe = args.observe
        self.explore = args.explore
        self.replay_memory = args.replay_memory
        self.batch_size = args.batch_size
        self.gamma = args.gamma
        self.init_epsilon = args.init_epsilon
        self.final_epsilon = args.final_epsilon
        self.save_model_freq = args.save_model_freq
        self.update_frequency = args.update_frequency
        self.action_repeat = args.action_repeat
        self.frame_seq_num = args.frame_seq_num
        self.save_model_at_termination = args.save_model_at_termination

        # Screen buffer for player B
        self.buffer_length = 2
        self.buffer_count = 0
        self.screen_buffer = np.empty((self.buffer_length, 80, 80),
                                      dtype=np.uint8)

        # Create folder for saved NN
        if (args.saved_model_dir == 'saved_networks'):
            args.saved_model_dir += "_" + self._exptime
        if not os.path.isdir("./" + args.saved_model_dir):
            os.makedirs("./" + args.saved_model_dir)

        # Parameters file of DQN
        self.param_file = "%s/%s.json" % (args.saved_model_dir, self.game_name)

        # Player A
        # DQN network
        self.net = DLNetwork(self.actions, self.logger, args)

        # Player B
        # SARSA learner and network
        self.sarsa_agent = self.sarsa_init(args)

        # Experience double ended queue
        self.deque = deque()

    #SARSA agent init
    def sarsa_init(self, args):

        random_seed = random.randint(0, 20)  #0-19
        rng = np.random.RandomState(random_seed)

        #Check the presence of NN file for sarsa when in play mode
        if ((args.handle == 'play') and (args.nn_file is None)):
            raise Exception('Error: no SARSA NN file to load')

        if args.nn_file is None:
            # New network creation
            self.logger.info("Creating network for SARSA")
            sarsa_network = q_network.DeepQLearner(
                args.screen_width,
                args.screen_height,
                self.actionsB,
                args.phi_length,  #num_frames
                args.discount,
                args.learning_rate,
                args.rms_decay,  #rho
                args.rms_epsilon,
                args.momentum_sarsa,
                1,  #clip_delta
                10000,  #freeze_interval
                args.batch_size,  #batch_size
                args.network_type,
                args.update_rule,
                # args.lambda_decay, #batch_accumulator
                'sum',
                rng)
        else:
            #Pretrained network loading
            #Mandatory for play mode, optional for training
            network_file_handle = open(args.nn_file, 'r')
            sarsa_network = cPickle.load(network_file_handle)

        self.logger.info("Creating SARSA agent")
        sarsa_agent_inst = SARSALambdaAgent(sarsa_network, args,
                                            args.epsilon_min,
                                            args.epsilon_decay,
                                            args.experiment_prefix,
                                            self.logger, rng)

        return sarsa_agent_inst

    """ Merge the previous two screen images """

    def sarsa_get_observation(self, args):

        assert self.buffer_count >= 2
        index = self.buffer_count % self.buffer_length - 1
        max_image = np.maximum(self.screen_buffer[index, ...],
                               self.screen_buffer[index - 1, ...])
        return max_image
        #return self.sarsa_resize_image(max_image, args)

    """ Appropriately resize a single image """

    def sarsa_resize_image(self, image, args):

        if args.resize_method == 'crop':
            # resize keeping aspect ratio
            resize_height = int(
                round(float(self.height) * self.resized_width / self.width))

            resized = cv2.resize(image, (self.resized_width, resize_height),
                                 interpolation=cv2.INTER_LINEAR)

            # Crop the part we want
            crop_y_cutoff = resize_height - CROP_OFFSET - self.resized_height
            cropped = resized[crop_y_cutoff:crop_y_cutoff +
                              self.resized_height, :]

            return cropped
        elif args.resize_method == 'scale':
            return cv2.resize(image, (args.screen_width, args.screen_height),
                              interpolation=cv2.INTER_LINEAR)
        else:
            raise ValueError('Unrecognized image resize method.')

    def param_serierlize(self, epsilon, step):
        json.dump({
            "epsilon": epsilon,
            "step": step
        }, open(self.param_file, "w"))

    def param_unserierlize(self):
        if os.path.exists(self.param_file):
            jd = json.load(open(self.param_file, 'r'))
            return jd['epsilon'], jd["step"]
        else:
            self.logger.info("Using epsilon: %d" % self.init_epsilon)
            return self.init_epsilon, 0

    def process_snapshot(self, snap_shot):
        # rgb to gray, and resize
        snap_shot = cv2.cvtColor(cv2.resize(snap_shot, (80, 80)),
                                 cv2.COLOR_BGR2GRAY)
        # image binary
        # _, snap_shot = cv2.threshold(snap_shot, 1, 255, cv2.THRESH_BINARY)
        return snap_shot

    """ Training function, this is the main loop for training """

    def train_net(self, args):

        self.logger.info("Running training experiment")
        self.logger.info("Player A (DQN agent), settings:")
        self.logger.info("Initial Epsilon: %.6f" % (args.init_epsilon))
        self.logger.info("Epsilon decay rate: 1/%.6f epsilon per episode" %
                         (args.explore))
        self.logger.info("Final Epsilon: %.6f" % (args.final_epsilon))
        self.logger.info("Training starting...")

        #self.logger.info("Player B (SARSA agent), settings:")

        # Initiallize variables
        epoch = 0
        epsilon, global_step = self.param_unserierlize()
        max_game_iterations = self.iterations

        # Epochs loop
        while True:

            # Initiallize variables
            game_dif_score = 0
            playerA_score = 0
            playerB_score = 0
            step = 0

            # initial state
            self.game.ale.reset_game()

            # Two players mode switch ON
            self.game.ale.setMode(1)

            # Initial state sequences
            state_seq = np.empty(
                (args.screen_width, args.screen_height, self.frame_seq_num))
            for i in range(self.frame_seq_num):
                state = self.game.ale.getScreenRGB()
                state = self.process_snapshot(state)
                state_seq[:, :, i] = state

            # Player B initiallization flag
            playerB_is_uninitiallized = True

            #Episode loop, each turn on the loop is a "step" and it implies a new game frame
            while True:

                # Select action player A
                self.logger.info("Selecting player A action")
                best_act = self.net.predict([state_seq])[0]

                # Prevent player A to take actions on the first two frames to add fairness
                if step < 2:
                    actionA = 0
                else:
                    # Normal epsilon-greedy policy
                    if (random.random() <= epsilon) or (len(
                            np.unique(best_act)) == 1):
                        actionA = random.randint(0, self.actions -
                                                 1)  # random action selection
                        self.logger.info(
                            "Random action selected for player A actionA=%d" %
                            (actionA))
                    else:
                        actionA = np.argmax(best_act)  # Best action selection
                        self.logger.info(
                            "Action selected for player A actionA=%d" %
                            (actionA))

                # Select action player B
                self.logger.info("Selecting player B action")
                if self.buffer_count >= self.buffer_length + 1:
                    if (playerB_is_uninitiallized == True):
                        self.logger.info("Initiallize playerB")
                        actionB = self.sarsa_agent.start_episode(
                            self.sarsa_get_observation(args))
                        actionB += 18  #TODO again fix this, it is anoying!!
                        playerB_is_uninitiallized = False
                    else:
                        actionB = self.sarsa_agent.step(
                            -reward_n, playerB_observation)
                        actionB += 18  #TODO again fix this, it is anoying!!
                else:
                    actionB = 18  #TODO fix this we must use just one value

                self.logger.info("Action selected for player B actionB=%d" %
                                 (actionB))

                # Carry out selected actions
                self.logger.info("Executing actions")
                reward_n = self.game.ale.actAB(actionA, actionB)

                #Save_scores
                game_dif_score += reward_n
                if (reward_n > 0):
                    playerA_score += reward_n
                elif (reward_n < 0):
                    playerB_score += -reward_n

                # Getting screen image
                state_n = self.game.ale.getScreenRGB()
                state_n_grayscale = self.process_snapshot(state_n)

                # Get observation for player A
                state_n = np.reshape(state_n_grayscale, (80, 80, 1))
                state_seq_n = np.append(state_n,
                                        state_seq[:, :, :(self.frame_seq_num -
                                                          1)],
                                        axis=2)
                self.logger.info("Player A observation over")

                # Get observation for player B
                screen_buffer_index = self.buffer_count % self.buffer_length
                self.screen_buffer[screen_buffer_index,
                                   ...] = state_n_grayscale
                # Wait until the buffer is full
                if self.buffer_count >= self.buffer_length:
                    playerB_observation = self.sarsa_get_observation(args)
                # Reset buffer counter to avoid overflow
                if self.buffer_count == (10 * self.buffer_length):
                    self.buffer_count = self.buffer_length + 1
                else:
                    self.buffer_count += 1
                self.logger.info("Player B observation over")

                # Check game over state
                terminal_n = self.game.ale.game_over()

                #TODO add frame limit

                # Scale down epsilon
                if (step > self.observe) and (epsilon > self.final_epsilon):
                    epsilon -= (self.init_epsilon -
                                self.final_epsilon) / self.explore

                # Store experience
                act_onehot = np.zeros(self.actions)
                act_onehot[actionA] = 1
                self.deque.append(
                    (state_seq, act_onehot, reward_n, state_seq_n, terminal_n))
                if len(self.deque) > self.replay_memory:
                    self.deque.popleft()

                # DQN Minibatch train
                if ((step > self.observe)
                        and ((step % self.update_frequency) == 0)):
                    self.logger.info("Player A training")
                    for _ in xrange(self.action_repeat):
                        mini_batch = random.sample(self.deque, self.batch_size)
                        batch_state_seq = [item[0] for item in mini_batch]
                        batch_action = [item[1] for item in mini_batch]
                        batch_reward = [item[2] for item in mini_batch]
                        batch_state_seq_n = [item[3] for item in mini_batch]
                        batch_terminal = [item[4] for item in mini_batch]

                        # Predict
                        target_rewards = []
                        batch_pred_act_n = self.net.predict(batch_state_seq_n)

                        for i in xrange(len(mini_batch)):
                            if batch_terminal[i]:
                                t_r = batch_reward[i]
                            else:
                                t_r = batch_reward[i] + self.gamma * np.max(
                                    batch_pred_act_n[i])
                            target_rewards.append(t_r)
                        # Train Q network
                        self.net.fit(batch_state_seq, batch_action,
                                     target_rewards)

                # Update state sequence
                state_seq = state_seq_n

                # Increase step counter
                step += 1

                # Save network model if using the save_model_freq param
                if ((step % self.save_model_freq)
                        == 0) and (self.save_model_at_termination == False):
                    global_step += step
                    self.param_serierlize(epsilon, global_step)
                    self.net.save_model("%s-dqn" % self.game_name,
                                        global_step=global_step)
                    self.logger.info(
                        "Saving network model, global_step=%d, cur_step=%d" %
                        (global_step, step))

                # Player A state description
                if step < self.observe:
                    state_desc = "observe"
                elif epsilon > self.final_epsilon:
                    state_desc = "explore"
                else:
                    state_desc = "train"

                # Record step information
                self.logger.exp([
                    epoch, step, actionA, actionB, reward_n, game_dif_score,
                    playerA_score, playerB_score, epsilon
                ])

                # End the episode
                if terminal_n:
                    self.logger.info("Episode %d is over" % epoch)

                    self.sarsa_agent.end_episode(-reward_n)

                    # Save DQN network model
                    global_step += step
                    self.param_serierlize(epsilon, global_step)
                    self.net.save_model("%s-dqn" % self.game_name,
                                        global_step=global_step)
                    self.logger.info(
                        "Saving network model, global_step=%d, cur_step=%d" %
                        (global_step, step))

                    #Save sarsa model
                    net_file = open(
                        args.saved_model_dir + '/' + self.game_name + '_' +
                        str(epoch) + '.pkl', 'w')
                    cPickle.dump(self.sarsa_agent.network, net_file, -1)
                    net_file.close()

                    break
                #****************** END episode loop ****************************************

            # Log end of epoch
            self.logger.info(
                "epoch=%d, state=%s, step=%d(%d), epsilon=%.5f, reward=%d, max_Q=%.6f"
                % (epoch, state_desc, step, global_step, epsilon,
                   np.max(best_act)))

            # Increase epoch counter
            epoch += 1

            # Break the loop after max_game_iterations
            if epoch >= max_game_iterations:
                self.sarsa_agent.finish_epoch(epoch)
                break
            #****************** END epoch loop ****************************************

    """ Play function, this is the main loop for using pretrained networks """

    def play_game(self, args):

        # Init vars
        epoch = 0
        epsilon = args.play_epsilon  #default 1

        max_game_iterations = self.iterations

        #Load epsilon value
        if epsilon == 0.0:
            epsilon, _ = self.param_unserierlize()
        else:
            self.logger.info("Using epsilong: %d" % epsilon)

        # Epochs loop
        while True:

            #reinitialize sarsa agent, to force forgetting the learned net from the previous episode
            self.sarsa_agent = self.sarsa_init(args)

            self.logger.info("Episode start...")

            # Initiallize variables
            stage_step = 0
            game_dif_score = 0
            playerA_score = 0
            playerB_score = 0

            # Reset the game in ale
            self.game.ale.reset_game()

            # two players mode switch
            self.game.ale.setMode(1)

            # Init state sequence
            state_seq = np.empty(
                (args.screen_width, args.screen_height, self.frame_seq_num))
            for i in range(self.frame_seq_num):
                state = self.game.get_screen_rgb()
                state = self.process_snapshot(state)
                state_seq[:, :, i] = state

            # Player B initiallization flag
            playerB_is_uninitiallized = True

            # Episodes loop
            while True:

                # Get action player A
                self.logger.info("Selecting player A action")
                best_act = self.net.predict([state_seq])[0]

                # Prevent player A to take actions on the first two frames to add fairness
                if len(np.unique(best_act)) == 1:
                    actionA = random.randint(0, self.actions - 1)
                    self.logger.info(
                        "Random action selected for player A actionA=%d" %
                        (actionA))
                else:
                    actionA = np.argmax(best_act)
                    self.logger.info(
                        "Action selected for player A actionA=%d" % (actionA))

                # Get action player B
                self.logger.info("Selecting player B action")
                if self.buffer_count >= self.buffer_length + 1:
                    if (playerB_is_uninitiallized == True):
                        self.logger.info("Initiallize playerB")
                        actionB = self.sarsa_agent.start_episode(
                            self.sarsa_get_observation(args))
                        actionB += 18  #TODO again fix this, it is anoying!!
                        playerB_is_uninitiallized = False
                    else:
                        actionB = self.sarsa_agent.step(
                            -reward_n, playerB_observation)
                        actionB += 18  #TODO again fix this, it is anoying!!
                else:
                    actionB = 18  #TODO fix this we must use just one value

                self.logger.info("Action selected for player B actionB=%d" %
                                 (actionB))

                # Carry out selected actions
                self.logger.info("Executing actions")
                reward_n = self.game.ale.actAB(actionA, actionB)

                #Save_scores
                game_dif_score += reward_n
                if (reward_n > 0):
                    playerA_score += reward_n
                elif (reward_n < 0):
                    playerB_score += -reward_n

                # Getting screen image
                state_n = self.game.ale.getScreenRGB()
                state_n_grayscale = self.process_snapshot(state_n)

                # Get observation for player A
                #state_n = np.reshape(state_n_grayscale, (args.screen_width, args.screen_height, 1))
                state_n = np.reshape(state_n_grayscale, (80, 80, 1))
                state_seq_n = np.append(state_n,
                                        state_seq[:, :, :(self.frame_seq_num -
                                                          1)],
                                        axis=2)
                self.logger.info("Player A observation over")

                # Get observation for player B
                screen_buffer_index = self.buffer_count % self.buffer_length
                self.screen_buffer[screen_buffer_index,
                                   ...] = state_n_grayscale
                # Wait until the buffer is full
                if self.buffer_count >= self.buffer_length:
                    playerB_observation = self.sarsa_get_observation(args)
                # Reset buffer counter to avoid overflow
                if self.buffer_count == (10 * self.buffer_length):
                    self.buffer_count = self.buffer_length + 1
                else:
                    self.buffer_count += 1
                self.logger.info("Player B observation over")

                # Check game over state
                terminal_n = self.game.game_over()

                # Update state sequence
                state_seq = state_seq_n

                # Increase step counter
                stage_step += 1

                # Record step information
                self.logger.exp([
                    epoch, stage_step, actionA, actionB, reward_n,
                    game_dif_score, playerA_score, playerB_score, epsilon
                ])

                # End the episode
                if terminal_n:
                    self.sarsa_agent.end_episode(-reward_n)
                    self.logger.info("Episode %d is over" % epoch)
                    time.sleep(2)
                    break
                #****************** END episode loop ****************************************

            self.logger.info(
                "game over, epoch=%d, steps=%d, score_player_A=%d, score_player_B=%d, difference=%d"
                % (epoch, stage_step, playerA_score, playerB_score,
                   game_dif_score))

            #Increase the epochs counter
            epoch += 1

            # Break the loop after max_game_iterations
            if epoch >= max_game_iterations:
                #self.sarsa_agent.finish_epoch(epoch) we are in playing mode so no need for saving the network
                break