Esempio n. 1
0
class FiveChessTrain():
    def __init__(self):
        self.policy_evaluate_size = 20  # 策略评估胜率时的模拟对局次数
        self.batch_size = 256  # 训练一批数据的长度
        self.max_keep_size = 500000  # 保留最近对战样本个数 平均一局大约400~600个样本, 也就是包含了最近1000次对局数据

        # 训练参数
        self.learn_rate = 1e-5
        self.lr_multiplier = 1.0  # 基于KL的自适应学习率
        self.temp = 1  # 概率缩放程度,实际预测0.01,训练采用1
        self.n_playout = 600  # 每个动作的模拟次数
        self.play_batch_size = 1 # 每次自学习次数
        self.epochs = 1  # 重复训练次数, 推荐是5
        self.kl_targ = 0.02  # 策略价值网络KL值目标
        
        # 纯MCTS的模拟数,用于评估策略模型
        self.pure_mcts_playout_num = 4000 # 用户纯MCTS构建初始树时的随机走子步数
        self.c_puct = 4  # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5

        if os.path.exists(model_file):
            # 使用一个训练好的策略价值网络
            self.policy_value_net = PolicyValueNet(size, model_file=model_file)
        else:
            # 使用一个新的的策略价值网络
            self.policy_value_net = PolicyValueNet(size)

        print("start data loader")
        self.dataset = Dataset(data_dir, self.max_keep_size)
        print("dataset len:",len(self.dataset),"index:",self.dataset.index)
        print("end data loader")

    def policy_update(self, sample_data, epochs=1):
        """更新策略价值网络policy-value"""
        # 训练策略价值网络
        state_batch, mcts_probs_batch, winner_batch = sample_data

        # old_probs, old_v = self.policy_value_net.policy_value(state_batch)  
        for i in range(epochs):
            loss, v_loss, p_loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier)
            # new_probs, new_v = self.policy_value_net.policy_value(state_batch)

            # 散度计算:
            # D(P||Q) = sum( pi * log( pi / qi) ) = sum( pi * (log(pi) - log(qi)) )
            # kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1))
            # if kl > self.kl_targ * epochs:  # 如果D_KL跑偏则尽早停止
            #     break

        # 自动调整学习率
        # if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
        #     self.lr_multiplier /= 1.5
        # elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
        #     self.lr_multiplier *= 1.5
        # 如果学习到了,explained_var 应该趋近于 1,如果没有学习到也就是胜率都为很小值时,则为 0
        # explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch)))
        # explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch)))
        # entropy 信息熵,越小越好
        # logging.info(("TRAIN kl:{:.5f},lr_multiplier:{:.3f},v_loss:{:.5f},p_loss:{:.5f},entropy:{:.5f},var_old:{:.5f},var_new:{:.5f}"
        #               ).format(kl, self.lr_multiplier, v_loss, p_loss, entropy, explained_var_old, explained_var_new))
        return loss, v_loss, p_loss, entropy

    def run(self):
        """启动训练"""
        try:
            dataset_len = len(self.dataset)      
            training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=4,)
            old_probs = None
            test_batch = None
            for i, data in enumerate(training_loader):  # 计划训练批次
                loss, v_loss, p_loss, entropy = self.policy_update(data, self.epochs)              
                if (i+1) % 10 == 0:
                    logging.info(("TRAIN idx {} : {} / {} v_loss:{:.5f}, p_loss:{:.5f}, entropy:{:.5f}")\
                        .format(i, i*self.batch_size, dataset_len, v_loss, p_loss, entropy))
                    
                    # 动态调整学习率
                    if old_probs is None:
                        test_batch, _, _ = data
                        old_probs, _ = self.policy_value_net.policy_value(test_batch) 
                    else:
                        new_probs, _ = self.policy_value_net.policy_value(test_batch)
                        kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1))
                        old_probs = None
        
                        if kl > self.kl_targ * 2:
                            self.lr_multiplier /= 1.5
                        elif kl < self.kl_targ / 2 and self.lr_multiplier < 100:
                            self.lr_multiplier *= 1.5
                        else:
                            continue
                        logging.info("kl:{} lr_multiplier:{} lr:{}".format(kl, self.lr_multiplier, self.learn_rate*self.lr_multiplier))

                    # logging.info("Train idx {} : {} / {}".format(i, i*self.batch_size, len(self.dataset)))
            self.policy_value_net.save_model(model_file)
        except KeyboardInterrupt:
            logging.info('quit')
Esempio n. 2
0
class Train():
    def __init__(self):
        self.game_batch_num = 1000000  # selfplay对战次数
        self.batch_size = 512  # data_buffer中对战次数超过n次后开始启动模型训练

        # training params
        self.learn_rate = 1e-5
        self.lr_multiplier = 1.0  # 基于KL的自适应学习率
        self.temp = 1  # MCTS的概率参数,越大越不肯定,训练时1,预测时1e-3
        self.n_playout = 256  # 每个动作的模拟战记录个数
        self.play_batch_size = 5  # 每次自学习次数
        self.buffer_size = 300000  # cache对次数
        self.epochs = 2  # 每次更新策略价值网络的训练步骤数, 推荐是5
        self.kl_targ = 0.02  # 策略价值网络KL值目标
        self.best_win_ratio = 0.0

        self.c_puct = 0.1  # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5
        self.policy_value_net = PolicyValueNet(GAME_WIDTH,
                                               GAME_HEIGHT,
                                               GAME_ACTIONS_NUM,
                                               model_file=model_file)

    def get_equi_data(self, play_data):
        """
        通过翻转增加数据集
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            extend_data.append((state, mcts_porb, winner))
            # 水平翻转
            equi_state = np.array([np.fliplr(s) for s in state])
            equi_mcts_prob = mcts_porb[[0, 2, 1, 3]]
            extend_data.append((equi_state, equi_mcts_prob, winner))
        return extend_data

    def collect_selfplay_data(self):
        """收集自我对抗数据用于训练"""
        # 使用MCTS蒙特卡罗树搜索进行自我对抗
        logging.info("TRAIN Self Play starting ...")
        # 游戏代理
        agent = Agent()

        # 创建使用策略价值网络来指导树搜索和评估叶节点的MCTS玩家
        mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                 c_puct=self.c_puct,
                                 n_playout=self.n_playout,
                                 is_selfplay=1)
        for _ in range(3):
            # 开始下棋
            reward, piececount, agentcount, play_data = agent.start_self_play(
                mcts_player, temp=self.temp)
            play_data = list(play_data)[:]
            episode_len = len(play_data)

            # 把翻转棋盘数据加到数据集里
            # play_data = self.get_equi_data(play_data)
            logging.info("TRAIN Self Play end. length:%s saving ..." %
                         episode_len)
            # 保存对抗数据到data_buffer
            for obj in play_data:
                filename = "{}.pkl".format(uuid.uuid1())
                savefile = os.path.join(data_wait_dir, filename)
                pickle.dump(obj, open(savefile, "wb"))
                # self.dataset.save(obj)

            if agent.limit_max_height == 10:
                jsonfile = os.path.join(data_dir, "result.json")
                if os.path.exists(jsonfile):
                    result = json.load(open(jsonfile, "r"))
                else:
                    result = {"reward": 0, "steps": 0, "agent": 0}
                if "1k" not in result:
                    result["1k"] = {"reward": 0, "steps": 0, "agent": 0}
                result["reward"] = result["reward"] + reward
                result["steps"] = result["steps"] + piececount
                result["agent"] = result["agent"] + agentcount
                result["1k"]["reward"] = result["1k"]["reward"] + reward
                result["1k"]["steps"] = result["1k"]["steps"] + piececount
                result["1k"]["agent"] = result["1k"]["agent"] + agentcount

                if result["agent"] > 0 and result["agent"] % 100 <= 1:
                    result[str(result["agent"])] = {
                        "reward":
                        result["1k"]["reward"] / result["1k"]["agent"],
                        "steps": result["1k"]["steps"] / result["1k"]["agent"]
                    }

                if result["agent"] > 0 and result["agent"] % 1000 == 0:

                    # 额外保存
                    steps = round(result["1k"]["steps"] /
                                  result["1k"]["agent"])
                    model_file = os.path.join(model_dir,
                                              'model_%s.pth' % steps)
                    self.policy_value_net.save_model(model_file)

                    for key in list(result.keys()):
                        if key.isdigit():
                            c = int(key)
                            if c % 1000 > 10:
                                del result[key]
                    result["1k"] = {"reward": 0, "steps": 0, "agent": 0}

                json.dump(result, open(jsonfile, "w"), ensure_ascii=False)

            if reward >= 1: break

    def policy_update(self, sample_data, epochs=1):
        """更新策略价值网络policy-value"""
        # 训练策略价值网络
        # 随机抽取data_buffer中的对抗数据
        # mini_batch = self.dataset.loadData(sample_data)
        state_batch, mcts_probs_batch, winner_batch = sample_data
        # # for x in mini_batch:
        # #     print("-----------------")
        # #     print(x)
        # # state_batch = [data[0] for data in mini_batch]
        # # mcts_probs_batch = [data[1] for data in mini_batch]
        # # winner_batch = [data[2] for data in mini_batch]

        # print(state_batch)

        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(epochs):
            loss, v_loss, p_loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)

            # 散度计算:
            # D(P||Q) = sum( pi * log( pi / qi) ) = sum( pi * (log(pi) - log(qi)) )
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # 如果D_KL跑偏则尽早停止
                break

        # 自动调整学习率
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5
        # 如果学习到了,explained_var 应该趋近于 1,如果没有学习到也就是胜率都为很小值时,则为 0
        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        # entropy 信息熵,越小越好
        logging.info((
            "TRAIN kl:{:.5f},lr_multiplier:{:.3f},v_loss:{:.5f},p_loss:{:.5f},entropy:{:.5f},var_old:{:.5f},var_new:{:.5f}"
        ).format(kl, self.lr_multiplier, v_loss, p_loss, entropy,
                 explained_var_old, explained_var_new))
        return loss, entropy

    def run(self):
        """启动训练"""
        try:
            # print("start data loader")
            # self.dataset = Dataset(data_dir, self.buffer_size)
            # print("end data loader")

            # step = 0
            # # 如果训练数据一半都不到,就先攒训练数据
            # if self.dataset.curr_game_batch_num/self.dataset.buffer_size<0.5:
            #     for _ in range(8):
            #         logging.info("TRAIN Batch:{} starting".format(self.dataset.curr_game_batch_num,))
            #         # n_playout=self.n_playout
            #         # self.n_playout=8
            #         self.collect_selfplay_data()
            #         # self.n_playout=n_playout
            #         logging.info("TRAIN Batch:{} end".format(self.dataset.curr_game_batch_num,))
            #         step += 1

            # training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=2,)

            # for i, data in enumerate(training_loader):  # 计划训练批次
            #     # 使用对抗数据重新训练策略价值网络模型
            #     loss, entropy = self.policy_update(data, self.epochs)

            # self.policy_value_net.save_model(model_file)
            # 收集自我对抗数据
            # for _ in range(self.play_batch_size):
            self.collect_selfplay_data()
            # logging.info("TRAIN {} self-play end, size: {}".format(self.dataset.curr_game_batch_num, self.dataset.curr_size()))

        except KeyboardInterrupt:
            logging.info('quit')
Esempio n. 3
0
class Train():
    def __init__(self):
        self.game_batch_num = 2000000  # selfplay对战次数
        self.batch_size = 512     # data_buffer中对战次数超过n次后开始启动模型训练

        # training params
        self.learn_rate = 1e-5
        self.lr_multiplier = 1.0  # 基于KL的自适应学习率
        self.temp = 1  # MCTS的概率参数,越大越不肯定,训练时1,预测时1e-3
        self.n_playout = 64  # 每个动作的模拟战记录个数
        self.play_batch_size = 1 # 每次自学习次数
        self.buffer_size = 200000  # cache对次数
        self.epochs = 1  # 每次更新策略价值网络的训练步骤数, 推荐是5
        self.kl_targ = 0.02  # 策略价值网络KL值目标
        self.best_win_ratio = 0.0
        
        self.c_puct = 0.1  # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5

    def get_equi_data(self, play_data):
        """
        通过翻转增加数据集
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            extend_data.append((state, mcts_porb, winner))
            # 水平翻转
            equi_state = np.array([np.fliplr(s) for s in state])
            equi_mcts_prob = mcts_porb[[0,2,1,3]]
            extend_data.append((equi_state, equi_mcts_prob, winner))
        return extend_data

    def collect_selfplay_data(self):
        """收集自我对抗数据用于训练"""       
        # 使用MCTS蒙特卡罗树搜索进行自我对抗
        logging.info("TRAIN Self Play starting ...")
        # 游戏代理
        agent = Agent()

        # 创建使用策略价值网络来指导树搜索和评估叶节点的MCTS玩家
        mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1)
        # 开始下棋
        winer, play_data = agent.start_self_play(mcts_player, temp=self.temp)
        play_data = list(play_data)[:]
        episode_len = len(play_data)

        # 把翻转棋盘数据加到数据集里
        # play_data = self.get_equi_data(play_data)
        logging.info("TRAIN Self Play end. length:%s saving ..." % episode_len)
        # 保存对抗数据到data_buffer
        for obj in play_data:
            self.dataset.save(obj)

    def policy_update(self, sample_data, epochs=1):
        """更新策略价值网络policy-value"""
        # 训练策略价值网络
        # 随机抽取data_buffer中的对抗数据
        # mini_batch = self.dataset.loadData(sample_data)
        state_batch, mcts_probs_batch, winner_batch = sample_data
        totle = torch.sum(winner_batch)
        totle_value = totle.item()
        # # for x in mini_batch:
        # #     print("-----------------")
        # #     print(x)
        # # state_batch = [data[0] for data in mini_batch]
        # # mcts_probs_batch = [data[1] for data in mini_batch]
        # # winner_batch = [data[2] for data in mini_batch]

        # print(state_batch)

        # old_probs, old_v = self.policy_value_net.policy_value(state_batch)  
        for i in range(epochs):
            loss, v_loss, p_loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier)
            # new_probs, new_v = self.policy_value_net.policy_value(state_batch)

            # 散度计算:
            # D(P||Q) = sum( pi * log( pi / qi) ) = sum( pi * (log(pi) - log(qi)) )
            # kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1))
            # if kl > self.kl_targ * 4:  # 如果D_KL跑偏则尽早停止
            #     break

        # 自动调整学习率
        # if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
        #     self.lr_multiplier /= 1.5
        # elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
        #     self.lr_multiplier *= 1.5
        # 如果学习到了,explained_var 应该趋近于 1,如果没有学习到也就是胜率都为很小值时,则为 0
        # explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch)))
        # explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch)))
        # entropy 信息熵,越小越好
        # logging.info(("TRAIN kl:{:.5f},lr_multiplier:{:.3f},v_loss:{:.5f},p_loss:{:.5f},entropy:{:.5f},var_old:{:.5f},var_new:{:.5f}"
        #               ).format(kl, self.lr_multiplier, v_loss, p_loss, entropy, explained_var_old, explained_var_new))
        return totle_value, v_loss, p_loss, entropy

    def run(self):
        """启动训练"""
        try:
            print("start data loader")
            self.dataset = Dataset(data_dir, self.buffer_size)
            newsample=self.dataset.newsample
            self.testdataset = TestDataset(data_dir, 10, newsample)
            print("end data loader")

            self.policy_value_net = PolicyValueNet(GAME_WIDTH, GAME_HEIGHT, GAME_ACTIONS_NUM, model_file=model_file)
            self.policy_value_net.save_model(model_file+".bak")
            # step = 0
            # # 如果训练数据一半都不到,就先攒训练数据
            # if self.dataset.curr_game_batch_num/self.dataset.buffer_size<0.5:
            #     for _ in range(8):
            #         logging.info("TRAIN Batch:{} starting".format(self.dataset.curr_game_batch_num,))
            #         # n_playout=self.n_playout
            #         # self.n_playout=8
            #         self.collect_selfplay_data()
            #         # self.n_playout=n_playout
            #         logging.info("TRAIN Batch:{} end".format(self.dataset.curr_game_batch_num,))
            #         step += 1
            dataset_len = len(self.dataset)  
            training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=1,)
            testing_loader = torch.utils.data.DataLoader(self.testdataset, batch_size=self.batch_size, shuffle=True, num_workers=1,)
            old_probs = None
            test_batch = None
            totle = 0
            for i, data in enumerate(training_loader):  # 计划训练批次
                if i==0:
                    _batch, _probs, _win = data
                    print(_batch[0][0])
                    print(_batch[0][1])
                    print(_probs[0])
                    print(_win[0])

                # 使用对抗数据重新训练策略价值网络模型
                totle_value, v_loss, p_loss, entropy = self.policy_update(data, self.epochs)
                totle = totle + totle_value
                if i%10 == 0:
                    logging.info(("TRAIN idx {} : {} / {} v_loss:{:.5f}, p_loss:{:.5f}, entropy:{:.5f}")\
                        .format(i, i*self.batch_size, dataset_len, v_loss, p_loss, entropy))

                    # 动态调整学习率
                    if old_probs is None:
                        test_batch, test_probs, test_win = next(iter(testing_loader))
                        old_probs, old_value = self.policy_value_net.policy_value(test_batch) 
                    else:
                        new_probs, new_value = self.policy_value_net.policy_value(test_batch)
                        kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1))
                        
                        if i % 50 == 0:   
                            logging.info("probs[0] old:{}".format(old_probs[0]))   
                            logging.info("probs[0] new:{}".format(new_probs[0]))   
                            logging.info("probs[0] tg: {}".format(test_probs[0])) 
                            maxlen = min(10, len(test_win)) 
                            for j in range(maxlen): 
                                logging.info("value[0] old:{} new:{} tg:{}".format(old_value[j][0], new_value[j][0], test_win[j]))  

                        old_probs = None
                        
                        if kl > self.kl_targ * 2:
                            self.lr_multiplier /= 1.5
                        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
                            self.lr_multiplier *= 1.5
                        else:
                            continue
                        logging.info("kl:{} lr_multiplier:{} lr:{}".format(kl, self.lr_multiplier, self.learn_rate*self.lr_multiplier))



            self.policy_value_net.save_model(model_file)
            # 收集自我对抗数据
            # for _ in range(self.play_batch_size):
            #     self.collect_selfplay_data()
            # logging.info("TRAIN {} self-play end, size: {}".format(self.dataset.curr_game_batch_num, self.dataset.curr_size()))
            # x - y = totle
            # x + y = dataset_len
            win = (totle+dataset_len)//2
            print("win:", win, "lost:", dataset_len-win, "prop:", win/dataset_len)           
    
        except KeyboardInterrupt:
            logging.info('quit')
Esempio n. 4
0
class Train():
    def __init__(self):
        self.game_batch_num = 2000000  # selfplay对战次数
        self.batch_size = 512     # data_buffer中对战次数超过n次后开始启动模型训练

        # training params
        self.learn_rate = 1e-5
        self.lr_multiplier = 1.0  # 基于KL的自适应学习率
        self.temp = 1  # MCTS的概率参数,越大越不肯定,训练时1,预测时1e-3
        self.n_playout = 64  # 每个动作的模拟战记录个数
        self.play_batch_size = 1 # 每次自学习次数
        self.buffer_size = 200000  # cache对次数
        self.epochs = 1  # 每次更新策略价值网络的训练步骤数, 推荐是5
        self.kl_targ = 0.02  # 策略价值网络KL值目标
        self.best_win_ratio = 0.0
        
        self.c_puct = 2  # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5
   

    def policy_update(self, sample_data, epochs=1):
        """更新策略价值网络policy-value"""
        # 训练策略价值网络
        # 随机抽取data_buffer中的对抗数据
        state_batch, mcts_probs_batch, values_batch = sample_data
        # 训练策略价值网络
        for i in range(epochs):
            loss, v_loss, p_loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, values_batch, self.learn_rate * self.lr_multiplier)
         
        return loss, v_loss, p_loss, entropy

    def run(self):
        """启动训练"""
        try:
            print("start data loader")
            self.dataset = Dataset(data_dir, self.buffer_size)
            self.testdataset = copy.copy(self.dataset)
            self.testdataset.test=True
            print("end data loader")

            self.policy_value_net = PolicyValueNet(GAME_WIDTH, GAME_HEIGHT, GAME_ACTIONS_NUM, model_file=model_file)
            self.policy_value_net.save_model(model_file+".bak")           

            dataset_len = len(self.dataset)  
            training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
            testing_loader = torch.utils.data.DataLoader(self.testdataset, batch_size=self.batch_size, shuffle=False,num_workers=0)
            old_probs = None
            test_batch = None

            for i, data in enumerate(training_loader):  # 计划训练批次
                if i==0:
                    _batch, _qvals, _actions = data
                    for j in range(len(_batch[0])):
                        print(_batch[0][j])
                    print(_qvals[0])
                    print(_actions[0])

                # 使用对抗数据重新训练策略价值网络模型
                _, v_loss, p_loss, entropy = self.policy_update(data, self.epochs)
                if i%10 == 0:
                    print(("TRAIN idx {} : {} / {} v_loss:{:.5f}, p_loss:{:.5f}, entropy:{:.5f}")\
                        .format(i, i*self.batch_size, dataset_len, v_loss, p_loss, entropy))

                    # 动态调整学习率
                    if old_probs is None:
                        test_batch, test_probs, test_valus = next(iter(testing_loader))
                        old_probs, old_value = self.policy_value_net.policy_value(test_batch) 
                    else:
                        new_probs, new_value = self.policy_value_net.policy_value(test_batch)
                        kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1))
                        
                        if i % 50 == 0:   
                            print("probs[0] old:{}".format(old_probs[0]))   
                            print("probs[0] new:{}".format(new_probs[0]))
                            print("probs[0] dst:{}".format(test_probs[0]))   
                            maxlen = min(10, len(test_batch)) 
                            for j in range(maxlen): 
                                print("value[0] old:{} new:{} tg:{}".format(old_value[j][0], new_value[j][0], test_valus[j]))  

                        old_probs = None
                        
                        if kl > self.kl_targ * 2:
                            self.lr_multiplier /= 1.5
                        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
                            self.lr_multiplier *= 1.5
                        else:
                            continue
                        print("kl:{} lr_multiplier:{} lr:{}".format(kl, self.lr_multiplier, self.learn_rate*self.lr_multiplier))

            self.policy_value_net.save_model(model_file)
   
    
        except KeyboardInterrupt:
            print('quit')
Esempio n. 5
0
    def collect_selfplay_data(self):
        """收集自我对抗数据用于训练"""
        print("TRAIN Self Play starting ...")

        jsonfile = os.path.join(data_dir, "result.json")

        # 游戏代理
        agent = Agent()

        max_game_num = 1
        agentcount, agentreward, piececount, agentscore = 0, 0, 0, 0

        borads = []
        game_num = 0

        cpuct_first_flag = random.random() > 0.5

        # 尽量不要出现一样的局面
        game_keys = []
        game_datas = []
        # 开始一局游戏
        for _ in count():
            start_time = time.time()
            game_num += 1
            print('start game :', game_num, 'time:',
                  datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))

            result = self.read_status_file(jsonfile)
            print("QVal:", result["QVal"])

            # c_puct 参数自动调节,step=0.1
            cpuct_list = []
            for cp in result["cpuct"]:
                cpuct_list.append(cp)
                if len(cpuct_list) == 2: break
            cpuct_list.sort()

            print("cpuct:", result["cpuct"])

            if cpuct_first_flag:
                cpuct = float(cpuct_list[0])
            else:
                cpuct = float(cpuct_list[1])
            cpuct_first_flag = not cpuct_first_flag

            print("c_puct:", cpuct, "n_playout:", self.n_playout)
            policy_value_net = PolicyValueNet(GAME_WIDTH,
                                              GAME_HEIGHT,
                                              GAME_ACTIONS_NUM,
                                              model_file=model_file)
            player = MCTSPlayer(policy_value_net.policy_value_fn,
                                c_puct=cpuct,
                                n_playout=self.n_playout)

            _data = {
                "steps": [],
                "shapes": [],
                "last_state": 0,
                "score": 0,
                "piece_count": 0
            }
            # game = copy.deepcopy(agent)
            game = Agent(isRandomNextPiece=False)

            if game_num == 1 or game_num == max_game_num:
                game.show_mcts_process = True

            piece_idx = []

            for i in count():
                _step = {"step": i}
                _step["state"] = game.current_state()
                _step["piece_count"] = game.piececount
                _step["shape"] = game.fallpiece["shape"]
                _step["piece_height"] = game.pieceheight

                if game_num == 1:
                    action, move_probs = player.get_action(game,
                                                           temp=self.temp,
                                                           return_prob=1,
                                                           need_random=False)
                else:
                    action, move_probs = player.get_action(game,
                                                           temp=self.temp,
                                                           return_prob=1,
                                                           need_random=False)

                    if game.get_key() in game_keys:
                        print(game.steps, game.piececount,
                              game.fallpiece["shape"], game.piecesteps, "key:",
                              game.get_key(), "key_len:", len(game_keys))
                        action = random.choice(game.get_availables())

                _, reward = game.step(action)

                _step["key"] = game.get_key()
                # 这里不鼓励多行消除
                _step["reward"] = 1 if reward > 0 else 0
                _step["action"] = action
                _step["move_probs"] = move_probs

                _data["shapes"].append(_step["shape"])
                _data["steps"].append(_step)

                # 这里的奖励是消除的行数
                if reward > 0:
                    result = self.read_status_file(jsonfile)
                    if result["curr"]["height"] == 0:
                        result["curr"]["height"] = game.pieceheight
                    else:
                        result["curr"]["height"] = round(
                            result["curr"]["height"] * 0.99 +
                            game.pieceheight * 0.01, 2)
                    result["shapes"][_step["shape"]] += reward

                    # 如果是第一次奖励,记录当前的是第几个方块
                    if game.score == reward:
                        if result["first_reward"] == 0:
                            result["first_reward"] = game.piececount
                        else:
                            result["first_reward"] = result[
                                "first_reward"] * 0.99 + game.piececount * 0.01

                        # 如果第一次的奖励低于平均数,则将前面的几个方块也进行奖励
                        if game.piececount < result["first_reward"]:
                            for idx in piece_idx:
                                _data["steps"][idx]["reward"] = 0.5

                    json.dump(result, open(jsonfile, "w"), ensure_ascii=False)
                    print("#"*40, 'score:', game.score, 'height:', game.pieceheight, 'piece:', game.piececount, "shape:", game.fallpiece["shape"], \
                        'step:', i, "step time:", round((time.time()-start_time)/i,3), "#"*40)

                # 记录当前的方块放置的 idx
                if game.state != 0:
                    piece_idx.append(i)

                # 方块的个数越多越好
                if game.terminal or (reward > 0 and game.pieceheight > 8):
                    _game_last_reward = 0  # game.getNoEmptyCount()/200.
                    _data["reward"] = _game_last_reward
                    _data["score"] = game.score
                    _data["piece_count"] = game.piececount

                    # 更新状态
                    game_reward = _game_last_reward + game.score

                    result = self.read_status_file(jsonfile)
                    if result["QVal"] == 0:
                        result["QVal"] = game_reward
                    else:
                        result["QVal"] = result[
                            "QVal"] * 0.999 + game_reward * 0.001
                    paytime = time.time() - start_time
                    steptime = paytime / game.steps
                    if result["time"]["agent_time"] == 0:
                        result["time"]["agent_time"] = paytime
                        result["time"]["step_time"] = steptime
                    else:
                        result["time"]["agent_time"] = round(
                            result["time"]["agent_time"] * 0.99 +
                            paytime * 0.01, 3)
                        d = game.steps / 10000.0
                        if d > 1: d = 0.99
                        result["time"]["step_time"] = round(
                            result["time"]["step_time"] * (1 - d) +
                            steptime * d, 3)

                    # 记录当前cpuct的统计结果
                    if str(cpuct) in result["cpuct"]:
                        result["cpuct"][str(cpuct)] = result["cpuct"][str(
                            cpuct)] * 0.99 + game_reward * 0.01

                    if game_reward > result["best"]["reward"]:
                        result["best"]["reward"] = game_reward
                        result["best"]["pieces"] = game.piececount
                        result["best"]["score"] = game.score
                        result["best"]["agent"] = result["agent"] + agentcount

                    result["agent"] += 1
                    result["curr"]["reward"] += game.score
                    result["curr"]["pieces"] += game.piececount
                    result["curr"]["agent1000"] += 1
                    result["curr"]["agent100"] += 1
                    json.dump(result, open(jsonfile, "w"), ensure_ascii=False)

                    game.print()
                    print(game_num, 'reward:', game.score, "Qval:",
                          game_reward, 'len:', i, "piececount:",
                          game.piececount, "time:",
                          time.time() - start_time)
                    print("pay:", time.time() - start_time, "s\n")
                    agentcount += 1
                    agentscore += game.score
                    agentreward += game_reward
                    piececount += game.piececount

                    break

            for step in _data["steps"]:
                if not step["key"] in game_keys:
                    game_keys.append(step["key"])

            game_datas.append(_data)

            borads.append(game.board)

            # 如果训练样本超过10000,则停止训练
            if len(game_keys) > 10000: break

            # 如果训练次数超过了最大次数,则直接终止训练
            if game_num >= max_game_num: break

        # 打印borad:
        from game import blank
        for y in range(agent.height):
            line = ""
            for b in borads:
                line += "| "
                for x in range(agent.width):
                    if b[x][y] == blank:
                        line += "  "
                    else:
                        line += "%s " % b[x][y]
            print(line)
        print((" " + " -" * agent.width + " ") * len(borads))

        ## 放弃 按0.50的衰减更新reward
        # 只关注最后一次得分方块的所有步骤,将消行方块的所有步骤的得分都设置为1
        for data in game_datas:
            step_count = len(data["steps"])
            piece_count = -1
            v = 0
            vlist = []
            for i in range(step_count - 1, -1, -1):
                if piece_count != data["steps"][i]["piece_count"]:
                    piece_count = data["steps"][i]["piece_count"]
                    v = data["steps"][i][
                        "reward"]  # 0.5*v+data["steps"][i]["reward"]
                    if v > 1: v = 1
                    vlist.insert(0, v)
                data["steps"][i]["reward"] = v
            print(vlist)

        # 总得分为 消行奖励  + (本局消行奖励-平均每局消行奖励/平均每局消行奖励)
        # for data in game_datas:
        #     step_count = len(data["steps"])
        #     weight = (data["score"]-result["QVal"])/result["QVal"]
        #     for i in range(step_count):
        #         # if data["steps"][i]["reward"] < 1:
        #         v = data["steps"][i]["reward"] + weight
        #             # if v>1: v=1
        #         data["steps"][i]["reward"] = v

        # print("fixed reward")
        # for data in game_datas:
        #     step_count = len(data["steps"])
        #     piece_count = -1
        #     vlist=[]
        #     for i in range(step_count):
        #         if piece_count!=data["steps"][i]["piece_count"]:
        #             piece_count = data["steps"][i]["piece_count"]
        #             vlist.append(data["steps"][i]["reward"])
        #     print("score:", data["score"], "piece_count:", data["piece_count"],  [round(num, 2) for num in vlist])

        # 状态    概率      本步表现 本局奖励
        states, mcts_probs, values, score = [], [], [], []

        for data in game_datas:
            for step in data["steps"]:
                states.append(step["state"])
                mcts_probs.append(step["move_probs"])
                values.append(step["reward"])
                score.append(data["score"])

        # # 用于统计shape的std
        # pieces_idx={"t":[], "i":[], "j":[], "l":[], "s":[], "z":[], "o":[]}

        # var_keys = set()

        # for data in game_datas:
        #     for shape in set(data["shapes"]):
        #         var_keys.add(shape)
        # step_key_name = "shape"

        # for key in var_keys:
        #     _states, _mcts_probs, _values = [], [], []
        #     # _pieces_idx={"t":[], "i":[], "j":[], "l":[], "s":[], "z":[], "o":[]}
        #     for data in game_datas:
        #         for step in data["steps"]:
        #             if step[step_key_name]!=key: continue
        #             _states.append(step["state"])
        #             _mcts_probs.append(step["move_probs"])
        #             _values.append(step["reward"])
        #             # _pieces_idx[step["shape"]].append(len(values)+len(_values)-1)

        #     if len(_values)==0: continue

        #     # 重新计算
        #     curr_avg_value = sum(_values)/len(_values)
        #     curr_std_value = np.std(_values)
        #     if curr_std_value<0.01: continue

        #     # for shape in _pieces_idx:
        #     #     pieces_idx[shape].extend(_pieces_idx[shape])

        #     _normalize_vals = []
        #     # 用正态分布的方式重新计算
        #     curr_std_value_fix = curr_std_value + 1e-8 # * (2.0**0.5) # curr_std_value / result["vars"]["std"]
        #     for v in _values:
        #         #标准化的标准差为 (x-μ)/(σ/std), std 为 1 # 1/sqrt(2)
        #         _nv = (v-curr_avg_value)/curr_std_value_fix
        #         if _nv <-1 : _nv = -1
        #         if _nv >1  : _nv = 1
        #         if _nv == 0: _nv = 1e-8
        #         _normalize_vals.append(_nv)

        #     # 将最好的一步的值设置为1
        #     # max_normalize_val = max(_normalize_vals)-1
        #     # for i in range(len(_normalize_vals)):
        #     #     _normalize_vals[i] -= max_normalize_val

        #     print(key, len(_normalize_vals), "max:", max(_normalize_vals), "min:", min(_normalize_vals), "std:", curr_std_value)

        #     states.extend(_states)
        #     mcts_probs.extend(_mcts_probs)
        #     values.extend(_normalize_vals)
        #     result["vars"]["max"] = result["vars"]["max"]*0.999 + max(_normalize_vals)*0.001
        #     result["vars"]["min"] = result["vars"]["min"]*0.999 + min(_normalize_vals)*0.001
        #     result["vars"]["avg"] = result["vars"]["avg"]*0.999 + np.average(_normalize_vals)*0.001
        #     result["vars"]["std"] = result["vars"]["std"]*0.999 + np.std(_normalize_vals)*0.001
        #     # _states, _mcts_probs, _values = [], [], []

        # # if result["vars"]["max"]>1 or result["vars"]["min"]<-1:
        # #     result["vars"]["std"] = round(result["vars"]["std"]-0.0001,4)
        # # else:
        # #     result["vars"]["std"] = round(result["vars"]["std"]+0.0001,4)

        # json.dump(result, open(jsonfile,"w"), ensure_ascii=False)

        assert len(states) > 0
        assert len(states) == len(values)
        assert len(states) == len(mcts_probs)

        print("TRAIN Self Play end. length:%s value sum:%s saving ..." %
              (len(states), sum(values)))

        # 保存对抗数据到data_buffer
        for obj in self.get_equi_data(states, mcts_probs, values, score):
            filename = "{}.pkl".format(uuid.uuid1())
            savefile = os.path.join(data_wait_dir, filename)
            pickle.dump(obj, open(savefile, "wb"))

        # 打印shape的标准差
        # for shape in pieces_idx:
        #     test_data=[]
        #     for i in pieces_idx[shape]:
        #         if i>=(len(values)): break
        #         test_data.append(values[i])
        #     if len(test_data)==0: continue
        #     print(shape, "len:", len(test_data), "max:", max(test_data), "min:", min(test_data), "std:", np.std(test_data))

        result = self.read_status_file(jsonfile)
        if result["curr"]["agent100"] > 100:
            result["reward"].append(
                round(result["curr"]["reward"] / result["curr"]["agent1000"],
                      2))
            result["pieces"].append(
                round(result["curr"]["pieces"] / result["curr"]["agent1000"],
                      2))
            result["qvals"].append(round(result["QVal"], 2))
            result["height"].append(result["curr"]["height"])
            result["time"]["step_times"].append(result["time"]["step_time"])
            result["curr"]["agent100"] -= 100
            while len(result["reward"]) > 200:
                result["reward"].remove(result["reward"][0])
            while len(result["pieces"]) > 200:
                result["pieces"].remove(result["pieces"][0])
            while len(result["qvals"]) > 200:
                result["qvals"].remove(result["qvals"][0])
            while len(result["height"]) > 200:
                result["height"].remove(result["height"][0])
            while len(result["time"]["step_times"]) > 200:
                result["time"]["step_times"].remove(
                    result["time"]["step_times"][0])

            # 每100局更新一次cpuct参数
            qval = result["QVal"]
            # cpuct表示概率的可信度
            if result["cpuct"][cpuct_list[0]] > result["cpuct"][cpuct_list[1]]:
                cpuct = round(float(cpuct_list[0]) - 0.01, 2)
                if cpuct <= 0.01:
                    result["cpuct"] = {"0.01": qval, "1.01": qval}
                else:
                    result["cpuct"] = {
                        str(cpuct): qval,
                        str(round(cpuct + 1, 2)): qval
                    }
            else:
                cpuct = round(float(cpuct_list[0]) + 0.01, 2)
                result["cpuct"] = {
                    str(cpuct): qval,
                    str(round(cpuct + 1, 2)): qval
                }

            if max(result["reward"]) == result["reward"][-1]:
                newmodelfile = model_file + "_reward_" + str(
                    result["reward"][-1])
                if not os.path.exists(newmodelfile):
                    policy_value_net.save_model(newmodelfile)

        if result["curr"]["agent1000"] > 1000:
            result["curr"] = {
                "reward": 0,
                "pieces": 0,
                "agent1000": 0,
                "agent100": 0,
                "height": 0
            }

            newmodelfile = model_file + "_" + str(result["agent"])
            if not os.path.exists(newmodelfile):
                policy_value_net.save_model(newmodelfile)
        result["lastupdate"] = datetime.datetime.now().strftime(
            '%Y-%m-%d %H:%M:%S')
        json.dump(result, open(jsonfile, "w"), ensure_ascii=False)
Esempio n. 6
0
class FiveChessPlay():
    def __init__(self):
        self.policy_evaluate_size = 100  # 策略评估胜率时的模拟对局次数
        self.batch_size = 512  # 训练一批数据的长度
        self.max_keep_size = 500000  # 保留最近对战样本个数 平均一局大约400~600个样本, 也就是包含了最近1000次对局数据

        # 训练参数
        self.learn_rate = 1e-4
        self.lr_multiplier = 1.0  # 基于KL的自适应学习率
        self.temp = 1  # 概率缩放程度,实际预测0.01,训练采用1
        self.n_playout = 1000  # 每个动作的模拟次数
        self.play_batch_size = 1  # 每次自学习次数
        self.epochs = 1  # 重复训练次数, 推荐是5
        self.kl_targ = 0.02  # 策略价值网络KL值目标

        # 纯MCTS的模拟数,用于评估策略模型
        self.pure_mcts_playout_num = 1000  # 用户纯MCTS构建初始树时的随机走子步数
        self.c_puct = 4  # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5
        self.mcts_win = [0, 0]  # 和纯MCTS对战胜率
        self.best_win = [0, 0]  # 和历史最佳模型对战胜率

        if os.path.exists(model_file):
            # 使用一个训练好的策略价值网络
            self.policy_value_net = PolicyValueNet(size, model_file=model_file)
        else:
            # 使用一个新的的策略价值网络
            self.policy_value_net = PolicyValueNet(size)
        self.best_policy_value_net = None

        # 保存历史最佳模型赢的次数,赢的越高,越要继续对战
        self.best_model_files_win = {}

    def save_wait_data(self, obj):
        filename = "{}.pkl".format(uuid.uuid1())
        savefile = os.path.join(data_wait_dir, filename)
        pickle.dump(obj, open(savefile, "wb"))

    def get_equi_data(self, play_data):
        """
        通过旋转和翻转增加数据集
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            mcts_porb = mcts_porb.reshape(size, size)
            for i in [1, 2, 3, 4]:
                # 逆时针旋转
                equi_state = np.array([np.rot90(s, i) for s in state])
                # equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(size, size)), i)
                equi_mcts_prob = np.rot90(mcts_porb, i)
                # extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                extend_data.append(
                    (equi_state, equi_mcts_prob.flatten(), winner))
                # 水平翻转
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                # extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                extend_data.append(
                    (equi_state, equi_mcts_prob.flatten(), winner))
        return extend_data

    def collect_selfplay_data(self):
        """收集自我对抗数据用于训练"""
        # 使用MCTS蒙特卡罗树搜索进行自我对抗
        logging.info("TRAIN Self Play starting ...")
        agent = Agent(size, n_in_row, is_shown=0)
        # 创建使用策略价值网络来指导树搜索和评估叶节点的MCTS玩家
        mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                 c_puct=self.c_puct,
                                 n_playout=self.n_playout,
                                 is_selfplay=0)

        files = os.listdir(model_dir)
        his_best_model_files = []
        his_best_model_weights = []
        for file in files:
            if file.startswith("best_model_15_5.pth."):
                if file not in self.best_model_files_win:
                    self.best_model_files_win[file] = 0
                his_best_model_files.append(file)
                his_best_model_weights.append(self.best_model_files_win[file])

        weights_min = min(his_best_model_weights)
        weights_max = max(his_best_model_weights)
        for i in range(len(his_best_model_weights)):
            if weights_max == weights_min:
                his_best_model_weights[i] = 1. / len(his_best_model_weights)
            else:
                his_best_model_weights[i] = 1.0 * (
                    his_best_model_weights[i] - weights_min) / (weights_max -
                                                                weights_min)

        curr_best_model_file = random.choices(
            his_best_model_files, weights=his_best_model_weights)[0]
        print(self.best_model_files_win)
        print("loading", curr_best_model_file)
        curr_best_policy_value_net = PolicyValueNet(size,
                                                    model_file=os.path.join(
                                                        model_dir,
                                                        curr_best_model_file))
        his_best_mcts_player = MCTSPlayer(
            curr_best_policy_value_net.policy_value_fn,
            c_puct=self.c_puct,
            n_playout=self.n_playout,
            is_selfplay=0)

        his_best_mcts_player.mcts._limit_max_var = False
        mcts_player.mcts._limit_max_var = False

        # 有一定几率和纯MCTS对抗
        # r = random.random()
        # if r>0.5:
        # pure_mcts_player = MCTSPurePlayer(c_puct=self.c_puct, n_playout=self.pure_mcts_playout_num)
        # print("AI VS MCTS, pure_mcts_playout_num:", self.pure_mcts_playout_num)
        # else:
        #     pure_mcts_player = None

        # 开始下棋
        winner, play_data = agent.start_self_play(mcts_player,
                                                  his_best_mcts_player,
                                                  temp=self.temp)

        if not his_best_mcts_player is None:
            if winner == mcts_player.player:
                self.mcts_win[0] = self.mcts_win[0] + 1
                # self.pure_mcts_playout_num=min(2000, self.pure_mcts_playout_num+100)
                print("Curr Model Win!", "win:", self.mcts_win[0], "lost",
                      self.mcts_win[1], "playout_num",
                      self.pure_mcts_playout_num)
            if winner == his_best_mcts_player.player:
                self.mcts_win[1] = self.mcts_win[1] + 1
                self.pure_mcts_playout_num = max(
                    500, self.pure_mcts_playout_num - 100)
                print("Curr Model Lost!", "win:", self.mcts_win[0], "lost",
                      self.mcts_win[1], "playout_num",
                      self.pure_mcts_playout_num)
        agent.game.print()

        play_data = list(play_data)[:]

        if winner == his_best_mcts_player.player:
            self.best_model_files_win[curr_best_model_file] += 1
        if winner == mcts_player.player:
            self.best_model_files_win[curr_best_model_file] -= 1

        # 采用翻转棋盘来增加样本数据集
        play_data = self.get_equi_data(play_data)
        logging.info("Self Play end. length:%s saving ..." % len(play_data))
        # 保存训练数据
        for obj in play_data:
            self.save_wait_data(obj)

        return play_data[-1]

    def policy_evaluate(self):
        """
        策略胜率评估:当前模型与最佳模型对战n局看胜率
        """
        # 如果不存在最佳模型,直接将当前模型保存为最佳模型
        if not os.path.exists(best_model_file):
            self.policy_value_net.save_model(best_model_file)
            return

        # 当前训练好的模型
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        if self.best_policy_value_net is None:
            self.best_policy_value_net = PolicyValueNet(
                size, model_file=best_model_file)
        best_mcts_player = MCTSPlayer(
            self.best_policy_value_net.policy_value_fn,
            c_puct=self.c_puct,
            n_playout=self.n_playout)

        current_mcts_player.mcts._limit_max_var = False
        best_mcts_player.mcts._limit_max_var = False

        agent = Agent(size, n_in_row, is_shown=0)
        winner, play_data = agent.start_self_evaluate(
            current_mcts_player,
            best_mcts_player,
            temp=self.temp,
            start_player=sum(self.best_win) % 2)
        if winner == current_mcts_player.player:
            self.best_win[0] = self.best_win[0] + 1
            print("Curr Model Win!", "win:", self.best_win[0], "lost",
                  self.best_win[1])
        if winner == best_mcts_player.player:
            self.best_win[1] = self.best_win[1] + 1
            print("Curr Model Lost!", "win:", self.best_win[0], "lost",
                  self.best_win[1])
        agent.game.print()

        # 保存训练数据
        play_data = list(play_data)[:]
        play_data = self.get_equi_data(play_data)
        logging.info("Eval Play end. length:%s saving ..." % len(play_data))
        for obj in play_data:
            self.save_wait_data(obj)

    def run(self):
        """启动训练"""
        try:
            # 先训练样本100000局
            for i in range(100000):
                logging.info(
                    "TRAIN Batch:{} starting, Size:{}, n_in_row:{}".format(
                        i, size, n_in_row))

                # 有 0.2 的概率中间插入一局和历史最佳模型对战样本
                if random.random() > 0.8:
                    state, mcts_porb, winner = self.collect_selfplay_data()
                    if i == 0:
                        print("-" * 50, "state", "-" * 50)
                        print(state)
                        print("-" * 50, "mcts_porb", "-" * 50)
                        print(mcts_porb)
                        print("-" * 50, "winner", "-" * 50)
                        print(winner)

                self.policy_evaluate()

                rate_of_winning = 0.6
                if (i + 1) % self.policy_evaluate_size == 0 or self.best_win[
                        1] > (self.policy_evaluate_size *
                              (1 - rate_of_winning)):
                    # if self.mcts_win[0]>self.mcts_win[1]:
                    #     self.pure_mcts_playout_num=self.pure_mcts_playout_num+50
                    # if self.mcts_win[0]<self.mcts_win[1]:
                    #     self.pure_mcts_playout_num=self.pure_mcts_playout_num-50
                    # self.mcts_win=[0, 0]

                    # 如果当前模型的胜率大于等于0.6,保留为最佳模型
                    v = 1.0 * self.best_win[0] / self.policy_evaluate_size
                    if v >= rate_of_winning:
                        t = os.path.getctime(best_model_file)
                        timeStruct = time.localtime(t)
                        timestr = time.strftime('%Y_%m_%d_%H_%M', timeStruct)
                        os.rename(best_model_file,
                                  best_model_file + "." + timestr)
                        self.policy_value_net.save_model(best_model_file)
                        self.best_policy_value_net = None
                        print("save curr modle to best model")
                    else:
                        print("curr:", v, "< 0.65, keep best model")

                    self.best_win = [0, 0]
                    self.policy_value_net = PolicyValueNet(
                        size, model_file=model_file)

            # 一轮训练完毕后与最佳模型进行对比
            # # 如果输了,再训练一次
            # if win_ratio<=0.5:
            #     self.policy_evaluate(self.policy_evaluate_size)
            #     print("lost all, add more sample")
        except KeyboardInterrupt:
            logging.info('quit')