Ejemplo n.º 1
0
class TrainPipeline():
    def __init__(self, init_model=None):

        # params of the board and the game
        self.game = Game()
        # training params
        self.config = TrainConfig()
        self.greedy_config = TrainGreedyConfig()
        self.data_buffer = deque(maxlen=self.config.buffer_size)
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet()

        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.config.c_puct,
                                      n_playout=self.config.n_playout,
                                      is_selfplay=1)
        self.mcts_player_greedy = MCTSPlayerGreedy(
            self.policy_value_net.policy_value_fn,
            c_puct=self.greedy_config.c_puct,
            n_playout=self.greedy_config.n_playout,
            is_selfplay=1)

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(
                self.mcts_player,
                temp=self.config.temp,
                greedy_player=self.mcts_player_greedy,
                who_greedy="B")
            play_data = list(play_data)
            # augment the data
            play_data = symmetry_board_moves(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        state_batch = [data[0] for data in self.data_buffer]
        mcts_probs_batch = [data[1] for data in self.data_buffer]
        winner_batch = [data[2] for data in self.data_buffer]
        self.policy_value_net.train(state_batch, mcts_probs_batch,
                                    winner_batch, self.config.epochs)
        self.policy_value_net.save_model("model.h5")

    def run(self):
        """run the training pipeline"""
        try:
            self.collect_selfplay_data(self.config.play_batch_size)
            self.policy_update()
        except KeyboardInterrupt:
            print('\n\rquit')

    def summary(self):
        self.policy_value_net.model.summary()
Ejemplo n.º 2
0
class NetTrainer:
    """
    网络训练器
    """
    def __init__(self, init_model=None):
        # 棋盘宽度
        self.board_width = 11
        # 棋盘高度
        self.board_height = 11
        # 连子胜利数
        self.n_in_row = 5
        # 自我对弈次数
        self.self_game_num = 5000
        # 自奕指定次数后,检查棋力
        self.check_freq = 50
        # 重复训练次数
        self.repeat_train_epochs = 5
        # 训练时MCTS模拟次数
        self.train_play_out_n = 2000
        # 学习速度
        # self.learn_rate = 2e-3
        # 尝试修改学习速度
        self.learn_rate = 2e-4
        # 批量训练数据大小
        # self.batch_size = 500
        self.batch_size = 1000
        # 训练池最大大小
        self.buffer_max_size = 20000
        # 训练数据缓冲池
        self.data_buffer = deque(maxlen=self.buffer_max_size)

        # 对手MCTS模拟次数
        self.rival_play_out_n = 5000
        # 最佳胜率
        self.best_win_ratio = 0.0

        # 初始化策略网络
        self.brain = PolicyValueNet(self.board_width, self.board_height,
                                    init_model)
        # 初始化自我对弈玩家
        self.self_player = MCTSSelfPlayer(self.brain.policy_value,
                                          self.train_play_out_n)

    def self_play_once(self):
        """
        完成一次自我对弈, 收集训练数据
        :return: 训练数据
        """

        game_board = GomokuBoard(width=self.board_width,
                                 height=self.board_height,
                                 n_in_row=self.n_in_row)
        batch_states = []
        batch_probs = []
        current_players = []

        while True:
            action, acts_probs = self.self_player.get_action(game_board)
            # 保存当前状态
            batch_states.append(game_board.state())
            # 保存当前状态下进行各个动作的概率
            batch_probs.append(acts_probs)
            # 保存当前玩家
            current_players.append(game_board.current_player)

            # 执行动作
            game_board.move(action)

            # 检查游戏是否结束
            end, winner = game_board.check_winner()
            if end:
                batch_values = np.zeros(len(current_players))
                # 如果不是和局则将胜利者的状态值设置为1, 失败者的状态值设置为-1
                if winner == GomokuPlayer.Nobody:
                    batch_values[np.array(current_players) == winner] = 1.0
                    batch_values[np.array(current_players) != winner] = -1.0
                batch_values = np.reshape(batch_values, [-1, 1])
                return winner, list(
                    zip(batch_states, batch_probs, batch_values))

    def get_equi_data(self, play_data):
        """
        获取等价数据(旋转和镜像)
        :param play_data:
        :return:
        """
        extend_data = []
        for state, prob, value in play_data:
            for i in [1, 2, 3, 4]:
                # 旋转数据
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_prob = np.rot90(
                    prob.reshape(self.board_height, self.board_width), i)
                extend_data.append((equi_state, equi_prob.flatten(), value))

                # 左右镜像
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_prob = np.fliplr(equi_prob)
                extend_data.append((equi_state, equi_prob.flatten(), value))

        return extend_data

    def policy_evaluate(self, show=None):
        """
        棋力评估
        :return: 胜率
        """
        net_player = MCTSPlayer(self.brain.policy_value, self.train_play_out_n)
        mcts_player = MCTSPlayer(rollout_policy_value, self.rival_play_out_n)

        net_win = 0

        # 神经网络玩家执黑棋先手
        for i in range(5):
            game_board = GomokuBoard(width=self.board_width,
                                     height=self.board_height,
                                     n_in_row=self.n_in_row)
            while True:
                action = net_player.get_action(game_board)
                game_board.move(action)
                end, winner = game_board.check_winner()
                if show:
                    game_board.dbg_print()
                if end:
                    if show:
                        print(winner)
                    if winner == GomokuPlayer.Black:
                        net_win += 1
                    break

                action = mcts_player.get_action(game_board)
                game_board.move(action)
                end, winner = game_board.check_winner()
                if show:
                    game_board.dbg_print()
                if end:
                    if show:
                        print(winner)
                    break

        # MCTS玩家执黑棋先手
        for i in range(5):
            game_board = GomokuBoard(width=self.board_width,
                                     height=self.board_height,
                                     n_in_row=self.n_in_row)
            while True:
                action = mcts_player.get_action(game_board)
                game_board.move(action)
                end, winner = game_board.check_winner()
                if show:
                    game_board.dbg_print()
                if end:
                    if show:
                        print(winner)
                    break

                action = net_player.get_action(game_board)
                game_board.move(action)
                end, winner = game_board.check_winner()
                if show:
                    game_board.dbg_print()
                if end:
                    if show:
                        print(winner)
                    if winner == GomokuPlayer.White:
                        net_win += 1
                    break

        return net_win / 10

    def run(self):
        """
        开始训练
        """
        black_win_num = 0
        white_win_num = 0
        nobody_win_num = 0

        for i in range(self.self_game_num):

            # 收集训练数据
            print("Self Game: {}".format(i))
            winner, play_data = self.self_play_once()
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)
            if winner == GomokuPlayer.Black:
                black_win_num += 1
            elif winner == GomokuPlayer.White:
                white_win_num += 1
            else:
                nobody_win_num += 1
            print("Black: {:.2f} White: {:.2f} Nobody: {:.2f}".format(
                black_win_num / (i + 1), white_win_num / (i + 1),
                nobody_win_num / (i + 1)))

            # 积累一些数据后, 进行训练
            if len(self.data_buffer) > (self.batch_size * 2):
                mini_batch = random.sample(self.data_buffer, self.batch_size)

                batch_states = [data[0] for data in mini_batch]
                batch_probs = [data[1] for data in mini_batch]
                batch_values = [data[2] for data in mini_batch]

                total_loss = 0.0
                total_entropy = 0.0
                for j in range(self.repeat_train_epochs):
                    loss, entropy = self.brain.train(batch_states, batch_probs,
                                                     batch_values,
                                                     self.learn_rate)
                    total_loss += loss
                    total_entropy += entropy
                print("Loss: {:.2f}, Entropy: {:.2f}".format(
                    total_loss / self.repeat_train_epochs,
                    total_entropy / self.repeat_train_epochs))

            if (i + 1) % self.check_freq == 0:
                self.brain.save_model(".\\CurrentModel\\GomokuAi")
                win_ratio = self.policy_evaluate()
                print("Rival({}), Net Win Ratio: {:.2f}".format(
                    self.rival_play_out_n, win_ratio))
                if win_ratio > self.best_win_ratio:
                    self.best_win_ratio = win_ratio
                    self.brain.save_model(".\\BestModel\\GomokuAi")
                if self.best_win_ratio >= 1.0:
                    self.best_win_ratio = 0.0
                    self.rival_play_out_n += 1000
Ejemplo n.º 3
0
from Game import Game
from Player import MCTSPlayer, HumanPlayer
from policy_value_net import PolicyValueNet
from utils import symmetry_board_moves
game = Game()
goat = HumanPlayer()
pvnet = PolicyValueNet("models/model.h5")
pvnet_fn = pvnet.policy_value_fn
bagh = MCTSPlayer(pvnet_fn, n_playout=500)
data = game.start_play(bagh, goat)
data = [x for x in data]
data = symmetry_board_moves(data)
state_batch = [x[0] for x in data]
mcts_probs_batch = [x[1] for x in data]
winner_batch = [x[2] for x in data]
pvnet.train(state_batch, mcts_probs_batch, winner_batch, 5)
pvnet.save_model("model.h5")