Esempio n. 1
0
class FiveChessTrain():
    def __init__(self):
        self.policy_evaluate_size = 20  # 策略评估胜率时的模拟对局次数
        self.batch_size = 256  # 训练一批数据的长度
        self.max_keep_size = 500000  # 保留最近对战样本个数 平均一局大约400~600个样本, 也就是包含了最近1000次对局数据

        # 训练参数
        self.learn_rate = 1e-5
        self.lr_multiplier = 1.0  # 基于KL的自适应学习率
        self.temp = 1  # 概率缩放程度,实际预测0.01,训练采用1
        self.n_playout = 600  # 每个动作的模拟次数
        self.play_batch_size = 1 # 每次自学习次数
        self.epochs = 1  # 重复训练次数, 推荐是5
        self.kl_targ = 0.02  # 策略价值网络KL值目标
        
        # 纯MCTS的模拟数,用于评估策略模型
        self.pure_mcts_playout_num = 4000 # 用户纯MCTS构建初始树时的随机走子步数
        self.c_puct = 4  # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5

        if os.path.exists(model_file):
            # 使用一个训练好的策略价值网络
            self.policy_value_net = PolicyValueNet(size, model_file=model_file)
        else:
            # 使用一个新的的策略价值网络
            self.policy_value_net = PolicyValueNet(size)

        print("start data loader")
        self.dataset = Dataset(data_dir, self.max_keep_size)
        print("dataset len:",len(self.dataset),"index:",self.dataset.index)
        print("end data loader")

    def policy_update(self, sample_data, epochs=1):
        """更新策略价值网络policy-value"""
        # 训练策略价值网络
        state_batch, mcts_probs_batch, winner_batch = sample_data

        # old_probs, old_v = self.policy_value_net.policy_value(state_batch)  
        for i in range(epochs):
            loss, v_loss, p_loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier)
            # new_probs, new_v = self.policy_value_net.policy_value(state_batch)

            # 散度计算:
            # D(P||Q) = sum( pi * log( pi / qi) ) = sum( pi * (log(pi) - log(qi)) )
            # kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1))
            # if kl > self.kl_targ * epochs:  # 如果D_KL跑偏则尽早停止
            #     break

        # 自动调整学习率
        # if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
        #     self.lr_multiplier /= 1.5
        # elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
        #     self.lr_multiplier *= 1.5
        # 如果学习到了,explained_var 应该趋近于 1,如果没有学习到也就是胜率都为很小值时,则为 0
        # explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch)))
        # explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch)))
        # entropy 信息熵,越小越好
        # logging.info(("TRAIN kl:{:.5f},lr_multiplier:{:.3f},v_loss:{:.5f},p_loss:{:.5f},entropy:{:.5f},var_old:{:.5f},var_new:{:.5f}"
        #               ).format(kl, self.lr_multiplier, v_loss, p_loss, entropy, explained_var_old, explained_var_new))
        return loss, v_loss, p_loss, entropy

    def run(self):
        """启动训练"""
        try:
            dataset_len = len(self.dataset)      
            training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=4,)
            old_probs = None
            test_batch = None
            for i, data in enumerate(training_loader):  # 计划训练批次
                loss, v_loss, p_loss, entropy = self.policy_update(data, self.epochs)              
                if (i+1) % 10 == 0:
                    logging.info(("TRAIN idx {} : {} / {} v_loss:{:.5f}, p_loss:{:.5f}, entropy:{:.5f}")\
                        .format(i, i*self.batch_size, dataset_len, v_loss, p_loss, entropy))
                    
                    # 动态调整学习率
                    if old_probs is None:
                        test_batch, _, _ = data
                        old_probs, _ = self.policy_value_net.policy_value(test_batch) 
                    else:
                        new_probs, _ = self.policy_value_net.policy_value(test_batch)
                        kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1))
                        old_probs = None
        
                        if kl > self.kl_targ * 2:
                            self.lr_multiplier /= 1.5
                        elif kl < self.kl_targ / 2 and self.lr_multiplier < 100:
                            self.lr_multiplier *= 1.5
                        else:
                            continue
                        logging.info("kl:{} lr_multiplier:{} lr:{}".format(kl, self.lr_multiplier, self.learn_rate*self.lr_multiplier))

                    # logging.info("Train idx {} : {} / {}".format(i, i*self.batch_size, len(self.dataset)))
            self.policy_value_net.save_model(model_file)
        except KeyboardInterrupt:
            logging.info('quit')
Esempio n. 2
0
class TrainPipeline():
    def __init__(self, mol=None, init_model=None):
        # params of the board and the game
        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 30  # num of simulations for each move
        self.c_puct = 1
        self.buffer_size = 200
        self.batch_size = 200  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.epochs = 50  # num of train_steps for each update
        self.kl_targ = 0.2
        self.check_freq = 5
        self.mol = mol
        self.play_batch_size = 1
        self.game_batch_num = 15
        self.in_dim = 1024
        self.n_hidden_1 = 1024
        self.n_hidden_2 = 1024
        self.out_dim = 1
        self.output_smi = []
        self.output_qed = []
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.in_dim,
                                                   self.n_hidden_1,
                                                   self.n_hidden_2,
                                                   self.out_dim,
                                                   model_file=init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.in_dim,
                                                   self.n_hidden_1,
                                                   self.n_hidden_2,
                                                   self.out_dim)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            play_data = start_self_play(self.mcts_player,
                                        self.mol,
                                        temp=self.temp)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            print(self.episode_len)
            # augment the data
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        # mini_batch = random.sample(self.data_buffer, self.batch_size)
        # state_batch = [data[0] for data in mini_batch]
        # mcts_probs_batch = [data[1] for data in mini_batch]
        # old_probs = self.policy_value_net.policy_value(state_batch)

        for i in range(self.epochs):
            mini_batch = random.sample(self.data_buffer, self.batch_size)
            state_batch = [data[0] for data in mini_batch]
            mcts_probs_batch = [data[1] for data in mini_batch]
            old_probs = self.policy_value_net.policy_value(state_batch)
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs = self.policy_value_net.policy_value(state_batch)

            kl = np.mean(
                np.sum(
                    old_probs *
                    (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10))))
            #if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
            #    print("early stopping!!")
            #    break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},").format(
                   kl,
                   self.lr_multiplier,
                   loss,
                   entropy,
               ))
        return loss, entropy

    def policy_evaluate(self):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        player = MCTSPlayer(self.policy_value_net.policy_value,
                            c_puct=self.c_puct,
                            n_playout=30)
        environment = Molecule(["C", "O", "N"],
                               init_mol=self.mol,
                               allow_removal=True,
                               allow_no_modification=False,
                               allow_bonds_between_rings=False,
                               allowed_ring_sizes=[5, 6],
                               max_steps=10,
                               target_fn=None,
                               record_path=False)
        environment.initialize()
        environment.init_qed = QED.qed(Chem.MolFromSmiles(self.mol))

        moves, fp, _S_P, _Qs = player.get_action(environment,
                                                 temp=self.temp,
                                                 return_prob=1,
                                                 rand=False)

        return moves, _S_P, _Qs

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i: {}, episode_len: {}".format(
                    i + 1, self.episode_len))
                if len(self.data_buffer) >= self.batch_size:
                    loss, entropy = self.policy_update()
                    print("loss is {}  entropy is {}".format(loss, entropy))
                # check the performance of the current model,
                # and save the model params
                if (i + 1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i + 1))
                    move_list, _S_P, _Qs = self.policy_evaluate()
                    # self.policy_value_net.save_model('./current_policy.model')
                    print(move_list)
                    print(_Qs)
                    print(_S_P)

                    self.output_smi.extend(move_list)
                    o_qed = list(
                        map(lambda x: QED.qed(Chem.MolFromSmiles(x)),
                            move_list))
                    print(o_qed)
                    print("#" * 30)
                    self.output_qed.extend(o_qed)
        except KeyboardInterrupt:
            print('\n\rquit')
Esempio n. 3
0
class Train():
    def __init__(self):
        self.game_batch_num = 1000000  # selfplay对战次数
        self.batch_size = 512  # data_buffer中对战次数超过n次后开始启动模型训练

        # training params
        self.learn_rate = 1e-4
        self.lr_multiplier = 1.0  # 基于KL的自适应学习率
        self.temp = 1  # MCTS的概率参数,越大越不肯定,训练时1,预测时1e-3
        self.n_playout = 500  # 每个动作的模拟战记录个数
        self.play_batch_size = 5  # 每次自学习次数
        self.buffer_size = 500000  # cache对次数
        self.epochs = 2  # 每次更新策略价值网络的训练步骤数, 推荐是5
        self.kl_targ = 0.02  # 策略价值网络KL值目标
        self.best_win_ratio = 0.0

        self.c_puct = 5  # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5
        self.policy_value_net = PolicyValueNet(GAME_WIDTH,
                                               GAME_HEIGHT,
                                               GAME_ACTIONS_NUM,
                                               model_file=model_file)

    def collect_selfplay_data(self):
        """收集自我对抗数据用于训练"""
        print("TRAIN Self Play starting ...")
        # 游戏代理
        agent = Agent()

        # 开始下棋
        agentcount, reward, piececount, keys, play_data = agent.start_self_play(
            self.policy_value_net)

        play_data = list(play_data)[:]
        episode_len = len(play_data)

        print("TRAIN Self Play end. length:%s saving ..." % episode_len)
        # 保存对抗数据到data_buffer
        for i, obj in enumerate(play_data):
            filename = "{}.pkl".format(uuid.uuid1())
            savefile = os.path.join(data_wait_dir, filename)
            pickle.dump(obj, open(savefile, "wb"))

        jsonfile = os.path.join(data_dir, "result.json")
        if os.path.exists(jsonfile):
            result = json.load(open(jsonfile, "r"))
        else:
            result = {}
            result = {"agent": 0, "reward": [], "pieces": []}
            result["curr"] = {"reward": 0, "pieces": 0, "agent": 0}

        result["agent"] += agentcount
        result["curr"]["reward"] += reward
        result["curr"]["pieces"] += piececount
        result["curr"]["agent"] += agentcount

        agent = result["agent"]
        if agent % 100 == 0:
            result["reward"].append(
                round(result["curr"]["reward"] / result["curr"]["agent"], 2))
            result["pieces"].append(
                round(result["curr"]["pieces"] / result["curr"]["agent"], 2))
            if len(result["reward"]) > 100:
                result["reward"].remove(min(result["reward"]))
            if len(result["pieces"]) > 100:
                result["pieces"].remove(min(result["pieces"]))
        if result["curr"]["agent"] > 1000:
            result["curr"] = {"reward": 0, "pieces": 0, "agent": 0}

        json.dump(result, open(jsonfile, "w"), ensure_ascii=False)

    def policy_update(self, sample_data, epochs=1):
        """更新策略价值网络policy-value"""
        # 训练策略价值网络
        # 随机抽取data_buffer中的对抗数据
        # mini_batch = self.dataset.loadData(sample_data)
        state_batch, mcts_probs_batch, winner_batch = sample_data
        # # for x in mini_batch:
        # #     print("-----------------")
        # #     print(x)
        # # state_batch = [data[0] for data in mini_batch]
        # # mcts_probs_batch = [data[1] for data in mini_batch]
        # # winner_batch = [data[2] for data in mini_batch]

        # print(state_batch)

        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(epochs):
            loss, v_loss, p_loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)

            # 散度计算:
            # D(P||Q) = sum( pi * log( pi / qi) ) = sum( pi * (log(pi) - log(qi)) )
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # 如果D_KL跑偏则尽早停止
                break

        # 自动调整学习率
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5
        # 如果学习到了,explained_var 应该趋近于 1,如果没有学习到也就是胜率都为很小值时,则为 0
        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        # entropy 信息熵,越小越好
        print((
            "TRAIN kl:{:.5f},lr_multiplier:{:.3f},v_loss:{:.5f},p_loss:{:.5f},entropy:{:.5f},var_old:{:.5f},var_new:{:.5f}"
        ).format(kl, self.lr_multiplier, v_loss, p_loss, entropy,
                 explained_var_old, explained_var_new))
        return loss, entropy

    def run(self):
        """启动训练"""
        try:
            self.collect_selfplay_data()
        except KeyboardInterrupt:
            print('quit')
Esempio n. 4
0
class Train():
    def __init__(self):
        self.game_batch_num = 1000000  # selfplay对战次数
        self.batch_size = 512  # data_buffer中对战次数超过n次后开始启动模型训练

        # training params
        self.learn_rate = 1e-5
        self.lr_multiplier = 1.0  # 基于KL的自适应学习率
        self.temp = 1  # MCTS的概率参数,越大越不肯定,训练时1,预测时1e-3
        self.n_playout = 256  # 每个动作的模拟战记录个数
        self.play_batch_size = 5  # 每次自学习次数
        self.buffer_size = 300000  # cache对次数
        self.epochs = 2  # 每次更新策略价值网络的训练步骤数, 推荐是5
        self.kl_targ = 0.02  # 策略价值网络KL值目标
        self.best_win_ratio = 0.0

        self.c_puct = 0.1  # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5
        self.policy_value_net = PolicyValueNet(GAME_WIDTH,
                                               GAME_HEIGHT,
                                               GAME_ACTIONS_NUM,
                                               model_file=model_file)

    def get_equi_data(self, play_data):
        """
        通过翻转增加数据集
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            extend_data.append((state, mcts_porb, winner))
            # 水平翻转
            equi_state = np.array([np.fliplr(s) for s in state])
            equi_mcts_prob = mcts_porb[[0, 2, 1, 3]]
            extend_data.append((equi_state, equi_mcts_prob, winner))
        return extend_data

    def collect_selfplay_data(self):
        """收集自我对抗数据用于训练"""
        # 使用MCTS蒙特卡罗树搜索进行自我对抗
        logging.info("TRAIN Self Play starting ...")
        # 游戏代理
        agent = Agent()

        # 创建使用策略价值网络来指导树搜索和评估叶节点的MCTS玩家
        mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                 c_puct=self.c_puct,
                                 n_playout=self.n_playout,
                                 is_selfplay=1)
        for _ in range(3):
            # 开始下棋
            reward, piececount, agentcount, play_data = agent.start_self_play(
                mcts_player, temp=self.temp)
            play_data = list(play_data)[:]
            episode_len = len(play_data)

            # 把翻转棋盘数据加到数据集里
            # play_data = self.get_equi_data(play_data)
            logging.info("TRAIN Self Play end. length:%s saving ..." %
                         episode_len)
            # 保存对抗数据到data_buffer
            for obj in play_data:
                filename = "{}.pkl".format(uuid.uuid1())
                savefile = os.path.join(data_wait_dir, filename)
                pickle.dump(obj, open(savefile, "wb"))
                # self.dataset.save(obj)

            if agent.limit_max_height == 10:
                jsonfile = os.path.join(data_dir, "result.json")
                if os.path.exists(jsonfile):
                    result = json.load(open(jsonfile, "r"))
                else:
                    result = {"reward": 0, "steps": 0, "agent": 0}
                if "1k" not in result:
                    result["1k"] = {"reward": 0, "steps": 0, "agent": 0}
                result["reward"] = result["reward"] + reward
                result["steps"] = result["steps"] + piececount
                result["agent"] = result["agent"] + agentcount
                result["1k"]["reward"] = result["1k"]["reward"] + reward
                result["1k"]["steps"] = result["1k"]["steps"] + piececount
                result["1k"]["agent"] = result["1k"]["agent"] + agentcount

                if result["agent"] > 0 and result["agent"] % 100 <= 1:
                    result[str(result["agent"])] = {
                        "reward":
                        result["1k"]["reward"] / result["1k"]["agent"],
                        "steps": result["1k"]["steps"] / result["1k"]["agent"]
                    }

                if result["agent"] > 0 and result["agent"] % 1000 == 0:

                    # 额外保存
                    steps = round(result["1k"]["steps"] /
                                  result["1k"]["agent"])
                    model_file = os.path.join(model_dir,
                                              'model_%s.pth' % steps)
                    self.policy_value_net.save_model(model_file)

                    for key in list(result.keys()):
                        if key.isdigit():
                            c = int(key)
                            if c % 1000 > 10:
                                del result[key]
                    result["1k"] = {"reward": 0, "steps": 0, "agent": 0}

                json.dump(result, open(jsonfile, "w"), ensure_ascii=False)

            if reward >= 1: break

    def policy_update(self, sample_data, epochs=1):
        """更新策略价值网络policy-value"""
        # 训练策略价值网络
        # 随机抽取data_buffer中的对抗数据
        # mini_batch = self.dataset.loadData(sample_data)
        state_batch, mcts_probs_batch, winner_batch = sample_data
        # # for x in mini_batch:
        # #     print("-----------------")
        # #     print(x)
        # # state_batch = [data[0] for data in mini_batch]
        # # mcts_probs_batch = [data[1] for data in mini_batch]
        # # winner_batch = [data[2] for data in mini_batch]

        # print(state_batch)

        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(epochs):
            loss, v_loss, p_loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)

            # 散度计算:
            # D(P||Q) = sum( pi * log( pi / qi) ) = sum( pi * (log(pi) - log(qi)) )
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # 如果D_KL跑偏则尽早停止
                break

        # 自动调整学习率
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5
        # 如果学习到了,explained_var 应该趋近于 1,如果没有学习到也就是胜率都为很小值时,则为 0
        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        # entropy 信息熵,越小越好
        logging.info((
            "TRAIN kl:{:.5f},lr_multiplier:{:.3f},v_loss:{:.5f},p_loss:{:.5f},entropy:{:.5f},var_old:{:.5f},var_new:{:.5f}"
        ).format(kl, self.lr_multiplier, v_loss, p_loss, entropy,
                 explained_var_old, explained_var_new))
        return loss, entropy

    def run(self):
        """启动训练"""
        try:
            # print("start data loader")
            # self.dataset = Dataset(data_dir, self.buffer_size)
            # print("end data loader")

            # step = 0
            # # 如果训练数据一半都不到,就先攒训练数据
            # if self.dataset.curr_game_batch_num/self.dataset.buffer_size<0.5:
            #     for _ in range(8):
            #         logging.info("TRAIN Batch:{} starting".format(self.dataset.curr_game_batch_num,))
            #         # n_playout=self.n_playout
            #         # self.n_playout=8
            #         self.collect_selfplay_data()
            #         # self.n_playout=n_playout
            #         logging.info("TRAIN Batch:{} end".format(self.dataset.curr_game_batch_num,))
            #         step += 1

            # training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=2,)

            # for i, data in enumerate(training_loader):  # 计划训练批次
            #     # 使用对抗数据重新训练策略价值网络模型
            #     loss, entropy = self.policy_update(data, self.epochs)

            # self.policy_value_net.save_model(model_file)
            # 收集自我对抗数据
            # for _ in range(self.play_batch_size):
            self.collect_selfplay_data()
            # logging.info("TRAIN {} self-play end, size: {}".format(self.dataset.curr_game_batch_num, self.dataset.curr_size()))

        except KeyboardInterrupt:
            logging.info('quit')
Esempio n. 5
0
class Agent_MCTS(nn.Module):
    def __init__(self,args,share_model,opti,board_max,param,is_selfplay=True):
        super().__init__()
        self._is_selfplay=is_selfplay
        self.learn_rate = 5e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0 # the temperature param
        self.n_playout = 100 # num of simulations for each move
        self.c_puct = 5
        self.batch_size = 32 # mini-batch size for training
        self.play_batch_size = 1 
        self.epochs = 5 # num of train_steps for each update
        self.kl_targ = 0.025
        self.check_freq = 50 
        self.game_batch_num = 1500
        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 1000  
        
        self.policy_value_net = PolicyValueNet(board_max,board_max,net_params = param)
        self.mcts = MCTS(self.policy_value_net.policy_value_fn, self.c_puct, self.n_playout)
        
        
        self.batch_size = args.batch_size 
        self.discount = args.discount
        self.epsilon = args.epsilon
        self.action_space = args.action_space
        self.hidden_size = args.hidden_size
        self.state_space = args.state_space
        
#        self.main_dqn= DQN_model(args)
        
#        self.main_dqn.train()
#        self.target_dqn = DQN_rainbow(args)
#        self.target_dqn = target_model
#        self.target_dqn_update()
#        self.target_dqn.eval()
        
#        self.optimizer = optim.Adam(self.main_dqn.parameters(), lr=args.lr, eps=args.adam_eps)
        
      
    def reset_player(self):
        self.mcts.update_with_move(-1) 
        
    def save(self):
        print('save')
        torch.save(self.policy_value_net.policy_value_net.state_dict(),'./net_param')
        
        
#    def save(self,path ='./param.p'):
#        torch.save(self.main_dqn.state_dict(),path)
#        
#    def load(self,path ='./param.p'):
#        if os.path.exists(path):
#            self.main_dqn.load_state_dict(torch.load(path))
#        else :
#            print("file not exist")
    
#    def target_dqn_update(self):
#        self.target_dqn.parameter_update(self.main_dqn)
    
   
    def get_action(self, board, temp=1e-3, return_prob=0):
        sensible_moves = board.availables
        move_probs = np.zeros(board.width*board.height) # the pi vector returned by MCTS as in the alphaGo Zero paper
        if len(sensible_moves) > 0:
            acts, probs = self.mcts.get_move_probs(board, temp)
            move_probs[list(acts)] = probs         
            if self._is_selfplay:
                # add Dirichlet Noise for exploration (needed for self-play training)
                move = np.random.choice(acts, p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs))))    
                self.mcts.update_with_move(move) # update the root node and reuse the search tree
            else:
                # with the default temp=1e-3, thisZ is almost equivalent to choosing the move with the highest prob
                move = np.random.choice(acts, p=probs)       
                # reset the root node
                self.mcts.update_with_move(-1)             
#                location = board.move_to_location(move)
#                print("AI move: %d,%d\n" % (location[0], location[1]))
                
            return move, move_probs
        else:            
            print("WARNING: the board is full")

        
    
#    def train(self):
#        self.main_dqn.train()
#    def eval(self):
#        self.main_dqn.eval()
        
    
    def learn(self,data_buffer):
        """update the policy-value net"""
        mini_batch = random.sample(data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]            
        old_probs, old_v = self.policy_value_net.policy_value(state_batch) 
        for i in range(self.epochs): 
            loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate*self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1))  
            if kl > self.kl_targ * 4:   # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5
            
        explained_var_old =  1 - np.var(np.array(winner_batch) - old_v.flatten())/np.var(np.array(winner_batch))
        explained_var_new = 1 - np.var(np.array(winner_batch) - new_v.flatten())/np.var(np.array(winner_batch))        
        print("kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}".format(
                kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new))
        return loss, entropy