class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game self.board_width = 8 #6 self.board_height = 8 #6 self.n_in_row = 5 #4 self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) # training params self.learn_rate = 5e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move # c_puct是MCTS里用来控制exploration-exploit tradeoff的参数 # 这个参数越大的话MCTS搜索的过程中就偏向于均匀的探索,越小的话就偏向于直接选择访问次数多的分支 self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.025 # 检查当前策咯胜率的频率,当前设置为每50次训练后通过自我对弈评价当前策略 # 如果找到更优策略,则保存当前策咯模型 self.check_freq = 50 #训练迭代次数 self.game_batch_num = 1500 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy #每次训练蒙特卡洛树搜索的次数,初始化为1000(后续训练过程中会不断增加) self.pure_mcts_playout_num = 1000 if init_model: # start training from an initial policy-value net policy_param = pickle.load(open(init_model, 'rb')) self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, net_params=policy_param) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """ augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...]""" extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90( np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data_zip2list = list(play_data) # add by haward self.episode_len = len(play_data_zip2list) # augment the data play_data = self.get_equi_data(play_data_zip2list) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) # kl距离,也叫做相对熵,衡量的是相同事件空间里的两个概率分布的差异情况 kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate # 通过比较新旧两个神经网络输出的KL散度(信息增益)来控制学习率,使得学习率快死增加然后逐渐减少 if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = 1 - np.var( np.array(winner_batch) - old_v.flatten()) / np.var( np.array(winner_batch)) explained_var_new = 1 - np.var( np.array(winner_batch) - new_v.flatten()) / np.var( np.array(winner_batch)) print( "kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}" .format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing games against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 #计算赢率,获胜积一分,平局积0.5分,失败不计分,再以总积分除以总比赛次数 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print("batch i:{}, episode_len:{}".format( i + 1, self.episode_len)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() # check the performance of the current model and save the model params if (i + 1) % self.check_freq == 0: print("current self-play batch: {}".format(i + 1)) # 与纯蒙特卡洛树进行十局对弈,计算对弈胜率 win_ratio = self.policy_evaluate() net_params = self.policy_value_net.get_policy_param( ) # get model params pickle.dump( net_params, open('current_policy_8_8_5_new.model', 'wb'), pickle.HIGHEST_PROTOCOL) # save model param to file if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio pickle.dump( net_params, open('best_policy_8_8_5_new.model', 'wb'), pickle.HIGHEST_PROTOCOL) # update the best_policy #如果当前策咯价值网络胜率为1,则提高纯蒙特卡洛树的搜索次数,继续训练 if self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000: self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game #width of chessboard self.board_width = 8 #6 #10 #height of chessboard self.board_height = 8 #6 #10 #conditions for victory self.n_in_row = 5 #4 #5 self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) # training params self.learn_rate = 5e-3 #learning rate self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 #The number of maximum elements in the queue self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) # queue size self.play_batch_size = 1 # collect a set of data if it self-play once self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.025 #KL target #check frequency: evaluate the game and current AI model every 50 times of self-play #The evaluation method is to use the latest AI model and MCTs-pure AI (based on random roll out) to fight 10 rounds self.check_freq = 50 #50 self.game_batch_num = 200 #the number of training batches self.best_win_ratio = 0.0 #historical best winning rate # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 if init_model: # start training from an initial policy-value net #pickle.load(file)反序列化对象。将文件中的数据解析为一个pytorch对象 policy_param = pickle.load(open(init_model, 'rb')) #使用‘rb’按照二进制位读取 self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, net_params = policy_param) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """ augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...]""" extend_data = [] for state, mcts_porb, winner in play_data: for i in [1,2,3,4]: # rotate counterclockwise equi_state = np.array([np.rot90(s,i) for s in state]) equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data_zip2list = list(play_data) # add by haward self.episode_len = len(play_data_zip2list) # augment the data play_data = self.get_equi_data(play_data_zip2list) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate*self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = 1 - np.var(np.array(winner_batch) - old_v.flatten())/np.var(np.array(winner_batch)) explained_var_new = 1 - np.var(np.array(winner_batch) - new_v.flatten())/np.var(np.array(winner_batch)) print("kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}".format( kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10,batch=0): """ Evaluate the trained policy by playing games against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i%2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1])/n_games print("batch_i:{}, num_playouts:{}, win: {}, lose: {}, tie:{}".format(batch, self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) logging.debug("batch_i {} num_playouts {} win {} lose {} tie {}".format(batch, self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size)# collect a set of data if it self-play once print("batch_i:{}, episode_len:{}".format(i+1, self.episode_len)) #logging.debug("batch_i:{}, episode_len:{}".format(i+1, self.episode_len)) if len(self.data_buffer) > self.batch_size: #当队列中的数据量大于mini-batch梯度下降所需的最小批量数目时,我们便可以从队列中随机选择最小批量的数据,用于训练更新网络 loss, entropy = self.policy_update() logging.debug("batch_i {} loss {}".format(i+1, loss)) #check the performance of the current model and save the model params if (i+1) % self.check_freq == 0: #每隔50次self-play对局就对当前AI模型进行一次评估 print("current self-play batch: {}".format(i+1)) win_ratio = self.policy_evaluate(batch= i+1) net_params = self.policy_value_net.get_policy_param() # get model params #保存模型,序列化对象,并将结果数据流写入到文件对象中 pickle.dump(net_params, open('current_policy_8_8_5_new.model', 'wb'), pickle.HIGHEST_PROTOCOL) # save model param to file if win_ratio > self.best_win_ratio: #胜率提高,更新模型参数 print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio pickle.dump(net_params, open('best_policy_8_8_5_new.model', 'wb'), pickle.HIGHEST_PROTOCOL) # update the best_policy if self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000: #当胜率达到100的时候纯蒙特卡洛模拟次数增加1000次,胜率清0 self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')