def __init__(self, config):
        # config see README.md
        # gomoku
        self.n = config['n']
        self.n_in_row = config['n_in_row']
        self.gomoku_gui = GomokuGUI(config['n'], config['human_color'])
        self.action_size = self.n**2

        # train
        self.num_iters = config['num_iters']
        self.num_eps = config['num_eps']
        self.check_freq = config['check_freq']
        self.contest_num = config['contest_num']
        self.dirichlet_alpha = config['dirichlet_alpha']
        self.temp = config['temp']
        self.update_threshold = config['update_threshold']
        self.explore_num = config['explore_num']

        self.examples_buffer = deque([],
                                     maxlen=config['examples_buffer_max_len'])

        # neural network
        self.batch_size = config['batch_size']
        self.mcts_use_gpu = config['mcts_use_gpu']

        self.nnet = NeuralNetWorkWrapper(config['lr'], config['l2'],
                                         config['kl_targ'], config['epochs'],
                                         config['num_channels'], config['n'],
                                         self.action_size, self.mcts_use_gpu)
        # mcts
        self.num_mcts_sims = config['num_mcts_sims']
        self.c_puct = config['c_puct']
        self.c_virtual_loss = config['c_virtual_loss']
        self.thread_pool_size = config['thread_pool_size']
Beispiel #2
0
    def __init__(self, config):
        # see config.py
        # gomoku
        self.n = config['n']
        self.n_in_row = config['n_in_row']
        self.gomoku_gui = GomokuGUI(config['n'], config['human_color'])
        self.action_size = config['action_size']

        # train
        self.num_iters = config['num_iters']
        self.num_eps = config['num_eps']
        self.num_train_threads = config['num_train_threads']
        self.check_freq = config['check_freq']
        self.num_contest = config['num_contest']
        self.dirichlet_alpha = config['dirichlet_alpha']
        self.temp = config['temp']
        self.update_threshold = config['update_threshold']
        self.num_explore = config['num_explore']

        self.examples_buffer = deque([],
                                     maxlen=config['examples_buffer_max_len'])

        # mcts
        self.num_mcts_sims = config['num_mcts_sims']
        self.c_puct = config['c_puct']
        self.c_virtual_loss = config['c_virtual_loss']
        self.num_mcts_threads = config['num_mcts_threads']
        self.libtorch_use_gpu = config['libtorch_use_gpu']

        # neural network
        self.batch_size = config['batch_size']
        self.epochs = config['epochs']
        self.nnet = NeuralNetWorkWrapper(config['lr'], config['l2'],
                                         config['num_layers'],
                                         config['num_channels'], config['n'],
                                         self.action_size,
                                         config['train_use_gpu'],
                                         self.libtorch_use_gpu)

        # start gui
        t = threading.Thread(target=self.gomoku_gui.loop)
        t.start()
Beispiel #3
0
    def __init__(self, config):

        # logging train messages
        logging.basicConfig(filename = config['train_log_file'], 
            level = logging.DEBUG, format = '%(message)s', filemode = 'w+')

        # board config
        self.n = config['n']
        self.n_in_row = config['n_in_row']
        self.action_size = config['action_size']

        # train config
        self.num_iters = config['num_iters']
        self.num_eps = config['num_eps']
        self.num_train_threads = config['num_train_threads']
        self.check_freq = config['check_freq']
        self.num_contest = config['num_contest']
        self.dirichlet_alpha = config['dirichlet_alpha']
        self.temp = config['temp']
        self.update_threshold = config['update_threshold']
        self.num_explore = config['num_explore']
        self.examples_buffer = deque([], maxlen = config['examples_buffer_max_len'])

        # mcts config
        self.libtorch_use_gpu = config['libtorch_use_gpu']
        self.num_mcts_sims = config['num_mcts_sims']
        self.c_puct = config['c_puct']
        self.c_virtual_loss = config['c_virtual_loss']
        self.num_mcts_threads = config['num_mcts_threads']

        # nn config
        self.batch_size = config['batch_size']
        self.epochs = config['epochs']
        self.nnet = NeuralNetWorkWrapper(config['lr'], config['l2'], config['num_layers'], 
            config['num_channels'], config['n'], config['action_size'], config['train_use_gpu'], config['libtorch_use_gpu'])
        
        # train debug config
        self.show_train_board = config['show_train_board']
Beispiel #4
0
class Leaner():
    def __init__(self, config):
        # see config.py
        # gomoku
        self.n = config['n']
        self.n_in_row = config['n_in_row']
        self.gomoku_gui = GomokuGUI(config['n'], config['human_color'])
        self.action_size = config['action_size']

        # train
        self.num_iters = config['num_iters']
        self.num_eps = config['num_eps']
        self.num_train_threads = config['num_train_threads']
        self.check_freq = config['check_freq']
        self.num_contest = config['num_contest']
        self.dirichlet_alpha = config['dirichlet_alpha']
        self.temp = config['temp']
        self.update_threshold = config['update_threshold']
        self.num_explore = config['num_explore']

        self.examples_buffer = deque([],
                                     maxlen=config['examples_buffer_max_len'])

        # mcts
        self.num_mcts_sims = config['num_mcts_sims']
        self.c_puct = config['c_puct']
        self.c_virtual_loss = config['c_virtual_loss']
        self.num_mcts_threads = config['num_mcts_threads']
        self.libtorch_use_gpu = config['libtorch_use_gpu']

        # neural network
        self.batch_size = config['batch_size']
        self.epochs = config['epochs']
        self.nnet = NeuralNetWorkWrapper(config['lr'], config['l2'],
                                         config['num_layers'],
                                         config['num_channels'], config['n'],
                                         self.action_size,
                                         config['train_use_gpu'],
                                         self.libtorch_use_gpu)

        # start gui
        t = threading.Thread(target=self.gomoku_gui.loop)
        t.start()

    def learn(self):
        # train the model by self play

        if path.exists(path.join('models', 'checkpoint.example')):
            print("loading checkpoint...")
            self.nnet.load_model()
            self.load_samples()
        else:
            # save torchscript
            self.nnet.save_model()
            self.nnet.save_model('models', "best_checkpoint")

        for itr in range(1, self.num_iters + 1):
            print("ITER :: {}".format(itr))

            # self play in parallel
            libtorch = NeuralNetwork(
                './models/checkpoint.pt', self.libtorch_use_gpu,
                self.num_mcts_threads * self.num_train_threads)
            itr_examples = []
            with concurrent.futures.ThreadPoolExecutor(
                    max_workers=self.num_train_threads) as executor:
                futures = [
                    executor.submit(self.self_play, 1 if itr % 2 else -1,
                                    libtorch, k == 1)
                    for k in range(1, self.num_eps + 1)
                ]
                for k, f in enumerate(futures):
                    examples = f.result()
                    itr_examples += examples

                    # decrease libtorch batch size
                    remain = min(
                        len(futures) - (k + 1), self.num_train_threads)
                    libtorch.set_batch_size(
                        max(remain * self.num_mcts_threads, 1))
                    print("EPS: {}, EXAMPLES: {}".format(k + 1, len(examples)))

            # release gpu memory
            del libtorch

            # prepare train data
            self.examples_buffer.append(itr_examples)
            train_data = reduce(lambda a, b: a + b, self.examples_buffer)
            random.shuffle(train_data)

            # train neural network
            epochs = self.epochs * (len(itr_examples) + self.batch_size -
                                    1) // self.batch_size
            self.nnet.train(train_data, self.batch_size, int(epochs))
            self.nnet.save_model()
            self.save_samples()

            # compare performance
            if itr % self.check_freq == 0:
                libtorch_current = NeuralNetwork(
                    './models/checkpoint.pt', self.libtorch_use_gpu,
                    self.num_mcts_threads * self.num_train_threads // 2)
                libtorch_best = NeuralNetwork(
                    './models/best_checkpoint.pt', self.libtorch_use_gpu,
                    self.num_mcts_threads * self.num_train_threads // 2)

                one_won, two_won, draws = self.contest(libtorch_current,
                                                       libtorch_best,
                                                       self.num_contest)
                print("NEW/PREV WINS : %d / %d ; DRAWS : %d" %
                      (one_won, two_won, draws))

                if one_won + two_won > 0 and float(one_won) / (
                        one_won + two_won) > self.update_threshold:
                    print('ACCEPTING NEW MODEL')
                    self.nnet.save_model('models', "best_checkpoint")
                else:
                    print('REJECTING NEW MODEL')

                # release gpu memory
                del libtorch_current
                del libtorch_best

    def self_play(self, first_color, libtorch, show):
        """
        This function executes one episode of self-play, starting with player 1.
        As the game is played, each turn is added as a training example to
        train_examples. The game is played till the game ends. After the game
        ends, the outcome of the game is used to assign values to each example
        in train_examples.
        """
        train_examples = []

        player1 = MCTS(libtorch, self.num_mcts_threads, self.c_puct,
                       self.num_mcts_sims, self.c_virtual_loss,
                       self.action_size)
        player2 = MCTS(libtorch, self.num_mcts_threads, self.c_puct,
                       self.num_mcts_sims, self.c_virtual_loss,
                       self.action_size)
        players = [player2, None, player1]
        player_index = 1

        gomoku = Gomoku(self.n, self.n_in_row, first_color)

        if show:
            self.gomoku_gui.reset_status()

        episode_step = 0
        while True:
            episode_step += 1
            player = players[player_index + 1]

            # get action prob
            if episode_step <= self.num_explore:
                prob = np.array(
                    list(player.get_action_probs(gomoku, self.temp)))
            else:
                prob = np.array(list(player.get_action_probs(gomoku, 0)))

            # generate sample
            board = tuple_2d_to_numpy_2d(gomoku.get_board())
            last_action = gomoku.get_last_move()
            cur_player = gomoku.get_current_color()

            sym = self.get_symmetries(board, prob, last_action)
            for b, p, a in sym:
                train_examples.append([b, a, cur_player, p])

            # dirichlet noise
            legal_moves = list(gomoku.get_legal_moves())
            noise = 0.1 * np.random.dirichlet(
                self.dirichlet_alpha * np.ones(np.count_nonzero(legal_moves)))

            prob = 0.9 * prob
            j = 0
            for i in range(len(prob)):
                if legal_moves[i] == 1:
                    prob[i] += noise[j]
                    j += 1
            prob /= np.sum(prob)

            # execute move
            action = np.random.choice(len(prob), p=prob)

            if show:
                self.gomoku_gui.execute_move(cur_player, action)
            gomoku.execute_move(action)
            player1.update_with_move(action)
            player2.update_with_move(action)

            # next player
            player_index = -player_index

            # is ended
            ended, winner = gomoku.get_game_status()
            if ended == 1:
                # b, last_action, cur_player, p, v
                return [(x[0], x[1], x[2], x[3], x[2] * winner)
                        for x in train_examples]

    def contest(self, network1, network2, num_contest):
        """compare new and old model
           Args: player1, player2 is neural network
           Return: one_won, two_won, draws
        """
        one_won, two_won, draws = 0, 0, 0

        with concurrent.futures.ThreadPoolExecutor(
                max_workers=self.num_train_threads) as executor:
            futures = [executor.submit(\
                self._contest, network1, network2, 1 if k <= num_contest // 2 else -1, k == 1) for k in range(1, num_contest + 1)]
            for f in futures:
                winner = f.result()
                if winner == 1:
                    one_won += 1
                elif winner == -1:
                    two_won += 1
                else:
                    draws += 1

        return one_won, two_won, draws

    def _contest(self, network1, network2, first_player, show):
        # create MCTS
        player1 = MCTS(network1, self.num_mcts_threads, self.c_puct,
                       self.num_mcts_sims, self.c_virtual_loss,
                       self.action_size)
        player2 = MCTS(network2, self.num_mcts_threads, self.c_puct,
                       self.num_mcts_sims, self.c_virtual_loss,
                       self.action_size)

        # prepare
        players = [player2, None, player1]
        player_index = first_player
        gomoku = Gomoku(self.n, self.n_in_row, first_player)
        if show:
            self.gomoku_gui.reset_status()

        # play
        while True:
            player = players[player_index + 1]

            # select best move
            prob = player.get_action_probs(gomoku)
            best_move = int(np.argmax(np.array(list(prob))))

            # execute move
            gomoku.execute_move(best_move)
            if show:
                self.gomoku_gui.execute_move(player_index, best_move)

            # check game status
            ended, winner = gomoku.get_game_status()
            if ended == 1:
                return winner

            # update search tree
            player1.update_with_move(best_move)
            player2.update_with_move(best_move)

            # next player
            player_index = -player_index

    def get_symmetries(self, board, pi, last_action):
        # mirror, rotational
        assert (len(pi) == self.action_size)  # 1 for pass

        pi_board = np.reshape(pi, (self.n, self.n))
        last_action_board = np.zeros((self.n, self.n))
        last_action_board[last_action // self.n][last_action % self.n] = 1
        l = []

        for i in range(1, 5):
            for j in [True, False]:
                newB = np.rot90(board, i)
                newPi = np.rot90(pi_board, i)
                newAction = np.rot90(last_action_board, i)
                if j:
                    newB = np.fliplr(newB)
                    newPi = np.fliplr(newPi)
                    newAction = np.fliplr(last_action_board)
                l += [(newB, newPi.ravel(),
                       np.argmax(newAction) if last_action != -1 else -1)]
        return l

    def play_with_human(self,
                        human_first=True,
                        checkpoint_name="best_checkpoint"):
        # load best model
        libtorch_best = NeuralNetwork('./models/best_checkpoint.pt',
                                      self.libtorch_use_gpu, 12)
        mcts_best = MCTS(libtorch_best, self.num_mcts_threads * 3, \
             self.c_puct, self.num_mcts_sims * 6, self.c_virtual_loss, self.action_size)

        # create gomoku game
        human_color = self.gomoku_gui.get_human_color()
        gomoku = Gomoku(self.n, self.n_in_row,
                        human_color if human_first else -human_color)

        players = ["alpha", None, "human"
                   ] if human_color == 1 else ["human", None, "alpha"]
        player_index = human_color if human_first else -human_color

        self.gomoku_gui.reset_status()

        while True:
            player = players[player_index + 1]

            # select move
            if player == "alpha":
                prob = mcts_best.get_action_probs(gomoku)
                best_move = int(np.argmax(np.array(list(prob))))
                self.gomoku_gui.execute_move(player_index, best_move)
            else:
                self.gomoku_gui.set_is_human(True)
                # wait human action
                while self.gomoku_gui.get_is_human():
                    time.sleep(0.1)
                best_move = self.gomoku_gui.get_human_move()

            # execute move
            gomoku.execute_move(best_move)

            # check game status
            ended, winner = gomoku.get_game_status()
            if ended == 1:
                break

            # update tree search
            mcts_best.update_with_move(best_move)

            # next player
            player_index = -player_index

        print("HUMAN WIN" if winner == human_color else "ALPHA ZERO WIN")

    def load_samples(self, folder="models", filename="checkpoint.example"):
        """load self.examples_buffer
        """

        filepath = path.join(folder, filename)
        with open(filepath, 'rb') as f:
            self.examples_buffer = pickle.load(f)

    def save_samples(self, folder="models", filename="checkpoint.example"):
        """save self.examples_buffer
        """

        if not path.exists(folder):
            mkdir(folder)

        filepath = path.join(folder, filename)
        with open(filepath, 'wb') as f:
            pickle.dump(self.examples_buffer, f, -1)
class Learner():
    def __init__(self, config):
        # gomoku
        self.n = config['n']
        self.n_in_row = config['n_in_row']
        # self.gomoku_gui = GomokuGUI(config['n'], config['human_color'])
        self.action_size = config['action_size']

        # train
        self.num_iters = config['num_iters']
        self.num_eps = config['num_eps']
        self.num_train_threads = config['num_train_threads']
        self.check_freq = config['check_freq']
        self.num_contest = config['num_contest']
        self.dirichlet_alpha = config['dirichlet_alpha']
        self.temp = config['temp']
        self.update_threshold = config['update_threshold']
        self.num_explore = config['num_explore']

        self.examples_buffer = deque([],
                                     maxlen=config['examples_buffer_max_len'])

        # mcts
        self.num_mcts_sims = config['num_mcts_sims']
        self.c_puct = config['c_puct']
        self.c_virtual_loss = config['c_virtual_loss']
        self.num_mcts_threads = config['num_mcts_threads']
        self.libtorch_use_gpu = config['libtorch_use_gpu']

        # neural network
        self.batch_size = config['batch_size']
        self.epochs = config['epochs']
        self.nnet = NeuralNetWorkWrapper(config['lr'], config['l2'],
                                         config['num_layers'],
                                         config['num_channels'], config['n'],
                                         self.action_size,
                                         config['train_use_gpu'],
                                         self.libtorch_use_gpu)

        # start gui
        # t = threading.Thread(target=self.gomoku_gui.loop)
        # t.start()

    def learn(self, model_dir, model_id):
        # train the model by self play

        model_path = path.join(model_dir, str(model_id))
        assert path.exists(model_path +
                           '.pkl'), f"{model_path+'.pkl'} not exists!!!"
        print(f"loading {model_id}-th model")
        self.nnet.load_model(model_path)

        # model_id = 0
        # if model_dir==None:
        #     print("debug mode: best_model_dir = join('..','build','weights', str(model_id))")
        #     model_dir = path.join('..','build','weights')
        # model_path = path.join(model_dir, str(model_id))
        # if path.exists(model_path+'.pkl'):
        #     print(f"loading {model_id}-th model")
        #     self.nnet.load_model(model_path)
        #     #self.load_samples()
        # else:
        #     print("prepare: save 0-th model")
        #     # save torchscript
        #     # self.nnet.save_model()
        #     self.nnet.save_model(model_path)

        data_path = path.join('..', 'build', 'data')
        train_data = self.load_samples(data_path)
        random.shuffle(train_data)

        # train neural network
        epochs = self.epochs * (len(train_data) + self.batch_size -
                                1) // self.batch_size
        self.nnet.train(train_data, min(self.batch_size, len(train_data)),
                        int(epochs))

        model_path = path.join(model_dir, str(model_id + 1))
        self.nnet.save_model(model_path)

    def get_symmetries(self, board, pi, last_action):
        # mirror, rotational
        assert (len(pi) == self.action_size)  # 1 for pass

        pi_board = np.reshape(pi, (self.n, self.n))
        last_action_board = np.zeros((self.n, self.n))
        last_action_board[last_action // self.n][last_action % self.n] = 1
        l = []

        for i in range(1, 5):
            for j in [True, False]:

                newB = np.rot90(board, i)
                newPi = np.rot90(pi_board, i)
                newAction = np.rot90(last_action_board, i)
                if j:
                    newB = np.fliplr(newB)
                    newPi = np.fliplr(newPi)
                    newAction = np.fliplr(last_action_board)
                l += [(newB, newPi.ravel(),
                       np.argmax(newAction) if last_action != -1 else -1)]
        return l

    def load_samples(self, folder):
        """load self.examples_buffer
        """
        BOARD_SIZE = self.n
        train_examples = []
        data_files = os.listdir(folder)
        for file_name in data_files:
            file_path = path.join(folder, file_name)
            with open(file_path, 'rb') as binfile:
                # size = os.path.getsize(filepath) #获得文件大小
                step = binfile.read(4)
                step = int().from_bytes(step, byteorder='little', signed=True)
                board = np.zeros((step, BOARD_SIZE * BOARD_SIZE))
                for i in range(step):
                    for j in range(BOARD_SIZE * BOARD_SIZE):
                        data = binfile.read(4)
                        data = int().from_bytes(data,
                                                byteorder='little',
                                                signed=True)
                        board[i][j] = data
                board = np.reshape(board, (-1, BOARD_SIZE, BOARD_SIZE))
                prob = np.zeros((step, BOARD_SIZE * BOARD_SIZE))
                for i in range(step):
                    for j in range(BOARD_SIZE * BOARD_SIZE):
                        data = binfile.read(4)
                        data = struct.unpack('f', data)[0]
                        prob[i][j] = data
                        # p = p.reshape((-1,BOARD_SIZE,BOARD_SIZE))
                    # print(p)

                v = []
                for i in range(step):
                    data = binfile.read(4)
                    data = int().from_bytes(data,
                                            byteorder='little',
                                            signed=True)
                    v.append(data)
                    # print(v)

                color = []
                for i in range(step):
                    data = binfile.read(4)
                    data = int().from_bytes(data,
                                            byteorder='little',
                                            signed=True)
                    color.append(data)

                last_action = []
                for i in range(step):
                    data = binfile.read(4)
                    data = int().from_bytes(data,
                                            byteorder='little',
                                            signed=True)
                    last_action.append(data)

                for i in range(step):
                    sym = self.get_symmetries(board[i], prob[i],
                                              last_action[i])
                    for b, p, a in sym:
                        train_examples.append([b, a, color[i], p, v[i]])
        return train_examples
Beispiel #6
0
class TrainPipeline():
    def __init__(self, config):

        # logging train messages
        logging.basicConfig(filename = config['train_log_file'], 
            level = logging.DEBUG, format = '%(message)s', filemode = 'w+')

        # board config
        self.n = config['n']
        self.n_in_row = config['n_in_row']
        self.action_size = config['action_size']

        # train config
        self.num_iters = config['num_iters']
        self.num_eps = config['num_eps']
        self.num_train_threads = config['num_train_threads']
        self.check_freq = config['check_freq']
        self.num_contest = config['num_contest']
        self.dirichlet_alpha = config['dirichlet_alpha']
        self.temp = config['temp']
        self.update_threshold = config['update_threshold']
        self.num_explore = config['num_explore']
        self.examples_buffer = deque([], maxlen = config['examples_buffer_max_len'])

        # mcts config
        self.libtorch_use_gpu = config['libtorch_use_gpu']
        self.num_mcts_sims = config['num_mcts_sims']
        self.c_puct = config['c_puct']
        self.c_virtual_loss = config['c_virtual_loss']
        self.num_mcts_threads = config['num_mcts_threads']

        # nn config
        self.batch_size = config['batch_size']
        self.epochs = config['epochs']
        self.nnet = NeuralNetWorkWrapper(config['lr'], config['l2'], config['num_layers'], 
            config['num_channels'], config['n'], config['action_size'], config['train_use_gpu'], config['libtorch_use_gpu'])
        
        # train debug config
        self.show_train_board = config['show_train_board']

    def learn(self):
        if path.exists(path.join('models', 'checkpoint.example')):
            logging.debug('loading checkpoint...')
            self.nnet.load_model()
            self.load_samples() 
        else:
            self.nnet.save_model()
            self.nnet.save_model('models', 'best_checkpoint')
        
        for itr in range(1, self.num_iters + 1):
            logging.debug('-' * 65)
            logging.debug('iter: {}'.format(itr))
            logging.debug('-' * 65)

            libtorch = NeuralNetwork('./models/checkpoint.pt', self.libtorch_use_gpu, self.num_mcts_threads * self.num_train_threads)

            itr_examples = []
            with concurrent.futures.ThreadPoolExecutor(max_workers = self.num_train_threads) as executor:
                futures = [executor.submit(self.self_play, libtorch, 
                    self.show_train_board if k == 0 else False) for k in range(self.num_eps)]
                for k, f in enumerate(futures):
                    examples = f.result()
                    remain = min(len(futures) - (k + 1), self.num_train_threads)
                    libtorch.set_batch_size(max(remain * self.num_mcts_threads, 1))
                    itr_examples.extend(examples)
                    logging.debug('eps: {}, examples: {}, moves: {}'.format(k + 1, len(examples), len(examples) // 8) )

            del libtorch

            self.examples_buffer.append(itr_examples)
            train_data = reduce(lambda a, b : a + b, self.examples_buffer)

            # the number of train data cannot less than batch size
            if len(train_data) >= self.batch_size:
                random.shuffle(train_data)
                epochs = (len(itr_examples) + self.batch_size - 1) // self.batch_size * self.epochs
                epoch_res = self.nnet.train(train_data, self.batch_size, int(epochs))
                for epo, loss, entropy in epoch_res:
                    logging.debug("epoch: {}, loss: {}, entropy: {}".format(epo, loss, entropy))
                self.nnet.save_model()

            self.save_samples()
            
            # evaluate the new model every check_freq iters
            if itr % self.check_freq == 0:
                num_half_threads = max(self.num_mcts_threads * self.num_train_threads // 2, 1)
                libtorch_current = NeuralNetwork('./models/checkpoint.pt', self.libtorch_use_gpu, num_half_threads)
                libtorch_best = NeuralNetwork('./models/best_checkpoint.pt', self.libtorch_use_gpu, num_half_threads)

                win_cnt, lose_cnt, draw_cnt = self.contest(libtorch_current, libtorch_best, self.num_contest)
                logging.debug('new vs. prev: {:d} wins, {:d} loses, {:d} draws'.format(win_cnt, lose_cnt, draw_cnt))

                # accept when the win rate greater than update_threshold
                if win_cnt + lose_cnt > 0 and win_cnt / (win_cnt + lose_cnt) > self.update_threshold:
                    logging.debug('new model accepted.')
                    self.nnet.save_model('models', 'best_checkpoint')
                else:
                    logging.debug('new model rejected')

                del libtorch_current
                del libtorch_best
    
    def self_play(self, libtorch, show):
        if show:
            print('display of a self play round begins\n')

        train_examples = []
        player = AlphaZero(libtorch, self.num_mcts_threads, self.num_mcts_sims, self.c_puct, self.c_virtual_loss)
        board = Board(self.n, self.n_in_row)

        episode_step = 0
        while True:
            episode_step += 1

            # have exploration in the first num_explore steps
            if episode_step <= self.num_explore:
                prob = np.array(list(player.get_action_probs(board, self.temp)))

                # add dirichlet noise
                legal_moves = list(board.get_moves())
                noise = 0.25 * np.random.dirichlet(self.dirichlet_alpha * np.ones(len(legal_moves)))
                prob = 0.75 * prob
                for i in range(len(legal_moves)):
                    prob[legal_moves[i]] += noise[i]
                prob /= np.sum(prob)
            else:
                prob = np.array(list(player.get_action_probs(board, 0)))

            # get action according to prob
            action = np.random.choice(len(prob), p = prob)

            states = board.get_encode_states()
            cur_player = board.get_cur_player()
            
            # get equivalent data, augment the dataset
            sym = self.get_symmetries(states, prob)
            for s, p in sym:
                train_examples.append([s, p, cur_player])

            board.exec_move(action)

            if show:
                # display the action probability
                for i in range(self.action_size):
                    if i % self.n == 0:
                        print()
                    if i == action:
                        print('\033[31;1m{:.3f}\033[0m'.format(prob[i]), end = ' ')
                    else:
                        print('{:.3f}'.format(prob[i]), end = ' ')
                print('\n')
                # display the board
                board.display()

            player.update_with_move(action)

            ended, winner = board.get_result()
            if ended:
                if show:
                    print('display of a self play round finished\n')
                return [(x[0], x[1], x[2] * winner) for x in train_examples]

    def contest(self, network1, network2, num_contest):
        win_cnt, lose_cnt, draw_cnt = 0, 0, 0
        with concurrent.futures.ThreadPoolExecutor(max_workers = self.num_train_threads) as executor:
            futures = [executor.submit(self._contest, network1, network2, 
                1 if k < num_contest // 2 else -1, self.show_train_board if k == 0 else 0) for k in range(num_contest)]
            for f in futures:
                winner = f.result()
                if winner == 1:
                    win_cnt += 1
                elif winner == -1:
                    lose_cnt += 1
                else:
                    draw_cnt += 1
        return win_cnt, lose_cnt, draw_cnt

    def _contest(self, network1, network2, start_player, show):
        if show:
            print('display of a contest round begins\n')

        player1 = AlphaZero(network1, self.num_mcts_threads, self.num_mcts_sims, self.c_puct, self.c_virtual_loss)
        player2 = AlphaZero(network2, self.num_mcts_threads, self.num_mcts_sims, self.c_puct, self.c_virtual_loss)
        players = [player2, None, player1]
        player_index = start_player
        board = Board(self.n, self.n_in_row, start_player)

        while True:
            player = players[player_index + 1]
            best_move = player.get_action(board)
            board.exec_move(best_move)

            if show:
                board.display()
            
            ended, winner = board.get_result()
            if ended == 1:
                if show:
                    print('display of a contest round finished\n')
                return winner
            
            player1.update_with_move(best_move)
            player2.update_with_move(best_move)
            player_index = -player_index

    def get_symmetries(self, states, prob):
        prob = np.reshape(prob, (self.n, self.n))
        res = []
        for i in range(4):
            # augment dataset by rotate the board
            equi_states = np.array([np.rot90(s, i) for s in states])
            equi_prob = np.rot90(prob, i)
            res.append((equi_states, equi_prob.ravel()))
            # augment dataset by flip the board
            equi_states = np.array([np.fliplr(s) for s in equi_states])
            equi_prob = np.fliplr(equi_prob)
            res.append((equi_states, equi_prob.ravel()))
        return res

    def load_samples(self, folder = 'models', filename = 'checkpoint.example'):
        filepath = path.join(folder, filename)
        with open(filepath, 'rb') as f:
            self.examples_buffer = pickle.load(f)
        
    def save_samples(self, folder = 'models', filename = 'checkpoint.example'):
        if not path.exists(folder):
            mkdir(folder)
        filepath = path.join(folder, filename)
        with open(filepath, 'wb') as f:
            pickle.dump(self.examples_buffer, f, -1)
class Leaner():
    def __init__(self, config):
        # config see README.md
        # gomoku
        self.n = config['n']
        self.n_in_row = config['n_in_row']
        self.gomoku_gui = GomokuGUI(config['n'], config['human_color'])
        self.action_size = self.n**2

        # train
        self.num_iters = config['num_iters']
        self.num_eps = config['num_eps']
        self.check_freq = config['check_freq']
        self.contest_num = config['contest_num']
        self.dirichlet_alpha = config['dirichlet_alpha']
        self.temp = config['temp']
        self.update_threshold = config['update_threshold']
        self.explore_num = config['explore_num']

        self.examples_buffer = deque([],
                                     maxlen=config['examples_buffer_max_len'])

        # neural network
        self.batch_size = config['batch_size']
        self.mcts_use_gpu = config['mcts_use_gpu']

        self.nnet = NeuralNetWorkWrapper(config['lr'], config['l2'],
                                         config['kl_targ'], config['epochs'],
                                         config['num_channels'], config['n'],
                                         self.action_size, self.mcts_use_gpu)
        # mcts
        self.num_mcts_sims = config['num_mcts_sims']
        self.c_puct = config['c_puct']
        self.c_virtual_loss = config['c_virtual_loss']
        self.thread_pool_size = config['thread_pool_size']

    def learn(self):
        # train the model by self play
        t = threading.Thread(target=self.gomoku_gui.loop)
        t.start()

        if os.path.exists('./models/checkpoint'):
            print("loading checkpoint...")
            self.nnet.load_model('models', "checkpoint")
            self.load_samples("models", "checkpoint")

        # generate .pt for libtorch
        self.nnet.save_model('models', "checkpoint")
        self.nnet.save_model('models', "best_checkpoint")

        for i in range(1, self.num_iters + 1):
            print("ITER ::: " + str(i))

            # self play
            first_color = 1
            for eps in range(1, self.num_eps + 1):
                examples = self.self_play(first_color)
                self.examples_buffer.extend(examples)\

                first_color = -first_color
                print("EPS :: " + str(eps) + ", EXAMPLES :: " +
                      str(len(examples)))

            # sample train data
            if len(self.examples_buffer) < self.batch_size:
                continue

            print("sampling...")
            train_data = sample(self.examples_buffer, self.batch_size)

            # train neural network
            self.nnet.train(train_data)
            self.nnet.save_model('models', "checkpoint")

            if i % self.check_freq == 0:
                self.save_samples("models", "checkpoint")

                # compare performance
                mcts = MCTS("./models/checkpoint.pt", self.thread_pool_size,
                            self.c_puct, self.num_mcts_sims,
                            self.c_virtual_loss, self.action_size,
                            self.mcts_use_gpu)
                mcts_best = MCTS("./models/best_checkpoint.pt",
                                 self.thread_pool_size, self.c_puct,
                                 self.num_mcts_sims, self.c_virtual_loss,
                                 self.action_size, self.mcts_use_gpu)

                one_won, two_won, draws = self.contest(mcts, mcts_best,
                                                       self.contest_num)
                print("NEW/PREV WINS : %d / %d ; DRAWS : %d" %
                      (one_won, two_won, draws))

                if one_won + two_won > 0 and float(one_won) / (
                        one_won + two_won) > self.update_threshold:
                    print('ACCEPTING NEW MODEL')
                    self.nnet.save_model('models', "best_checkpoint")
                else:
                    print('REJECTING NEW MODEL')

                del mcts
                del mcts_best

        t.join()

    def self_play(self, first_color):
        """
        This function executes one episode of self-play, starting with player 1.
        As the game is played, each turn is added as a training example to
        train_examples. The game is played till the game ends. After the game
        ends, the outcome of the game is used to assign values to each example
        in train_examples.
        """

        train_examples = []
        gomoku = Gomoku(self.n, self.n_in_row, first_color)
        mcts = MCTS("./models/checkpoint.pt", self.thread_pool_size,
                    self.c_puct, self.num_mcts_sims, self.c_virtual_loss,
                    self.action_size, self.mcts_use_gpu)

        episode_step = 0
        while True:
            episode_step += 1

            # prob
            temp = self.temp if episode_step <= self.explore_num else 0
            prob = np.array(list(mcts.get_action_probs(gomoku, temp)))

            # generate sample
            board = tuple_2d_to_numpy_2d(gomoku.get_board())
            last_action = gomoku.get_last_move()
            cur_player = gomoku.get_current_color()

            sym = self.get_symmetries(board, prob)
            for b, p in sym:
                train_examples.append([b, last_action, cur_player, p])

            # dirichlet noise
            legal_moves = list(gomoku.get_legal_moves())
            noise = 0.25 * np.random.dirichlet(
                self.dirichlet_alpha * np.ones(np.count_nonzero(legal_moves)))

            prob_noise = 0.75 * prob
            j = 0
            for i in range(len(prob_noise)):
                if legal_moves[i] == 1:
                    prob_noise[i] += noise[j]
                    j += 1
            prob_noise /= np.sum(prob_noise)
            action = np.random.choice(len(prob_noise), p=prob_noise)

            # execute move
            gomoku.execute_move(action)
            mcts.update_with_move(action)

            # is ended
            ended, winner = gomoku.get_game_status()
            if ended == 1:
                # b, last_action, cur_player, p, v
                return [(x[0], x[1], x[2], x[3], x[2] * winner)
                        for x in train_examples]

    def contest(self, player1, player2, contest_num):
        """compare new and old model
           Args: player1, player2 is white/balck player
           Return: one_won, two_won, draws
        """

        one_won, two_won, draws = 0, 0, 0

        for i in range(contest_num):
            if i < contest_num // 2:
                # first half, white first
                winner = self._contest(player1, player2, 1)
            else:
                # second half, black first
                winner = self._contest(player1, player2, -1)

            if winner == 1:
                one_won += 1
            elif winner == -1:
                two_won += 1
            else:
                draws += 1

        return one_won, two_won, draws

    def _contest(self, player1, player2, first_player):
        # old model play with new model

        players = [player2, None, player1]
        player_index = first_player
        gomoku = Gomoku(self.n, self.n_in_row, first_player)
        self.gomoku_gui.reset_status()

        while True:
            player = players[player_index + 1]

            # select best move
            prob = player.get_action_probs(gomoku)
            best_move = int(np.argmax(np.array(list(prob))))

            # execute move
            gomoku.execute_move(best_move)
            self.gomoku_gui.execute_move(player_index, best_move)

            # check game status
            ended, winner = gomoku.get_game_status()
            if ended == 1:
                return winner

            # update search tree
            player1.update_with_move(best_move)
            player2.update_with_move(best_move)

            # next player
            player_index = -player_index

    def get_symmetries(self, board, pi):
        # mirror, rotational
        assert (len(pi) == self.action_size)  # 1 for pass

        pi_board = np.reshape(pi, (self.n, self.n))
        l = []

        for i in range(1, 5):
            for j in [True, False]:
                newB = np.rot90(board, i)
                newPi = np.rot90(pi_board, i)
                if j:
                    newB = np.fliplr(newB)
                    newPi = np.fliplr(newPi)
                l += [(newB, newPi.ravel())]
        return l

    def play_with_human(self,
                        human_first=True,
                        checkpoint_name="best_checkpoint"):
        t = threading.Thread(target=self.gomoku_gui.loop)
        t.start()

        # load best model
        mcts_best = MCTS("./models/best_checkpoint.pt", self.thread_pool_size,
                         self.c_puct, self.num_mcts_sims * 2,
                         self.c_virtual_loss, self.action_size,
                         self.mcts_use_gpu)

        # create gomoku game
        human_color = self.gomoku_gui.get_human_color()
        gomoku = Gomoku(self.n, self.n_in_row,
                        human_color if human_first else -human_color)

        players = ["alpha", None, "human"
                   ] if human_color == 1 else ["human", None, "alpha"]
        player_index = human_color if human_first else -human_color

        while True:
            player = players[player_index + 1]

            # select move
            if player == "alpha":
                prob = mcts_best.get_action_probs(gomoku)
                best_move = int(np.argmax(np.array(list(prob))))
                self.gomoku_gui.execute_move(player_index, best_move)
            else:
                self.gomoku_gui.set_is_human(True)
                # wait human action
                while self.gomoku_gui.get_is_human():
                    time.sleep(0.1)
                best_move = self.gomoku_gui.get_human_move()

            # execute move
            gomoku.execute_move(best_move)

            # check game status
            ended, winner = gomoku.get_game_status()
            if ended == 1:
                break

            # update tree search
            mcts_best.update_with_move(best_move)

            # next player
            player_index = -player_index

        print("human win" if winner == human_color else "alpha win")

        t.join()

    def load_samples(self, folder="models", filename="checkpoint"):
        """load self.examples_buffer
        """

        filepath = os.path.join(folder, filename + '.example')
        with open(filepath, 'rb') as f:
            self.examples_buffer = pickle.load(f)

    def save_samples(self, folder="models", filename="checkpoint"):
        """save self.examples_buffer
        """

        if not os.path.exists(folder):
            os.mkdir(folder)

        filepath = os.path.join(folder, filename + '.example')
        with open(filepath, 'wb') as f:
            pickle.dump(self.examples_buffer, f, -1)