Ejemplo n.º 1
0
class TrainPipeline():
    def __init__(self, init_model=None):
        # params of the board and the game
        self.board_width = 15
        self.board_height = 15
        self.n_in_row = 5
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 400  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 200
        self.game_batch_num = 5000
        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 1000
        #init_model = "best_policy.pt"
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model,
                                                   use_gpu=True)

        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   use_gpu=True)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(
                    np.flipud(
                        mcts_porb.reshape(self.board_height,
                                          self.board_width)), i)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    def collect_selfplay_data(self, n_games=10):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss,
                                                  entropy, explained_var_old,
                                                  explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=15):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            print("running for game", i)
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                    i + 1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # check the performance of the current model,
                # and save the model params
                if (i + 1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i + 1))
                    win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model('./current_policy.model')
                    if win_ratio >= self.best_win_ratio:
                        print("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model('./best_policy.pt')
                        if (self.best_win_ratio == 1.0
                                and self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')
Ejemplo n.º 2
0
class TrainPipeline():
    def __init__(self, init_model=None):
        # 设置棋盘和游戏的参数
        self.board_width = 10
        self.board_height = 10
        self.n_in_row = 5
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # 设置训练参数
        self.learn_rate = 2e-3  # 基准学习率
        self.lr_multiplier = 1.2  # 基于KL自动调整学习倍速
        self.temp = 1.0  # 温度参数
        self.n_playout = 400  # 每下一步棋,模拟的步骤数
        self.c_puct = 5  # exploitation和exploration之间的折中系数
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)  #使用 deque 创建一个双端队列
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02  # 早停检查
        self.check_freq = 50  # 每50次检查一次,策略价值网络是否更新
        self.game_batch_num = 2000  # 训练多少个epoch
        self.best_win_ratio = 0.0  # 当前最佳胜率,用他来判断是否有更好的模型
        # 弱AI(纯MCTS)模拟步数,用于给训练的策略AI提供对手
        self.pure_mcts_playout_num = 1000
        if init_model:
            # 通过init_model设置策略网络
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model,
                                                   use_gpu=True)
        else:
            # 训练一个新的策略网络
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   use_gpu=True)
        # AI Player,设置is_selfplay=1 自我对弈,因为是在进行训练
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    # 通过旋转和翻转增加数据集, play_data: [(state, mcts_prob, winner_z), ..., ...]
    def get_equi_data(self, play_data):
        extend_data = []
        for state, mcts_porb, winner in play_data:
            # 在4个方向上进行expand,每个方向都进行旋转,水平翻转
            for i in [1, 2, 3, 4]:
                # 逆时针旋转
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(
                    np.flipud(
                        mcts_porb.reshape(self.board_height,
                                          self.board_width)), i)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # 水平翻转
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    # 收集自我对弈数据,用于训练
    def collect_selfplay_data(self, n_games=1):
        for i in range(n_games):
            # 与MCTS Player进行对弈
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            play_data = list(play_data)[:]
            # 保存下了多少步
            self.episode_len = len(play_data)
            # 增加数据 play_data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    # 更新策略网络
    def policy_update(self):
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        # 保存更新前的old_probs, old_v
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            # 每次训练,调整参数,返回loss和entropy
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            # 输入状态,得到行动的可能性和状态值,按照batch进行输入
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            # 计算更新前后两次的loss差
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # 动态调整学习倍率 lr_multiplier
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss,
                                                  entropy, explained_var_old,
                                                  explained_var_new))
        return loss, entropy

    # 用于评估训练网络的质量,评估一共10场play,返回比赛胜率(赢1分、输0分、平0.5分)
    def policy_evaluate(self, n_games=10):
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            # AI和弱AI(纯MCTS)对弈,不需要可视化 is_shown=0,双方轮流职黑 start_player=i % 2
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        # 计算胜率,平手计为0.5分
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        # 开始训练
        try:
            # 训练game_batch_num次,每个batch比赛play_batch_size场
            for i in range(self.game_batch_num):
                # 收集自我对弈数据
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                    i + 1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # 判断当前模型的表现,保存最优模型
                if (i + 1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i + 1))
                    win_ratio = self.policy_evaluate()
                    # 保存当前策略
                    self.policy_value_net.save_model('./current_policy.model')
                    if win_ratio > self.best_win_ratio:
                        print("发现新的最优策略,进行策略更新")
                        self.best_win_ratio = win_ratio
                        # 更新最优策略
                        self.policy_value_net.save_model('./best_policy.model')
                        if (self.best_win_ratio == 1.0
                                and self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')
Ejemplo n.º 3
0
class TrainPipeline():
    def __init__(self, init_model=None, board_width=6, board_height=6,
                 n_in_row=4, n_playout=400, use_gpu=False, is_shown=False,
                 output_file_name="", game_batch_number=1500):
        # 游戏和棋盘参数
        self.board_width = board_width
        self.board_height = board_height
        self.n_in_row = n_in_row
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # 训练参数
        self.learn_rate = 2e-3  #学习率α :0.002
        self.lr_multiplier = 1.0  # 根据 KL散度 适应性的调整学习率 
        self.temp = 1.0  # 温度参数t
        self.n_playout = n_playout  # 每次move的 模拟playout次数
        self.c_puct = 5 #c_put常量
        self.buffer_size = 10000 #缓冲区大小
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size) 
        self.play_batch_size = 1 
        self.epochs = 5  #  每次 update 的 train_steps
        self.kl_targ = 0.02
        self.check_freq = 100
        self.game_batch_num = game_batch_number #训练局数
        self.best_win_ratio = 0.0
        # 纯蒙特卡索搜索训练参数
        # 目的可以是作为真正训练的模型的对手
        self.pure_mcts_playout_num = 7000 #纯蒙特卡洛搜索模拟次数
        self.use_gpu = use_gpu 
        self.is_shown = is_shown
        self.output_file_name = output_file_name #输出的txt文件名
        #初始化神经网络
        self.policy_value_net = PolicyValueNet(self.board_width,
                                               self.board_height,
                                               model_file=init_model,
                                               use_gpu=self.use_gpu
                                               )
        #
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)


    def get_equi_data(self, play_data):
        """旋转和平移,增大本次play_data(因为同一个局面对应着其他三个等价局面)
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state]) #盘面旋转90度
                equi_mcts_prob = np.rot90(np.flipud(
                    mcts_porb.reshape(self.board_height, self.board_width)), i)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """为训练收集self-play的数据"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            play_data = list(play_data)[:]#刚刚的zip元组类型转换成list
            self.episode_len = len(play_data) #事件数
            # 通过旋转平移得到四个等价的局面,对本次数据进行增大/增元 (augment)
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)#将增大后的数据存入缓冲队列

    def policy_update(self):
        """使用训练结果取更新策略网络"""
        mini_batch = random.sample(self.data_buffer, self.batch_size) #从data_buffer中,随机获取batch_size个元素
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                    state_batch,
                    mcts_probs_batch,
                    winner_batch,
                    self.learn_rate*self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(np.sum(old_probs * (
                    np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1)
            )
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}"
               ).format(kl,
                        self.lr_multiplier,
                        loss,
                        entropy,
                        explained_var_old,
                        explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        通过和纯MCTS玩家对战评估当前策略
        Note: 这仅仅是为了监控训练的进程
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=self.is_shown)
            win_cnt[winner] += 1
        win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
                self.pure_mcts_playout_num,
                win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """开始训练"""
        #打开两个txt文件
        with open("info/"+str(self.board)+"_loss_"+self.output_file_name+".txt",'w') as loss_file:
            loss_file.write("self-play次数,loss,entropy\n")
        with open("info/"+str(self.board)+"_win_ration"+self.output_file_name+".txt", 'w') as win_ratio_file:
            win_ratio_file.write("self-play次数, pure_MCTS战力, 胜率\n")
        
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)#执行一次重0到分出胜负的模拟,并收集数据
                print("对局 i:{}, 事件数(走了多少步):{}".format(
                        i+1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                    with open("info/" + str(self.board) + "_loss_" + self.output_file_name + ".txt", 'a') as loss_file:
                        loss_file.write(str(i+1)+','+str(loss)+','+str(entropy)+'\n')
                # check the performance of the current model,
                # and save the model params
                if (i+1) % self.check_freq == 0:
                    print("当前的 self-play 对局: {}".format(i+1))
                    win_ratio = self.policy_evaluate()
                    with open("info/" + str(self.board) + "_win_ration" + self.output_file_name + ".txt",
                              'a') as win_ratio_file:
                        win_ratio_file.write(str(i+1)+','+str(self.pure_mcts_playout_num)+','+str(win_ratio)+'\n')
                    self.policy_value_net.save_model('./model/'+str(self.board_height)
                                                     +'_'+str(self.board_width)
                                                     +'_'+str(self.n_in_row)+
                                                     '_current_policy_'+output_file_name+'.model')
                    if win_ratio >= self.best_win_ratio:
                        print("更好的策略产生!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model('./model/'+str(self.board_height)
                                                     +'_'+str(self.board_width)
                                                     +'_'+str(self.n_in_row)+
                                                     '_best_policy_'+output_file_name+'.model')
                        if (self.best_win_ratio == 1.0 and
                                self.pure_mcts_playout_num < 50000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')
        loss_file.close()
        win_ratio_file.close()
Ejemplo n.º 4
0
class TrainPipeline:
    def __init__(self, config: Config):
        # params of the game

        self.config = config
        self.game = Game.from_config(config)

        # training params

        self.buffer_size = 10000
        self.min_data_to_collect = 128
        self.batch_size = 128  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)

        self.save_freq = 20
        self.eval_freq = self.save_freq * 100
        self.num_total_iter = self.eval_freq * 4
        assert self.num_total_iter % self.save_freq == 0
        assert self.num_total_iter % self.eval_freq == 0

        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.policy_value_net = PolicyValueNet(self.config.size,
                                               model_file=config.model_file)
        self.mcts_player = MCTSPlayer(
            self.policy_value_net,
            c_puct=config.c_puct,
            n_playout=config.n_playout,
            is_selfplay=True,
        )

    def collect_selfplay_data(self):
        """collect self-play data for training"""
        n_game = 0

        while not len(self.data_buffer) > self.min_data_to_collect:
            _, play_data = self.game.start_self_play(self.mcts_player,
                                                     display=True)
            play_data = [(data[0], data[1], data[2]) for data in play_data]
            self.data_buffer.extend(play_data)
            n_game += 1

        return len(self.data_buffer), n_game

    def transform_data(self, board_input_batch, mcts_probs_batch):
        """rotate or flip the original data"""
        original_shape = (board_input_batch.shape, mcts_probs_batch.shape)

        for i in range(board_input_batch.shape[0]):
            n_rotate = np.random.randint(4)
            flip = np.random.randint(2)

            if not n_rotate or not flip:
                continue

            board_input = board_input_batch[i]
            board_size = board_input.shape[0]
            mcts_probs = mcts_probs_batch[i]
            mcts_place_probs = np.reshape(
                mcts_probs[0:board_size * board_size],
                (board_size, board_size))
            mcts_pass_move_prob = mcts_probs[-1]

            if n_rotate:
                board_input = np.rot90(board_input, n_rotate, axes=(1, 2))
                mcts_place_probs = np.rot90(mcts_place_probs, n_rotate)
            if flip:
                board_input = board_input.T
                mcts_place_probs = mcts_place_probs.T

            np.put(board_input_batch, i, board_input)
            np.put(mcts_probs_batch, i,
                   np.append(mcts_place_probs, mcts_pass_move_prob))

        transformed_shape = (board_input_batch.shape, mcts_probs_batch.shape)
        assert original_shape == transformed_shape
        return board_input_batch, mcts_probs

    def policy_update(self):
        """update the policy-value net"""
        random.shuffle(self.data_buffer)
        n_batchs = len(self.data_buffer) // self.batch_size

        for i in range(n_batchs):
            mini_batch = list(
                itertools.islice(self.data_buffer, i * self.batch_size,
                                 (i + 1) * self.batch_size))

            board_input_batch = np.array(
                [get_current_input(data[0]) for data in mini_batch])
            mcts_probs_batch = np.array([data[1] for data in mini_batch])
            winner_batch = np.array([data[2] for data in mini_batch])
            board_input_batch, mcts_probs_batch = self.transform_data(
                board_input_batch, mcts_probs_batch)

            old_probs, old_v = self.policy_value_net.policy_value(
                board_input_batch)

            self.policy_value_net.set_train_mode()

            loss, entropy = self.policy_value_net.train_step(
                board_input_batch, mcts_probs_batch, winner_batch)
            new_probs, new_v = self.policy_value_net.policy_value(
                board_input_batch)
            kl = np.mean(
                np.sum(
                    old_probs *
                    (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1,
                ))

            explained_var_old = 1 - np.var(
                np.array(winner_batch) - old_v.flatten()) / np.var(
                    np.array(winner_batch))
            explained_var_new = 1 - np.var(
                np.array(winner_batch) - new_v.flatten()) / np.var(
                    np.array(winner_batch))
            print(("batch:{}, "
                   "kl:{:.5f}, "
                   "loss:{:.5f}, "
                   "entropy:{:.5f}, "
                   "explained_var_old:{:.3f}, "
                   "explained_var_new:{:.3f}").format(
                       i,
                       kl,
                       loss,
                       entropy,
                       explained_var_old,
                       explained_var_new,
                   ))

    def run(self):
        """run the training pipeline"""
        np.random.seed(0)
        try:
            for i in range(self.num_total_iter):
                n_data, n_game = self.collect_selfplay_data()

                print("iteration {}: total {} data collected from {} game(s)".
                      format(i, n_data, n_game))

                self.policy_value_net.set_train_mode()
                self.policy_update()
                self.data_buffer.clear()
                # check the performance of the current model,
                # and save the model params
                if (i + 1) % self.save_freq == 0:
                    print("saving current model at {}: file={}".format(
                        i + 1, self.config.get_current_model_name()))
                    self.policy_value_net.save_model(
                        self.config.get_current_model_name())
                if (i + 1) % self.eval_freq == 0:
                    print("evalutating current model: {}".format(i + 1))
                    current_mcts_player = MCTSPlayer(
                        self.policy_value_net,
                        c_puct=self.config.c_puct,
                        n_playout=self.config.n_playout,
                    )
                    win_ratio = evaluate.evaluate_policy(
                        self.game, current_mcts_player)
                    if win_ratio > self.best_win_ratio:
                        print(
                            "saving the new best policy at {}! win_ratio={}, file={}"
                            .format(
                                i,
                                win_ratio,
                                self.config.get_best_model_name(),
                            ))
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model(
                            self.config.get_best_model_name())
        except KeyboardInterrupt:
            print("\n\rquit")
Ejemplo n.º 5
0
class TrainPipeline():
    def __init__(self, init_model=None):

        self.writer = SummaryWriter(WRITER_DIR)

        # params of the board and the game
        self.board_width = 6
        self.board_height = 6
        self.n_in_row = 4

        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)

        self.game = Game(self.board)

        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param

        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02

        self.check_freq = 50
        self.game_batch_num = 5000

        self.improvement_counter = 1000
        self.best_win_ratio = 0.0

        self.input_plains_num = INPUT_PLANES_NUM

        self.c_puct = 5
        self.n_playout = 50  # num of simulations for each move
        self.shutter_threshold_availables = 1
        self.full_boards_selfplay = False

        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 200
        self.pure_mcts_playout_num_step = 200

        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(
                self.board_width,
                self.board_height,
                self.input_plains_num,
                model_file=init_model,
                shutter_threshold_availables=self.shutter_threshold_availables)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(
                self.board_width,
                self.board_height,
                self.input_plains_num,
                shutter_threshold_availables=self.shutter_threshold_availables)

        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(
                    np.flipud(
                        mcts_porb.reshape(self.board_height,
                                          self.board_width)), i)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):

        self.episode_len = 0
        self.episode_len_full_1 = 0
        self.episode_len_full_2 = 0
        self.episode_len_full_3 = 0
        self.episode_len_full_4 = 0
        self.episode_len_full_5 = 0

        if self.full_boards_selfplay:
            """collect self-play data for training"""
            for i in range(n_games):

                #EMPTY BOARD:
                winner, play_data = self.game.start_self_play(
                    self.mcts_player,
                    temp=self.temp,
                    is_last_move=(self.input_plains_num == 4),
                    start_player=i % 2 + 1)
                play_data = list(play_data)[:]
                self.episode_len += len(play_data) / n_games

                # augment the data
                play_data = self.get_equi_data(play_data)
                self.data_buffer.extend(play_data)

                if self.board_width == 6:
                    #BOARD 1 FULL
                    board = copy.deepcopy(BOARD_1_FULL[0])
                    board = np.flipud(board)
                    i_board_1 = np.zeros(
                        (2, self.board_width, self.board_height))
                    i_board_1[0] = board == 1
                    i_board_1[1] = board == 2

                    winner_full_1, play_data_full_1 = self.game.start_self_play(
                        self.mcts_player,
                        temp=self.temp,
                        is_last_move=(self.input_plains_num == 4),
                        initial_state=i_board_1)

                    play_data_full_1 = list(play_data_full_1)[:]
                    self.episode_len_full_1 += len(play_data_full_1) / n_games

                    # augment the data
                    play_data_full_1 = self.get_equi_data(play_data_full_1)
                    self.data_buffer.extend(play_data_full_1)

                    # BOARD 2 FULL
                    board = copy.deepcopy(BOARD_2_FULL[0])
                    board = np.flipud(board)
                    i_board_2 = np.zeros(
                        (2, self.board_width, self.board_height))
                    i_board_2[0] = board == 1
                    i_board_2[1] = board == 2

                    winner_full_2, play_data_full_2 = self.game.start_self_play(
                        self.mcts_player,
                        temp=self.temp,
                        is_last_move=(self.input_plains_num == 4),
                        initial_state=i_board_2)

                    play_data_full_2 = list(play_data_full_2)[:]
                    self.episode_len_full_2 += len(play_data_full_2) / n_games

                    # augment the data
                    play_data_full_2 = self.get_equi_data(play_data_full_2)
                    self.data_buffer.extend(play_data_full_2)

                else:
                    # BOARD 3 FULL
                    board = copy.deepcopy(BOARD_3_FULL[0])
                    board = np.flipud(board)
                    i_board_3 = np.zeros(
                        (2, self.board_width, self.board_height))
                    i_board_3[0] = board == 1
                    i_board_3[1] = board == 2

                    winner_full_3, play_data_full_3 = self.game.start_self_play(
                        self.mcts_player,
                        temp=self.temp,
                        is_last_move=(self.input_plains_num == 4),
                        initial_state=i_board_3)

                    play_data_full_3 = list(play_data_full_3)[:]
                    self.episode_len_full_3 += len(play_data_full_3) / n_games

                    # augment the data
                    play_data_full_3 = self.get_equi_data(play_data_full_3)
                    self.data_buffer.extend(play_data_full_3)

                    # BOARD 4 FULL
                    board = copy.deepcopy(BOARD_4_FULL[0])
                    board = np.flipud(board)
                    i_board_4 = np.zeros(
                        (2, self.board_width, self.board_height))
                    i_board_4[0] = board == 1
                    i_board_4[1] = board == 2

                    winner_full_4, play_data_full_4 = self.game.start_self_play(
                        self.mcts_player,
                        temp=self.temp,
                        is_last_move=(self.input_plains_num == 4),
                        initial_state=i_board_4)

                    play_data_full_4 = list(play_data_full_4)[:]
                    self.episode_len_full_4 += len(play_data_full_4) / n_games

                    # augment the data
                    play_data_full_4 = self.get_equi_data(play_data_full_4)
                    self.data_buffer.extend(play_data_full_4)

                    # BOARD 5 FULL
                    board = copy.deepcopy(BOARD_5_FULL[0])
                    board = np.flipud(board)
                    i_board_5 = np.zeros(
                        (2, self.board_width, self.board_height))
                    i_board_5[0] = board == 1
                    i_board_5[1] = board == 2

                    winner_full_5, play_data_full_5 = self.game.start_self_play(
                        self.mcts_player,
                        temp=self.temp,
                        is_last_move=(self.input_plains_num == 4),
                        initial_state=i_board_5)

                    play_data_full_5 = list(play_data_full_5)[:]
                    self.episode_len_full_5 += len(play_data_full_5) / n_games

                    # augment the data
                    play_data_full_5 = self.get_equi_data(play_data_full_5)
                    self.data_buffer.extend(play_data_full_5)

        else:
            for i in range(n_games):
                # EMPTY BOARD:
                winner, play_data = self.game.start_self_play(
                    self.mcts_player,
                    temp=self.temp,
                    is_last_move=(self.input_plains_num == 4),
                    start_player=i % 2 + 1)
                play_data = list(play_data)[:]
                self.episode_len += len(play_data) / n_games

                # augment the data
                play_data = self.get_equi_data(play_data)
                self.data_buffer.extend(play_data)

    def policy_update(self, iteration):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)

        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)

            new_probs, new_v = self.policy_value_net.policy_value(state_batch)

            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))

            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
            # adaptively adjust the learning rate

        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))

        train_str = ("kl:{:.5f}, "
                     "lr_multiplier:{:.3f}, "
                     "loss:{}, "
                     "entropy:{}, "
                     "explained_var_old:{:.3f}, "
                     "explained_var_new:{:.3f}").format(
                         kl, self.lr_multiplier, loss, entropy,
                         explained_var_old, explained_var_new)

        print(train_str)

        self.writer.add_scalar('lr multiplier', self.lr_multiplier,
                               iteration + 1)
        self.writer.add_scalar('kl_', kl, iteration + 1)
        self.writer.add_scalar('explained var old', explained_var_old,
                               iteration + 1)
        self.writer.add_scalar('explained var new', explained_var_new,
                               iteration + 1)
        self.writer.add_scalar('training loss', loss, iteration + 1)
        self.writer.add_scalar('training entropy', entropy, iteration + 1)

        # self.writer.add_scalars("training tracking", {'lr multiplier': self.lr_multiplier,
        #                                               'kl': kl,
        #                                               'explained var old': explained_var_old,
        #                                               'explained var new': explained_var_new,
        #                                               'training loss':loss,
        #                                               'training entropy':entropy},
        #                                                i+1)

        return loss, entropy

    def policy_evaluate(self, iteration, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)

        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)

        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2 + 1,
                                          is_shown=0,
                                          savefig=False)
            win_cnt[winner] += 1
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        print("num_playouts: {}, win: {}, lose: {}, tie:{}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1]))

        self.writer.add_text(
            tag='evaluation results',
            text_string=
            f"num_playouts: {self.pure_mcts_playout_num}, win: {win_cnt[1]}, lose: {win_cnt[2]}, tie:{win_cnt[-1]}",
            global_step=iteration + 1)

        return win_ratio

    def run(self):
        """run the training pipeline"""

        if not os.path.exists(MODEL_DIR):
            os.makedirs(MODEL_DIR)

        try:
            improvement_counter_local = 0

            for i in range(self.game_batch_num):

                self.writer.add_scalar('MCTS playouts num',
                                       self.pure_mcts_playout_num, i + 1)

                self.collect_selfplay_data(self.play_batch_size)

                if self.full_boards_selfplay:

                    if self.board_width == 6:
                        print(
                            "batch i:{}, episode_len:{}, episode len full 1: {}, episode len full 2: {}"
                            .format(i + 1, self.episode_len,
                                    self.episode_len_full_1,
                                    self.episode_len_full_2))
                        self.writer.add_scalar('episode len full 1',
                                               self.episode_len_full_1, i + 1)
                        self.writer.add_scalar('episode len full 2',
                                               self.episode_len_full_2, i + 1)
                        self.writer.add_scalar('episode len', self.episode_len,
                                               i + 1)

                    else:
                        print(
                            "batch i:{}, episode_len:{}, episode len full 3: {}, episode len full 4: {}, episode len full 5: {}"
                            .format(i + 1, self.episode_len,
                                    self.episode_len_full_3,
                                    self.episode_len_full_4,
                                    self.episode_len_full_5))

                        self.writer.add_scalar('episode len full 3',
                                               self.episode_len_full_3, i + 1)
                        self.writer.add_scalar('episode len full 4',
                                               self.episode_len_full_4, i + 1)
                        self.writer.add_scalar('episode len full 5',
                                               self.episode_len_full_5, i + 1)

                        self.writer.add_scalar('episode len', self.episode_len,
                                               i + 1)

                else:
                    print("batch i:{}, episode_len:{}".format(
                        i + 1, self.episode_len))
                    self.writer.add_scalar('episode len', self.episode_len,
                                           i + 1)

                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update(iteration=i)

                # check the performance of the current model,
                # and save the model params
                if (i + 1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i + 1))
                    win_ratio = self.policy_evaluate(iteration=i)

                    self.policy_value_net.save_model(
                        f'{MODEL_DIR}/current_policy_{i + 1}.model')

                    if win_ratio > self.best_win_ratio:

                        self.writer.add_text('best model savings',
                                             'better model found', i + 1)

                        print("New best policy!!!!!!!!")

                        improvement_counter_local = 0
                        self.best_win_ratio = win_ratio

                        # update the best_policy
                        # self.policy_value_net.save_model(f'{MODEL_DIR}/best_policy.model')

                        # if (self.best_win_ratio == 1.0 and
                        #         self.pure_mcts_playout_num < 5000):

                        if self.best_win_ratio == 1.0:
                            self.pure_mcts_playout_num += self.pure_mcts_playout_num_step
                            self.best_win_ratio = 0.0

                    else:
                        improvement_counter_local += 1
                        if improvement_counter_local == self.improvement_counter:
                            print(
                                f"No better policy was found in the last {self.improvement_counter} "
                                f"checks. Ending training. ")

                            self.writer.add_text(
                                'best model savings',
                                f"No better policy was found "
                                f"in the last {self.improvement_counter} "
                                f"checks. Ending training. ", i + 1)
                            break

        except KeyboardInterrupt:
            print('\n\rquit')
Ejemplo n.º 6
0
class TrainPipeline():
    def __init__(self, init_model=None):
        # params of the board and the game
        self.board_width = 6
        self.board_height = 6
        self.n_in_row = 4
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # training params
        self.learn_rate = 2e-3
        self.data_buffer = deque(maxlen=1000)
        self.batch_size = 10
        self.temp = 1.0  # the temperature param
        self.n_playout = 40  # num of simulations for each move
        self.c_puct = 5
        self.epochs = 50

        self.pure_mcts_playout_num = 2
        self.best_win_ratio = 0.0

        self.policy_value_net = PolicyValueNet(self.board_width,
                                               self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        print("Phase 1: Collecting Data")
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            play_data = list(play_data)[:]
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        print("Phase 2: Updating the Network")
        mini_batch = random.sample(self.data_buffer, self.batch_size)

        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]

        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch, self.learn_rate)
            #print("Loss is {}, Entropy is {}".format(loss,entropy))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """

        print("Phase 3: Evaluatiing the Network")
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(5):
                print("Then {} / {} Training".format(i, 5))
                self.collect_selfplay_data(5)
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                    print("Final Loss : {} , Final Entropy: {}".format(
                        loss, entropy))

                win_ratio = self.policy_evaluate(7)
                print("Win-Ratio: ", win_ratio)
                #self.policy_value_net.save_model('./current_policy.model')
                if win_ratio > self.best_win_ratio:
                    print("New best policy!!!!!!!!")
                    self.best_win_ratio = win_ratio
            self.policy_value_net.save_model('./best_policy.model')
        except KeyboardInterrupt:
            print('\n\rquit')