class FiveChessTrain(): def __init__(self): self.policy_evaluate_size = 20 # 策略评估胜率时的模拟对局次数 self.batch_size = 256 # 训练一批数据的长度 self.max_keep_size = 500000 # 保留最近对战样本个数 平均一局大约400~600个样本, 也就是包含了最近1000次对局数据 # 训练参数 self.learn_rate = 1e-5 self.lr_multiplier = 1.0 # 基于KL的自适应学习率 self.temp = 1 # 概率缩放程度,实际预测0.01,训练采用1 self.n_playout = 600 # 每个动作的模拟次数 self.play_batch_size = 1 # 每次自学习次数 self.epochs = 1 # 重复训练次数, 推荐是5 self.kl_targ = 0.02 # 策略价值网络KL值目标 # 纯MCTS的模拟数,用于评估策略模型 self.pure_mcts_playout_num = 4000 # 用户纯MCTS构建初始树时的随机走子步数 self.c_puct = 4 # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5 if os.path.exists(model_file): # 使用一个训练好的策略价值网络 self.policy_value_net = PolicyValueNet(size, model_file=model_file) else: # 使用一个新的的策略价值网络 self.policy_value_net = PolicyValueNet(size) print("start data loader") self.dataset = Dataset(data_dir, self.max_keep_size) print("dataset len:",len(self.dataset),"index:",self.dataset.index) print("end data loader") def policy_update(self, sample_data, epochs=1): """更新策略价值网络policy-value""" # 训练策略价值网络 state_batch, mcts_probs_batch, winner_batch = sample_data # old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(epochs): loss, v_loss, p_loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) # new_probs, new_v = self.policy_value_net.policy_value(state_batch) # 散度计算: # D(P||Q) = sum( pi * log( pi / qi) ) = sum( pi * (log(pi) - log(qi)) ) # kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) # if kl > self.kl_targ * epochs: # 如果D_KL跑偏则尽早停止 # break # 自动调整学习率 # if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: # self.lr_multiplier /= 1.5 # elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: # self.lr_multiplier *= 1.5 # 如果学习到了,explained_var 应该趋近于 1,如果没有学习到也就是胜率都为很小值时,则为 0 # explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) # explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) # entropy 信息熵,越小越好 # logging.info(("TRAIN kl:{:.5f},lr_multiplier:{:.3f},v_loss:{:.5f},p_loss:{:.5f},entropy:{:.5f},var_old:{:.5f},var_new:{:.5f}" # ).format(kl, self.lr_multiplier, v_loss, p_loss, entropy, explained_var_old, explained_var_new)) return loss, v_loss, p_loss, entropy def run(self): """启动训练""" try: dataset_len = len(self.dataset) training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=4,) old_probs = None test_batch = None for i, data in enumerate(training_loader): # 计划训练批次 loss, v_loss, p_loss, entropy = self.policy_update(data, self.epochs) if (i+1) % 10 == 0: logging.info(("TRAIN idx {} : {} / {} v_loss:{:.5f}, p_loss:{:.5f}, entropy:{:.5f}")\ .format(i, i*self.batch_size, dataset_len, v_loss, p_loss, entropy)) # 动态调整学习率 if old_probs is None: test_batch, _, _ = data old_probs, _ = self.policy_value_net.policy_value(test_batch) else: new_probs, _ = self.policy_value_net.policy_value(test_batch) kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) old_probs = None if kl > self.kl_targ * 2: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 100: self.lr_multiplier *= 1.5 else: continue logging.info("kl:{} lr_multiplier:{} lr:{}".format(kl, self.lr_multiplier, self.learn_rate*self.lr_multiplier)) # logging.info("Train idx {} : {} / {}".format(i, i*self.batch_size, len(self.dataset))) self.policy_value_net.save_model(model_file) except KeyboardInterrupt: logging.info('quit')
class TrainPipeline(): def __init__(self, mol=None, init_model=None): # params of the board and the game # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 30 # num of simulations for each move self.c_puct = 1 self.buffer_size = 200 self.batch_size = 200 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.epochs = 50 # num of train_steps for each update self.kl_targ = 0.2 self.check_freq = 5 self.mol = mol self.play_batch_size = 1 self.game_batch_num = 15 self.in_dim = 1024 self.n_hidden_1 = 1024 self.n_hidden_2 = 1024 self.out_dim = 1 self.output_smi = [] self.output_qed = [] # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.in_dim, self.n_hidden_1, self.n_hidden_2, self.out_dim, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.in_dim, self.n_hidden_1, self.n_hidden_2, self.out_dim) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): play_data = start_self_play(self.mcts_player, self.mol, temp=self.temp) play_data = list(play_data)[:] self.episode_len = len(play_data) print(self.episode_len) # augment the data self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" # mini_batch = random.sample(self.data_buffer, self.batch_size) # state_batch = [data[0] for data in mini_batch] # mcts_probs_batch = [data[1] for data in mini_batch] # old_probs = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] old_probs = self.policy_value_net.policy_value(state_batch) loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, self.learn_rate * self.lr_multiplier) new_probs = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum( old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)))) #if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly # print("early stopping!!") # break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 print(("kl:{:.5f}," "lr_multiplier:{:.3f}," "loss:{}," "entropy:{},").format( kl, self.lr_multiplier, loss, entropy, )) return loss, entropy def policy_evaluate(self): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ player = MCTSPlayer(self.policy_value_net.policy_value, c_puct=self.c_puct, n_playout=30) environment = Molecule(["C", "O", "N"], init_mol=self.mol, allow_removal=True, allow_no_modification=False, allow_bonds_between_rings=False, allowed_ring_sizes=[5, 6], max_steps=10, target_fn=None, record_path=False) environment.initialize() environment.init_qed = QED.qed(Chem.MolFromSmiles(self.mol)) moves, fp, _S_P, _Qs = player.get_action(environment, temp=self.temp, return_prob=1, rand=False) return moves, _S_P, _Qs def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print("batch i: {}, episode_len: {}".format( i + 1, self.episode_len)) if len(self.data_buffer) >= self.batch_size: loss, entropy = self.policy_update() print("loss is {} entropy is {}".format(loss, entropy)) # check the performance of the current model, # and save the model params if (i + 1) % self.check_freq == 0: print("current self-play batch: {}".format(i + 1)) move_list, _S_P, _Qs = self.policy_evaluate() # self.policy_value_net.save_model('./current_policy.model') print(move_list) print(_Qs) print(_S_P) self.output_smi.extend(move_list) o_qed = list( map(lambda x: QED.qed(Chem.MolFromSmiles(x)), move_list)) print(o_qed) print("#" * 30) self.output_qed.extend(o_qed) except KeyboardInterrupt: print('\n\rquit')
class Train(): def __init__(self): self.game_batch_num = 1000000 # selfplay对战次数 self.batch_size = 512 # data_buffer中对战次数超过n次后开始启动模型训练 # training params self.learn_rate = 1e-4 self.lr_multiplier = 1.0 # 基于KL的自适应学习率 self.temp = 1 # MCTS的概率参数,越大越不肯定,训练时1,预测时1e-3 self.n_playout = 500 # 每个动作的模拟战记录个数 self.play_batch_size = 5 # 每次自学习次数 self.buffer_size = 500000 # cache对次数 self.epochs = 2 # 每次更新策略价值网络的训练步骤数, 推荐是5 self.kl_targ = 0.02 # 策略价值网络KL值目标 self.best_win_ratio = 0.0 self.c_puct = 5 # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5 self.policy_value_net = PolicyValueNet(GAME_WIDTH, GAME_HEIGHT, GAME_ACTIONS_NUM, model_file=model_file) def collect_selfplay_data(self): """收集自我对抗数据用于训练""" print("TRAIN Self Play starting ...") # 游戏代理 agent = Agent() # 开始下棋 agentcount, reward, piececount, keys, play_data = agent.start_self_play( self.policy_value_net) play_data = list(play_data)[:] episode_len = len(play_data) print("TRAIN Self Play end. length:%s saving ..." % episode_len) # 保存对抗数据到data_buffer for i, obj in enumerate(play_data): filename = "{}.pkl".format(uuid.uuid1()) savefile = os.path.join(data_wait_dir, filename) pickle.dump(obj, open(savefile, "wb")) jsonfile = os.path.join(data_dir, "result.json") if os.path.exists(jsonfile): result = json.load(open(jsonfile, "r")) else: result = {} result = {"agent": 0, "reward": [], "pieces": []} result["curr"] = {"reward": 0, "pieces": 0, "agent": 0} result["agent"] += agentcount result["curr"]["reward"] += reward result["curr"]["pieces"] += piececount result["curr"]["agent"] += agentcount agent = result["agent"] if agent % 100 == 0: result["reward"].append( round(result["curr"]["reward"] / result["curr"]["agent"], 2)) result["pieces"].append( round(result["curr"]["pieces"] / result["curr"]["agent"], 2)) if len(result["reward"]) > 100: result["reward"].remove(min(result["reward"])) if len(result["pieces"]) > 100: result["pieces"].remove(min(result["pieces"])) if result["curr"]["agent"] > 1000: result["curr"] = {"reward": 0, "pieces": 0, "agent": 0} json.dump(result, open(jsonfile, "w"), ensure_ascii=False) def policy_update(self, sample_data, epochs=1): """更新策略价值网络policy-value""" # 训练策略价值网络 # 随机抽取data_buffer中的对抗数据 # mini_batch = self.dataset.loadData(sample_data) state_batch, mcts_probs_batch, winner_batch = sample_data # # for x in mini_batch: # # print("-----------------") # # print(x) # # state_batch = [data[0] for data in mini_batch] # # mcts_probs_batch = [data[1] for data in mini_batch] # # winner_batch = [data[2] for data in mini_batch] # print(state_batch) old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(epochs): loss, v_loss, p_loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) # 散度计算: # D(P||Q) = sum( pi * log( pi / qi) ) = sum( pi * (log(pi) - log(qi)) ) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # 如果D_KL跑偏则尽早停止 break # 自动调整学习率 if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 # 如果学习到了,explained_var 应该趋近于 1,如果没有学习到也就是胜率都为很小值时,则为 0 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) # entropy 信息熵,越小越好 print(( "TRAIN kl:{:.5f},lr_multiplier:{:.3f},v_loss:{:.5f},p_loss:{:.5f},entropy:{:.5f},var_old:{:.5f},var_new:{:.5f}" ).format(kl, self.lr_multiplier, v_loss, p_loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def run(self): """启动训练""" try: self.collect_selfplay_data() except KeyboardInterrupt: print('quit')
class Train(): def __init__(self): self.game_batch_num = 1000000 # selfplay对战次数 self.batch_size = 512 # data_buffer中对战次数超过n次后开始启动模型训练 # training params self.learn_rate = 1e-5 self.lr_multiplier = 1.0 # 基于KL的自适应学习率 self.temp = 1 # MCTS的概率参数,越大越不肯定,训练时1,预测时1e-3 self.n_playout = 256 # 每个动作的模拟战记录个数 self.play_batch_size = 5 # 每次自学习次数 self.buffer_size = 300000 # cache对次数 self.epochs = 2 # 每次更新策略价值网络的训练步骤数, 推荐是5 self.kl_targ = 0.02 # 策略价值网络KL值目标 self.best_win_ratio = 0.0 self.c_puct = 0.1 # MCTS child权重, 用来调节MCTS中 探索/乐观 的程度 默认 5 self.policy_value_net = PolicyValueNet(GAME_WIDTH, GAME_HEIGHT, GAME_ACTIONS_NUM, model_file=model_file) def get_equi_data(self, play_data): """ 通过翻转增加数据集 play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: extend_data.append((state, mcts_porb, winner)) # 水平翻转 equi_state = np.array([np.fliplr(s) for s in state]) equi_mcts_prob = mcts_porb[[0, 2, 1, 3]] extend_data.append((equi_state, equi_mcts_prob, winner)) return extend_data def collect_selfplay_data(self): """收集自我对抗数据用于训练""" # 使用MCTS蒙特卡罗树搜索进行自我对抗 logging.info("TRAIN Self Play starting ...") # 游戏代理 agent = Agent() # 创建使用策略价值网络来指导树搜索和评估叶节点的MCTS玩家 mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) for _ in range(3): # 开始下棋 reward, piececount, agentcount, play_data = agent.start_self_play( mcts_player, temp=self.temp) play_data = list(play_data)[:] episode_len = len(play_data) # 把翻转棋盘数据加到数据集里 # play_data = self.get_equi_data(play_data) logging.info("TRAIN Self Play end. length:%s saving ..." % episode_len) # 保存对抗数据到data_buffer for obj in play_data: filename = "{}.pkl".format(uuid.uuid1()) savefile = os.path.join(data_wait_dir, filename) pickle.dump(obj, open(savefile, "wb")) # self.dataset.save(obj) if agent.limit_max_height == 10: jsonfile = os.path.join(data_dir, "result.json") if os.path.exists(jsonfile): result = json.load(open(jsonfile, "r")) else: result = {"reward": 0, "steps": 0, "agent": 0} if "1k" not in result: result["1k"] = {"reward": 0, "steps": 0, "agent": 0} result["reward"] = result["reward"] + reward result["steps"] = result["steps"] + piececount result["agent"] = result["agent"] + agentcount result["1k"]["reward"] = result["1k"]["reward"] + reward result["1k"]["steps"] = result["1k"]["steps"] + piececount result["1k"]["agent"] = result["1k"]["agent"] + agentcount if result["agent"] > 0 and result["agent"] % 100 <= 1: result[str(result["agent"])] = { "reward": result["1k"]["reward"] / result["1k"]["agent"], "steps": result["1k"]["steps"] / result["1k"]["agent"] } if result["agent"] > 0 and result["agent"] % 1000 == 0: # 额外保存 steps = round(result["1k"]["steps"] / result["1k"]["agent"]) model_file = os.path.join(model_dir, 'model_%s.pth' % steps) self.policy_value_net.save_model(model_file) for key in list(result.keys()): if key.isdigit(): c = int(key) if c % 1000 > 10: del result[key] result["1k"] = {"reward": 0, "steps": 0, "agent": 0} json.dump(result, open(jsonfile, "w"), ensure_ascii=False) if reward >= 1: break def policy_update(self, sample_data, epochs=1): """更新策略价值网络policy-value""" # 训练策略价值网络 # 随机抽取data_buffer中的对抗数据 # mini_batch = self.dataset.loadData(sample_data) state_batch, mcts_probs_batch, winner_batch = sample_data # # for x in mini_batch: # # print("-----------------") # # print(x) # # state_batch = [data[0] for data in mini_batch] # # mcts_probs_batch = [data[1] for data in mini_batch] # # winner_batch = [data[2] for data in mini_batch] # print(state_batch) old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(epochs): loss, v_loss, p_loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) # 散度计算: # D(P||Q) = sum( pi * log( pi / qi) ) = sum( pi * (log(pi) - log(qi)) ) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # 如果D_KL跑偏则尽早停止 break # 自动调整学习率 if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 # 如果学习到了,explained_var 应该趋近于 1,如果没有学习到也就是胜率都为很小值时,则为 0 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) # entropy 信息熵,越小越好 logging.info(( "TRAIN kl:{:.5f},lr_multiplier:{:.3f},v_loss:{:.5f},p_loss:{:.5f},entropy:{:.5f},var_old:{:.5f},var_new:{:.5f}" ).format(kl, self.lr_multiplier, v_loss, p_loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def run(self): """启动训练""" try: # print("start data loader") # self.dataset = Dataset(data_dir, self.buffer_size) # print("end data loader") # step = 0 # # 如果训练数据一半都不到,就先攒训练数据 # if self.dataset.curr_game_batch_num/self.dataset.buffer_size<0.5: # for _ in range(8): # logging.info("TRAIN Batch:{} starting".format(self.dataset.curr_game_batch_num,)) # # n_playout=self.n_playout # # self.n_playout=8 # self.collect_selfplay_data() # # self.n_playout=n_playout # logging.info("TRAIN Batch:{} end".format(self.dataset.curr_game_batch_num,)) # step += 1 # training_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True, num_workers=2,) # for i, data in enumerate(training_loader): # 计划训练批次 # # 使用对抗数据重新训练策略价值网络模型 # loss, entropy = self.policy_update(data, self.epochs) # self.policy_value_net.save_model(model_file) # 收集自我对抗数据 # for _ in range(self.play_batch_size): self.collect_selfplay_data() # logging.info("TRAIN {} self-play end, size: {}".format(self.dataset.curr_game_batch_num, self.dataset.curr_size())) except KeyboardInterrupt: logging.info('quit')
class Agent_MCTS(nn.Module): def __init__(self,args,share_model,opti,board_max,param,is_selfplay=True): super().__init__() self._is_selfplay=is_selfplay self.learn_rate = 5e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 100 # num of simulations for each move self.c_puct = 5 self.batch_size = 32 # mini-batch size for training self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.025 self.check_freq = 50 self.game_batch_num = 1500 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 self.policy_value_net = PolicyValueNet(board_max,board_max,net_params = param) self.mcts = MCTS(self.policy_value_net.policy_value_fn, self.c_puct, self.n_playout) self.batch_size = args.batch_size self.discount = args.discount self.epsilon = args.epsilon self.action_space = args.action_space self.hidden_size = args.hidden_size self.state_space = args.state_space # self.main_dqn= DQN_model(args) # self.main_dqn.train() # self.target_dqn = DQN_rainbow(args) # self.target_dqn = target_model # self.target_dqn_update() # self.target_dqn.eval() # self.optimizer = optim.Adam(self.main_dqn.parameters(), lr=args.lr, eps=args.adam_eps) def reset_player(self): self.mcts.update_with_move(-1) def save(self): print('save') torch.save(self.policy_value_net.policy_value_net.state_dict(),'./net_param') # def save(self,path ='./param.p'): # torch.save(self.main_dqn.state_dict(),path) # # def load(self,path ='./param.p'): # if os.path.exists(path): # self.main_dqn.load_state_dict(torch.load(path)) # else : # print("file not exist") # def target_dqn_update(self): # self.target_dqn.parameter_update(self.main_dqn) def get_action(self, board, temp=1e-3, return_prob=0): sensible_moves = board.availables move_probs = np.zeros(board.width*board.height) # the pi vector returned by MCTS as in the alphaGo Zero paper if len(sensible_moves) > 0: acts, probs = self.mcts.get_move_probs(board, temp) move_probs[list(acts)] = probs if self._is_selfplay: # add Dirichlet Noise for exploration (needed for self-play training) move = np.random.choice(acts, p=0.75*probs + 0.25*np.random.dirichlet(0.3*np.ones(len(probs)))) self.mcts.update_with_move(move) # update the root node and reuse the search tree else: # with the default temp=1e-3, thisZ is almost equivalent to choosing the move with the highest prob move = np.random.choice(acts, p=probs) # reset the root node self.mcts.update_with_move(-1) # location = board.move_to_location(move) # print("AI move: %d,%d\n" % (location[0], location[1])) return move, move_probs else: print("WARNING: the board is full") # def train(self): # self.main_dqn.train() # def eval(self): # self.main_dqn.eval() def learn(self,data_buffer): """update the policy-value net""" mini_batch = random.sample(data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate*self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = 1 - np.var(np.array(winner_batch) - old_v.flatten())/np.var(np.array(winner_batch)) explained_var_new = 1 - np.var(np.array(winner_batch) - new_v.flatten())/np.var(np.array(winner_batch)) print("kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}".format( kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy