class TrainPipeline: def __init__(self, init_model=None): # 棋盘数据 self.board_width = 8 self.board_height = 8 # self.n_in_row = 5 self.board = chessboard(row=self.board_width, col=self.board_height) # 训练参数 self.learn_rate = 2e-3 self.lr_multiplier = 1.0 self.temp = 1.0 self.n_playout = 400 # 每次模拟次数 self.c_puct = 5 self.buffer_size = 10000000 self.batch_size = 512 # 每批样本量 self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # 每次更新前迭代次数 self.kl_targ = 0.02 self.check_freq = 2 # 自我对弈次数 self.game_batch_num = 1000 self.best_win_ratio = 0.0 # 纯蒙特卡罗树搜索,用来作为基准 self.pure_mcts_playout_num = 400 # 有预训练模型的情况 if init_model: self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model) else: # 从头开始训练 self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) # 扩充训练数据 def get_equi_data(self, play_data): # 用旋转和翻转来设置数据 # play_data:[(state, mcts_prob, winner_z), ..., ...] extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # 顺时针旋转 equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90( np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # 垂直翻转 equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data # 进行一轮自我博弈 def start_self_play(self, player, is_shown=0, temp=1e-3): self.board.reset() p1, p2 = self.board.players states, mcts_probs, current_players = [], [], [] # 测试 # t = 0 while True: # t += 1 # print(t) move, move_probs = player.get_action(self.board, temp=temp, return_prob=1) # print("测试", move_probs) # store the data states.append(self.board.current_state()) mcts_probs.append(move_probs) current_players.append(self.board.current_player) # perform a move self.board.do_move(move) if is_shown: display(self.board) end, winner = self.board.game_end() # print(t, end, winner, self.board.count) if end: # winner from the perspective of the current player of each state winners_z = np.zeros(len(current_players)) if winner != -1: winners_z[np.array(current_players) == winner] = 1.0 winners_z[np.array(current_players) != winner] = -1.0 # reset MCTS root node player.reset_player() if is_shown: if winner != -1: print("Game end. Winner is player:", winner) else: print("Game end. Tie") return winner, zip(states, mcts_probs, winners_z) # 收集自我博弈训练数据 def collect_selfplay_data(self, n_games=1): for i in range(n_games): # print("测试", i) winner, play_data = self.start_self_play(self.mcts_player, temp=self.temp, is_shown=False) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) # 更新策略值网络 def policy_update(self): mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # 早期停止 break # 调整学习率 if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl:{:.5f}," "lr_multiplier:{:.3f}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy # 进行一局对弈 def start_play(self, player1, player2, start_player=1, is_shown=1): if start_player not in (1, 2): raise Exception('start_player should be either 0 (player1 first) ' 'or 1 (player2 first)') self.board.reset(start_player) p1, p2 = self.board.players player1.set_player_ind(p1) player2.set_player_ind(p2) players = {p1: player1, p2: player2} if is_shown: display(self.board) while True: current_player = self.board.get_current_player() # print(current_player, players) player_in_turn = players[current_player] move = player_in_turn.get_action(self.board) self.board.do_move(move) if is_shown: display(self.board) end, winner = self.board.game_end() if end: if is_shown: if winner != -1: print("Game end. Winner is", players[winner]) else: print("Game end. Tie") return winner # 策略评估,用纯蒙特卡罗树搜索来做基准 def policy_evaluate(self, n_games=10): current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2 + 1, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio # 运行训练 @run.change_dir @run.timethis def run(self): try: losses = [] for i in tqdm.tqdm(range(self.game_batch_num)): self.collect_selfplay_data(self.play_batch_size) print("batch i:{}, episode_len:{}".format( i + 1, self.episode_len)) # 测试用的 # self.policy_value_net.save_model('./output/best_policy.model') if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() losses.append(loss) # print(i, loss) # 检查当前模型表现并保存模型 if (i + 1) % self.check_freq == 0: print("当前自训练次数: {}".format(i + 1)) win_ratio = self.policy_evaluate() self.policy_value_net.save_model( './output/current_policy.model') if win_ratio > self.best_win_ratio: print("新的最佳策略!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model( './output/best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 plt.figure() plt.plot(losses) plt.savefig("./output/loss.png") except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self): # params of the board and the game self.board_width = BOARD_SIZE self.board_height = BOARD_SIZE self.board = Board() self.game = Game(self.board) # training params self.learn_rate = 5e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 300 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.025 self.check_freq = 1 self.game_batch_num = 1500 self.best_win_ratio = 0.0 self.episode_len = 0 # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy self.pure_mcts_playout_num = 300 # start training from a given policy-value net # policy_param = pickle.load(open('current_policy.model', 'rb')) # self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, net_params = policy_param) # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """ augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...]""" extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90( np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = 1 - np.var( np.array(winner_batch) - old_v.flatten()) / np.var( np.array(winner_batch)) explained_var_new = 1 - np.var( np.array(winner_batch) - new_v.flatten()) / np.var( np.array(winner_batch)) print( "kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}" .format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing games against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print("batch i:{}, episode_len:{}".format( i + 1, self.episode_len)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() # check the performance of the current model,and save the model params if (i + 1) % self.check_freq == 0: print("current self-play batch: {}".format(i + 1)) win_ratio = self.policy_evaluate() net_params = self.policy_value_net.get_policy_param( ) # get model params pickle.dump( net_params, open('current_policy.model', 'wb'), pickle.HIGHEST_PROTOCOL) # save model param to file if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio pickle.dump( net_params, open('best_policy.model', 'wb'), pickle.HIGHEST_PROTOCOL) # update the best_policy if self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 1000: self.pure_mcts_playout_num += 100 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game self.board_width = 6 self.board_height = 6 self.n_in_row = 4 self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 1500 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90(np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate*self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * ( np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1) ) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl:{:.5f}," "lr_multiplier:{:.3f}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}" ).format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print("batch i:{}, episode_len:{}".format( i+1, self.episode_len)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() # check the performance of the current model, # and save the model params if (i+1) % self.check_freq == 0: print("current self-play batch: {}".format(i+1)) win_ratio = self.policy_evaluate() self.policy_value_net.save_model('./current_policy.model') if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model('./best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(object): def __init__(self, init_model=None): self.game = Quoridor() self.learn_rate = 2e-3 self.lr_multiplier = 1.0 self.temp = 1.0 self.n_playout = 200 self.c_puct = 5 self.buffer_size = 10000 self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.kl_targ = 0.02 self.check_freq = 10 self.game_batch_num = 1000 self.best_win_ratio = 0.0 self.pure_mcts_playout_num = 1000 self.old_probs = 0 self.new_probs = 0 self.first_trained = False if init_model: self.policy_value_net = PolicyValueNet(model_file=init_model) else: self.policy_value_net = PolicyValueNet() self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): extend_data = [] for i, (state, mcts_prob, winner) in enumerate(play_data): wall_state = state[:3,:BOARD_SIZE - 1,:BOARD_SIZE - 1] dist_state1 = np.reshape(state[(6 + (WALL_NUM + 1) * 2), :BOARD_SIZE, :BOARD_SIZE], (1, BOARD_SIZE, BOARD_SIZE)) dist_state2 = np.reshape(state[(7 + (WALL_NUM + 1) * 2), :BOARD_SIZE, :BOARD_SIZE], (1, BOARD_SIZE, BOARD_SIZE)) # horizontally flipped game flipped_wall_state = [] for i in range(3): wall_padded = np.fliplr(wall_state[i]) wall_padded = np.pad(wall_padded, (0,1), mode='constant', constant_values=0) flipped_wall_state.append(wall_padded) flipped_wall_state = np.array(flipped_wall_state) player_position = state[3:5, :,:] flipped_player_position = [] for i in range(2): flipped_player_position.append(np.fliplr(player_position[i])) flipped_player_position = np.array(flipped_player_position) h_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5:, :,:]]) h_equi_mcts_prob = np.copy(mcts_prob) h_equi_mcts_prob[11] = mcts_prob[10] # SE to SW h_equi_mcts_prob[10] = mcts_prob[11] # SW to SE h_equi_mcts_prob[9] = mcts_prob[8] # NE to NW h_equi_mcts_prob[8] = mcts_prob[9] # NW to NE h_equi_mcts_prob[7] = mcts_prob[6] # EE to WW h_equi_mcts_prob[6] = mcts_prob[7] # WW to EE h_equi_mcts_prob[3] = mcts_prob[2] # E to W h_equi_mcts_prob[2] = mcts_prob[3] # W to E h_wall_actions = h_equi_mcts_prob[12:12 + (BOARD_SIZE-1) ** 2].reshape(BOARD_SIZE-1, BOARD_SIZE-1) v_wall_actions = h_equi_mcts_prob[12 + (BOARD_SIZE-1) ** 2:].reshape(BOARD_SIZE-1, BOARD_SIZE -1) flipped_h_wall_actions = np.fliplr(h_wall_actions) flipped_v_wall_actions = np.fliplr(v_wall_actions) h_equi_mcts_prob[12:] = np.hstack([flipped_h_wall_actions.flatten(), flipped_v_wall_actions.flatten()]) # Vertically flipped game flipped_wall_state = [] for i in range(3): wall_padded = np.flipud(wall_state[i]) wall_padded = np.pad(wall_padded, (0,1), mode='constant', constant_values=0) flipped_wall_state.append(wall_padded) flipped_wall_state = np.array(flipped_wall_state) flipped_player_position = [] for i in range(2): flipped_player_position.append(np.flipud(player_position[1-i])) flipped_player_position = np.array(flipped_player_position) cur_player = (np.ones((BOARD_SIZE, BOARD_SIZE)) - state[5 + 2* (WALL_NUM+1),:,:]).reshape(-1,BOARD_SIZE, BOARD_SIZE) v_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5+(WALL_NUM+1):5 + 2*(WALL_NUM+1), :,:], state[5:5+(WALL_NUM+1),:,:], cur_player, dist_state2, dist_state1]) # v_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5:(5 + (WALL_NUM+1) * 2), :, :], cur_player, state[:(6 + (WALL_NUM + 1) * 2), :, :]]) v_equi_mcts_prob = np.copy(mcts_prob) v_equi_mcts_prob[11] = mcts_prob[9] # SE to NE v_equi_mcts_prob[10] = mcts_prob[8] # SW to NW v_equi_mcts_prob[9] = mcts_prob[11] # NE to SE v_equi_mcts_prob[8] = mcts_prob[10] # NW to SW v_equi_mcts_prob[5] = mcts_prob[4] # NN to SS v_equi_mcts_prob[4] = mcts_prob[5] # SS to NN v_equi_mcts_prob[1] = mcts_prob[0] # N to S v_equi_mcts_prob[0] = mcts_prob[1] # S to N h_wall_actions = v_equi_mcts_prob[12:12 + (BOARD_SIZE-1) ** 2].reshape(BOARD_SIZE-1, BOARD_SIZE-1) v_wall_actions = v_equi_mcts_prob[12 + (BOARD_SIZE-1) ** 2:].reshape(BOARD_SIZE-1, BOARD_SIZE -1) flipped_h_wall_actions = np.flipud(h_wall_actions) flipped_v_wall_actions = np.flipud(v_wall_actions) v_equi_mcts_prob[12:] = np.hstack([flipped_h_wall_actions.flatten(), flipped_v_wall_actions.flatten()]) ## Horizontally-vertically flipped game wall_state = state[:3,:BOARD_SIZE - 1,:BOARD_SIZE - 1] flipped_wall_state = [] for i in range(3): wall_padded = np.fliplr(np.flipud(wall_state[i])) wall_padded = np.pad(wall_padded, (0,1), mode='constant', constant_values=0) flipped_wall_state.append(wall_padded) flipped_wall_state = np.array(flipped_wall_state) flipped_player_position = [] for i in range(2): flipped_player_position.append(np.fliplr(np.flipud(player_position[1-i]))) flipped_player_position = np.array(flipped_player_position) cur_player = (np.ones((BOARD_SIZE, BOARD_SIZE)) - state[5 + 2*(WALL_NUM+1),:,:]).reshape(-1,BOARD_SIZE, BOARD_SIZE) hv_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5 + (WALL_NUM+1):5 + 2*(WALL_NUM+1), :,:], state[5:5+(WALL_NUM+1),:,:], cur_player, dist_state2, dist_state1]) # hv_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5:(5 + (WALL_NUM+1) * 2), :, :], cur_player, state[(6 + (WALL_NUM + 1) * 2):, :, :]]) hv_equi_mcts_prob = np.copy(mcts_prob) hv_equi_mcts_prob[11] = mcts_prob[8] # SE to NW hv_equi_mcts_prob[10] = mcts_prob[9] # SW to NE hv_equi_mcts_prob[9] = mcts_prob[10] # NE to SW hv_equi_mcts_prob[8] = mcts_prob[11] # NW to SE hv_equi_mcts_prob[7] = mcts_prob[6] # EE to WW hv_equi_mcts_prob[6] = mcts_prob[7] # WW to EE hv_equi_mcts_prob[5] = mcts_prob[4] # NN to SS hv_equi_mcts_prob[4] = mcts_prob[5] # SS to NN hv_equi_mcts_prob[3] = mcts_prob[2] # E to W hv_equi_mcts_prob[2] = mcts_prob[3] # W to E hv_equi_mcts_prob[1] = mcts_prob[0] # N to S hv_equi_mcts_prob[0] = mcts_prob[1] # S to N h_wall_actions = hv_equi_mcts_prob[12:12 + (BOARD_SIZE-1) ** 2].reshape(BOARD_SIZE-1, BOARD_SIZE-1) v_wall_actions = hv_equi_mcts_prob[12 + (BOARD_SIZE-1) ** 2:].reshape(BOARD_SIZE-1, BOARD_SIZE -1) flipped_h_wall_actions = np.fliplr(np.flipud(h_wall_actions)) flipped_v_wall_actions = np.fliplr(np.flipud(v_wall_actions)) hv_equi_mcts_prob[12:] = np.hstack([flipped_h_wall_actions.flatten(), flipped_v_wall_actions.flatten()]) ########### extend_data.append((state, mcts_prob, winner)) extend_data.append((h_equi_state, h_equi_mcts_prob, winner)) extend_data.append((v_equi_state, v_equi_mcts_prob, winner * -1)) extend_data.append((hv_equi_state, hv_equi_mcts_prob, winner * -1)) return extend_data def collect_selfplay_data(self, n_games=1): for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data = list(play_data)[:] self.episode_len = len(play_data) play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) print("{}th game finished. Current episode length: {}, Length of data buffer: {}".format(i, self.episode_len, len(self.data_buffer))) def policy_update(self): dataloader = DataLoader(self.data_buffer, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True) valloss_acc = 0 polloss_acc = 0 entropy_acc = 0 for i in range(NUM_EPOCHS): self.old_probs = self.new_probs if self.first_trained: kl = np.mean(np.sum(self.old_probs * (np.log(self.old_probs + 1e-10) - np.log(self.new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: break if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 for i, (state, mcts_prob, winner) in enumerate(dataloader): valloss, polloss, entropy = self.policy_value_net.train_step(state, mcts_prob, winner, self.learn_rate * self.lr_multiplier) self.new_probs, new_v = self.policy_value_net.policy_value(state) global iter_count writer.add_scalar("Val Loss/train", valloss.item(), iter_count) writer.add_scalar("Policy Loss/train", polloss.item(), iter_count) writer.add_scalar("Entropy/train", entropy, iter_count) writer.add_scalar("LR Multiplier", self.lr_multiplier, iter_count) iter_count += 1 valloss_acc += valloss.item() polloss_acc += polloss.item() entropy_acc += entropy.item() self.first_trained = True valloss_mean = valloss_acc / (len(dataloader) * NUM_EPOCHS) polloss_mean = polloss_acc / (len(dataloader) * NUM_EPOCHS) entropy_mean = entropy_acc / (len(dataloader) * NUM_EPOCHS) #explained_var_old = 1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch)) #explained_var_new = 1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch)) #print( "kl:{:.5f}, lr_multiplier:{:.3f}, value loss:{}, policy loss:[], entropy:{}".format( # kl, self.lr_multiplier, valloss, polloss, entropy, explained_var_old, explained_var_new)) return valloss_mean, polloss_mean, entropy_mean def run(self): try: self.collect_selfplay_data(3) count = 0 for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) # collect_s print("batch i:{}, episode_len:{}".format(i + 1, self.episode_len)) if len(self.data_buffer) > BATCH_SIZE: valloss, polloss, entropy = self.policy_update() print("VALUE LOSS: %0.3f " % valloss, "POLICY LOSS: %0.3f " % polloss, "ENTROPY: %0.3f" % entropy) #writer.add_scalar("Val Loss/train", valloss.item(), i) #writer.add_scalar("Policy Loss/train", polloss.item(), i) #writer.add_scalar("Entory/train", entropy, i) if (i + 1) % self.check_freq == 0: count += 1 print("current self-play batch: {}".format(i + 1)) # win_ratio = self.policy_evaluate() # Add generation to filename self.policy_value_net.save_model('model_7x7_' + str(count) + '_' + str("%0.3f_" % (valloss+polloss) + str(time.strftime('%Y-%m-%d', time.localtime(time.time()))))) except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self, size=(8, 8), init_model=None): # params of the board and the game print(size) self.board_width = size[1] self.board_height = size[0] self.board = GomokuBoard(size=(self.board_width, self.board_height)) self.game = GomokuGame(self.board) # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 3000 self.best_win_ratio = 0.0 self.all_loss = [] # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, z in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90( np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob), z)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob), z)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): result, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) print('The result:', result) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] z_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, z_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=(1, 2))) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(z_batch) - old_v.flatten()) / np.var(np.array(z_batch))) explained_var_new = (1 - np.var(np.array(z_batch) - new_v.flatten()) / np.var(np.array(z_batch))) print(("kl:{:.5f}," "lr_multiplier:{:.3f}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy # def policy_evaluate(self, n_games=10): # """ # Evaluate the trained policy by playing against the pure MCTS player # Note: this is only for monitoring the progress of training # """ # current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, # c_puct=self.c_puct, # n_playout=self.n_playout) # pure_mcts_player = MCTS_Pure(c_puct=5, # n_playout=self.pure_mcts_playout_num) # win_cnt = defaultdict(int) # for i in range(n_games): # winner = self.game.start_play(current_mcts_player, # pure_mcts_player, # start_player=i % 2, # is_shown=0) # win_cnt[winner] += 1 # win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games # print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( # self.pure_mcts_playout_num, # win_cnt[1], win_cnt[2], win_cnt[-1])) # return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print("batch i:{}, episode_len:{}".format( i + 1, self.episode_len)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() self.all_loss.append(loss) # check the performance of the current model, # and save the model params # if (i+1) % self.check_freq == 0: # print("current self-play batch: {}".format(i+1)) # win_ratio = self.policy_evaluate() # self.policy_value_net.save_model('./current_policy.model') # if win_ratio > self.best_win_ratio: # print("New best policy!!!!!!!!") # self.best_win_ratio = win_ratio # update the best_policy if (i + 1) % 10 == 0: self.policy_value_net.save_model( './model/best_policy08_08.model') print('save model.') print(self.all_loss) print('finish') # if (self.best_win_ratio == 1.0 and # self.pure_mcts_playout_num < 5000): # self.pure_mcts_playout_num += 1000 # self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline: def __init__(self, n: int, init_model=None): # params of the board and the game self.n = n self.board = Board(self.n) self.game = Game(self.board) # training params self.learn_rate = 5e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_play_out = 400 # number of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.epochs = 5 # number of train_steps for each update self.kl_target = 0.025 self.check_freq = 50 self.game_batch_number = 10000 self.best_win_ratio = 0.0 self.episode_length = 0 self.pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) # number of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy self.last_batch_number = 0 self.pure_mcts_play_out_number = 1000 if init_model: # start training from an initial policy-value net policy_param = pickle.load(open(init_model, 'rb')) self.policy_value_net = PolicyValueNet(self.n, net_params=policy_param) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.n) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_func, c_puct=self.c_puct, n_play_out=self.n_play_out, is_self_play=1) def get_equivalent_data(self, play_data): """ augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...]""" extend_data = [] for state, mcts_probabilities, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equivalent_state = np.array([np.rot90(s, i) for s in state]) equivalent_mcts_prob = np.rot90( np.flipud(mcts_probabilities.reshape(self.n, self.n)), i) extend_data.append( (equivalent_state, np.flipud(equivalent_mcts_prob).flatten(), winner)) # flip horizontally equivalent_state = np.array( [np.fliplr(s) for s in equivalent_state]) equivalent_mcts_prob = np.fliplr(equivalent_mcts_prob) extend_data.append( (equivalent_state, np.flipud(equivalent_mcts_prob).flatten(), winner)) return extend_data def collect_self_play_data(self): """collect self-play data for training""" play_data = list( self.game.start_self_play(self.mcts_player, temp=self.temp)) self.episode_length = len(play_data) play_data = self.get_equivalent_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" kl = 0 new_v = 0 loss = 0 entropy = 0 mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probabilities_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probabilities, old_v = self.policy_value_net.policy_value( state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probabilities_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probabilities, new_v = self.policy_value_net.policy_value( state_batch) kl = np.mean( np.sum(old_probabilities * (np.log(old_probabilities + 1e-10) - np.log(new_probabilities + 1e-10)), axis=1)) if kl > self.kl_target * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_target * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_target / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = 1 - np.var( np.array(winner_batch) - old_v.flatten()) / np.var( np.array(winner_batch)) explained_var_new = 1 - np.var( np.array(winner_batch) - new_v.flatten()) / np.var( np.array(winner_batch)) print_log( "kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}" .format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing games against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer( self.policy_value_net.policy_value_func, c_puct=self.c_puct, n_play_out=self.n_play_out) pure_mcts_player = MCTS_Pure(c_puct=5, n_play_out=self.pure_mcts_play_out_number) win_cnt = defaultdict(int) results = self.pool.map(self.game.start_play, [(current_mcts_player, pure_mcts_player, i) for i in range(n_games)]) for winner in results: win_cnt[winner] += 1 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games print_log("number_play_outs:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_play_out_number, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_number): if os.path.exists("done"): break start_time = time.time() self.collect_self_play_data() print_log("batch i:{}, episode_len:{}, in:{}".format( i + 1 + self.last_batch_number, self.episode_length, time.time() - start_time)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() data_log( str((i + 1 + self.last_batch_number, loss, entropy))) # check the performance of the current model,and save the model params if (i + 1) % self.check_freq == 0: print_log("current self-play batch: {}".format( i + 1 + self.last_batch_number)) start_time = time.time() win_ratio = self.policy_evaluate() net_params = self.policy_value_net.get_policy_parameter( ) # get model params pickle.dump(net_params, open('current_policy.model', 'wb'), pickle.HIGHEST_PROTOCOL) print_log(str(time.time() - start_time)) if win_ratio > self.best_win_ratio: self.best_win_ratio = win_ratio pickle.dump(net_params, open('best_policy.model', 'wb'), pickle.HIGHEST_PROTOCOL) if self.best_win_ratio >= 0.8: print_log("New best policy defeated " + str(self.pure_mcts_play_out_number) + " play out MCTS player ") self.best_win_ratio = 0.0 self.pure_mcts_play_out_number += 1000 except KeyboardInterrupt: pass
class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game self.board_width = 6 #棋盘宽度 self.board_height = 6 #棋盘高度 self.n_in_row = 4 #胜利条件:多少个棋连成一线算是胜利 # 实例化一个board,定义棋盘宽高和胜利条件 self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 1500 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 #初始化network和树,network是一直保存的,树的话不知道什么时候重置。 if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) #作用是扩充data,因为五子棋是上下左右相同的。 def get_equi_data(self, play_data): """augment the data set by rotation and flipping ##play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise # np.rot90:矩阵旋转90度 # np.flipud:矩阵反转 equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90(np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data #搜集selfplay的data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" #进行n_games游戏 for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data = list(play_data)[:] self.episode_len = len(play_data) #对弈步数 # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" #======解压数据============ mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] #========================= #这里好像做了important sampling,直接计算KL_diverges大小,超过一定就早停 old_probs, old_v = self.policy_value_net.policy_value(state_batch) #进行epochs次训练 for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate*self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * ( np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1) ) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate # 根据上次更新的KL_diverges大小,动态调整学习率 if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl:{:.5f}," "lr_multiplier:{:.3f}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}" ).format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy #用纯MCTS玩,和AlphaZERO玩,看看哪个更厉害 def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio #training pipeline def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): #搜集data,搜集play_batch_size次,每次玩n_game次。 #每次game都会新建一棵树,每一步就是树的一个节点。 #每一步都会进行_n_playout次模拟 self.collect_selfplay_data(self.play_batch_size) print("batch i:{}, episode_len:{}".format( i+1, self.episode_len)) # data足够,update.可以用上important sampling,updata,n次。 # update玩,进行新的搜集时,就会清空原来数据。 if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() # check the performance of the current model, # and save the model params if (i+1) % self.check_freq == 0: print("current self-play batch: {}".format(i+1)) win_ratio = self.policy_evaluate() self.policy_value_net.save_model('./current_policy.model') if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model('./best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game # basic params self.board_width = 9 self.board_height = 9 self.n_in_row = 5 # init the board and game self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) # training params self.learn_rate = 3e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1e-3 # the temperature param # self.n_playout = 400 # num of simulations for each move self.n_playout = 400 self.c_puct = 3 # a number in (0, inf) that controls how quickly exploration # converges to the maximum-value policy. A higher value means # relying on the prior more. self.buffer_size = 10000 # self.batch_size = 512 # mini-batch size for training self.batch_size = 256 self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 1000 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 400 if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90( np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=25): """collect self-play data for training""" for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl:{:.5f}," "lr_multiplier:{:.3f}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=30): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games with open(logfile_name, 'w+') as file: file.write("num_playouts:{}, win: {}, lose: {}, tie:{}\n".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print("batch i:{}, episode_len:{}".format( i + 1, self.episode_len)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() # check the performance of the current model, # and save the model params if (i + 1) % self.check_freq == 0: print("current self-play batch: {}".format(i + 1)) win_ratio = self.policy_evaluate() self.policy_value_net.save_model( './models/current_policy_{}.model'.format(i + 1)) if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model( './models/best_policy_{}.model'.format(i + 1)) if (self.best_win_ratio > 0.8 and self.pure_mcts_playout_num < 25000): print("stronger model to compete") self.pure_mcts_playout_num += 500 self.best_win_ratio = 0.0 elif self.best_win_ratio == 0 and self.n_playout < 15000: self.pure_mcts_playout_num += 250 print("enhance the alphazero mcts") print('-------------------------training_outer_epoch!!!!!!', i, "-----------------") except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game self.board_width = 6 self.board_height = 6 self.n_in_row = 4 self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 1500 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 # add output log self.formatter = logging.Formatter('%(asctime)s [%(module)s] %(levelname)s: %(message)s', '%Y-%m-%d %H:%M:%S') self.logger = logging.getLogger(__name__) self.logger.setLevel(level=logging.INFO) self.handler = logging.FileHandler("output.log") self.handler.setLevel(logging.INFO) self.handler.setFormatter(self.formatter) self.console = logging.StreamHandler() self.console.setLevel(logging.INFO) self.console.setFormatter(self.formatter) self.logger.addHandler(self.handler) self.logger.addHandler(self.console) if init_model: if os.path.exists(init_model): # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model) else: self.logger.error("{} does not exists!\n".format(init_model)) return -1 else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90(np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate*self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * ( np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1) ) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) self.logger.info(("kl:{:.5f}, lr_multiplier:{:.3f}, loss:{}, entropy:{}, explained_var_old:{:.3f}, explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games self.logger.info("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) self.logger.info("batch i:{}, episode_len:{}".format( i+1, self.episode_len)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() # check the performance of the current model, # and save the model params if (i+1) % self.check_freq == 0: self.logger.info("current self-play batch: {}".format(i+1)) win_ratio = self.policy_evaluate() self.policy_value_net.save_model('./current_policy.model') self.policy_value_net.save_model('./policy_{}_{}_{}_{}.model'.fromat(self.board_width, self.board_height, self.n_in_row, datetime.datetime.strftime(datetime.datetime.now(), "%Y%m%d%H%M%S"))) if win_ratio > self.best_win_ratio: self.logger.info("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model('./best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')
class Train(): def __init__(self, init_model=None): # params of the game self.width = 4 self.height = 4 self.game = Game() # params of training self.learn_rate = 2e-3 self.lr_multiplier = 1.0 self.temp = 1.0 self.n_playout = 300 self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 64 self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 5000 self.best_win_ratio = 0.0 self.pure_mcts_playout_num = 500 if init_model: self.policy_value_net = PolicyValueNet(self.width, self.height, model_file=init_model) else: self.policy_value_net = PolicyValueNet(self.width, self.height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def collect_selfplay_data(self, n_games=1): for i in range(n_games): print "=====================Start====================" self.game = Game() winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) #print "winner",winner,play_data print "======================END=====================" play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data #play_data = self.get_qui_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): #print "____policy___update_______" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) #print "old_v = ",old_v for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: break if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print "result-eval var=", np.var( np.array(winner_batch) - new_v.flatten()), "\twinner var=", np.var( np.array(winner_batch)) print "kl=", kl, "\tlr_mul=", self.lr_multiplier print "var_old : {:.3f}\tvar_new : {:.3f}".format( explained_var_old, explained_var_new) return loss, entropy def policy_evaluate(self, n_games=10): #print "_____policy__evaluation________" current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) print "winner", winner win_cnt[winner] += 1 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[0]) / n_games print "win ratio =", win_ratio print("num_playout:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[0])) return win_ratio def run(self, modelfile=None): for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print "gamebatch :", i + 1, "episode_len:", self.episode_len print "selfplayend,data_buffer len=", len(self.data_buffer) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() print "loss = {:.3f}\tentropy = {:.3f}".format(loss, entropy) if (i + 1) % self.check_freq == 0: print("current self-play batch:{}".format(i + 1)) win_ratio = self.policy_evaluate() self.policy_value_net.save_model('current.model') if win_ratio > self.best_win_ratio: print("new best model") self.best_win_ratio = win_ratio self.policy_value_net.save_model("best.model") if self.best_win_ratio >= 0.8 and self.pure_mcts_playout_num < 1000: print "Pure Harder" self.pure_mcts_playout_num += 100 self.best_win_ratio = 0.0
class TrainPipeline(): def __init__(self, init_model=None): self.board = Board() self.game = Game(self.board) # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL # self.board_height, self.temp = 1.0 # the temperature param self.n_playout = 1600 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 15000 self.best_win_ratio = 0.0 self.pure_mcts_playout_num = 1000 if init_model: self.policy_value_net = PolicyValueNet(model_file=init_model, use_gpu=True) else: self.policy_value_net = PolicyValueNet(use_gpu=True) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) print("init done") def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] print("play_data = {}".format(play_data)) for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise # print("state[0] = {}".format(state[0])) # print("state = {}".format(state)) # equi_state = np.array([np.rot90(s, i) for s in state]) equi_state = np.rot90(state, i) equi_mcts_prob = np.rot90( np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp, is_shown=1) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data # play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print( ("kl = {:.5f}," "lr_multiplier = {:.3f}," "loss = {}," "entropy = {}," "explained_var_old = {:.3f}," "explained_var_new = {:.3f}").format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): print(i) winner = self.game.start_play(current_mcts_player, pure_mcts_player, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games print("num_playouts = {}, win = {}, lose = {}, tie = {}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: localtime = time.asctime(time.localtime(time.time())) print("本地时间为 :", localtime) for i in range(self.game_batch_num): print("selfplay....") self.collect_selfplay_data(self.play_batch_size) print("selfplay done") print("batch i = {}, episode_len = {}".format( i + 1, self.episode_len)) localtime = time.asctime(time.localtime(time.time())) print("本地时间为 :", localtime) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() # check the performance of the current model, # and save the model params if (i + 1) % self.check_freq == 0: print("current self-play batch = {}".format(i + 1)) win_ratio = self.policy_evaluate() self.policy_value_net.save_model('./current_policy.model') if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model('./best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self): # 게임(오목)에 대한 변수들 self.board_width, self.board_height = 9, 9 self.n_in_row = 5 self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) # 학습에 대한 변수들 self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # KL에 기반하여 학습 계수를 적응적으로 조정 self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.data_buffer = deque(maxlen=self.buffer_size) self.batch_size = 512 # mini-batch size : 버퍼 안의 데이터 중 512개를 추출 self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 500 # 지정 횟수마다 모델을 체크하고 저장. 원래는 100이었음. self.game_batch_num = 3000 # 최대 학습 횟수 self.train_num = 0 # 현재 학습 횟수 # policy-value net에서 학습 시작 self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """ 회전 및 뒤집기로 데이터set 확대 play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # 반시계 방향으로 회전 equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # 수평으로 뒤집기 equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data = list(play_data)[:] self.episode_len = len(play_data) # 데이터를 확대 play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) # deque의 오른쪽(마지막)에 삽입 def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate*self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) # D_KL diverges 가 나쁘면 빠른 중지 if kl > self.kl_targ * 4 : break # learning rate를 적응적으로 조절 if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1 : self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10 : self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(f"kl:{kl:5f}, lr_multiplier:{self.lr_multiplier:3f}, loss:{loss}, entropy:{entropy}, explained_var_old:{explained_var_old:3f}, explained_var_new:{explained_var_new:3f}") return loss, entropy def run(self): for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) self.train_num += 1 print(f"batch i:{self.train_num}, episode_len:{self.episode_len}") if len(self.data_buffer) > self.batch_size : loss, entropy = self.policy_update() # 현재 model의 성능을 체크, 모델 속성을 저장 if (i+1) % self.check_freq == 0: print(f"★ {self.train_num}번째 batch에서 모델 저장 : {datetime.now()}") self.policy_value_net.save_model(f'{model_path}/policy_9_{self.train_num}.model') pickle.dump(self, open(f'{train_path}/train_9_{self.train_num}.pickle', 'wb'), protocol=2)
def test(): from quoridor import Quoridor from pure_mcts import MCTSPlayer as MCTS_Pure from mcts_player import MCTSPlayer from policy_value_net import PolicyValueNet policy_value_net = PolicyValueNet(model_file=None) c_puct = 5 n_playout = 800 temp = 1.0 board = Quoridor() game = Game(board) mcts_player = MCTSPlayer(policy_value_net.policy_value_fn, c_puct=c_puct, n_playout=n_playout, is_selfplay=1) winner, play_data = game.start_self_play(mcts_player, is_shown=1, temp=temp) print(winner) print(play_data) state_batch = [data[0] for data in play_data] mcts_probs_batch = [data[1] for data in play_data] winner_batch = [data[2] for data in play_data] learn_rate = 2e-3 lr_multiplier = 1.0 kl_targ = 0.02 old_probs, old_v = policy_value_net.policy_value(state_batch) for i in range(5): loss, entropy = policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, learn_rate*lr_multiplier) new_probs, new_v = policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * ( np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1) ) if kl > kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > kl_targ * 2 and lr_multiplier > 0.1: lr_multiplier /= 1.5 elif kl < kl_targ / 2 and lr_multiplier < 10: lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl:{:.5f}," "lr_multiplier:{:.3f}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}" ).format(kl, lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) policy_value_net.save_model('./current_policy.model')
class TrainPipeline(): def __init__(self): # params of the board and the game self.board_width = 5 self.board_height = 5 self.game = Game() # training params self.learn_rate = 0.001 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 500 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 128 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 100 self.game_batch_num = 2000 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy self.pure_mcts_playout_num = 3000 # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def collect_selfplay_data(self, n_games=1): """ collect self-play data for training default collect one game data """ for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) self.data_buffer.extend(play_data) def policy_update(self, verbose=False): """ update the policy-value net verbose to show more details of the training steps, default not show """ # ipdb.set_trace() mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) loss_list = [] entropy_list = [] for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) loss_list.append(loss) entropy_list.append(entropy) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 if verbose: explained_var_old = ( 1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = ( 1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl: {:.3f}, " "lr_multiplier: {:.3f}\n" "last loss: {:.3f}, " "mean loss: {:.3f}, " "mean entropy: {:.3f}\n" "explained old: {:.3f}, " "explained new: {:.3f}\n").format(kl, self.lr_multiplier, loss_list[-1], np.mean(loss_list), np.mean(entropy_list), explained_var_old, explained_var_new)) def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing games against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=3000) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): # alphazero always red, but change the first player in the game winner = self.game.start_play(current_mcts_player, pure_mcts_player, 1, 2, start_player=(i % 2) + 1, is_show=0) print("winner is {}".format(winner)) win_cnt[winner] += 1 # 计算红方(alphazero)的胜率 win_ratio = win_cnt[1] / n_games print("num_playouts:{}, win: {}, lose: {}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): print("game", i, 'start ...') bt = time.time() self.collect_selfplay_data(self.play_batch_size) print('game', i, 'cost', int(time.time() - bt), 's') if len(self.data_buffer) > self.batch_size: print("#### batch i:{} ####\n".format(i + 1)) for vi in range(5): verbose = vi % 5 == 0 self.policy_update(verbose) # check the performance of the current model,and save the model params # every 1000 check once if (i + 1) % self.check_freq == 0: print("current self-play batch: {}".format(i + 1)) self.policy_value_net.saver.save( self.policy_value_net.session, self.policy_value_net.model_file) win_ratio = self.policy_evaluate() print('*****win ration: {:.2f}%\n'.format(win_ratio * 100)) if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # save the model self.policy_value_net.saver.save( self.policy_value_net.session, self.policy_value_net.model_file ) # update the best_policy if self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000: self.pure_mcts_playout_num += 100 self.best_win_ratio = 0.0 except KeyboardInterrupt: # save before quit self.policy_value_net.saver.save(self.policy_value_net.session, self.policy_value_net.model_file) print('quit, Bye !')
class TrainPipeline(): def __init__(self): # params of the board and the game self.board_width = 9 self.board_height = 9 self.board = Board(width=self.board_width, height=self.board_height) self.game = Game(self.board) # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 800 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 3 self.best_loss = None # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 init_model = 'checkpoint/current_policy.model' if os.path.isfile(init_model + '.index'): # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ print('1') extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90( np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" print('2') for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" print('3') mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl:{:.5f}," "lr_multiplier:{:.3f}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ print('4') current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = 0 for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt += 1 win_ratio = win_cnt / n_games print("num_playouts:{}, win: {}".format(self.pure_mcts_playout_num, win_cnt)) return win_ratio def run(self): """run the training pipeline""" print('go1') try: if not os.path.isdir('checkpoint'): os.makedirs('checkpoint') for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print("{}: batch i:{}, episode_len:{}".format( datetime.datetime.now(), i + 1, self.episode_len)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() if self.best_loss is None or loss < self.best_loss: self.best_loss = loss print( "New best policy auto save at batch {}".format(i + 1)) self.policy_value_net.save_model( 'checkpoint/best_policy.model') if (i + 1) % self.check_freq == 0: print("current model auto save at batch {}".format(i + 1)) self.policy_value_net.save_model( 'checkpoint/current_policy.model') except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(object): def __init__(self, init_model=None): # 棋盘参数 self.game = Quoridor() # 训练参数 self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # 适应性调节学习速率 self.temp = 1.0 self.n_playout = 400 self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 128 # 取1 测试ing self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 1500 self.best_win_ratio = 0.0 self.pure_mcts_playout_num = 1000 if init_model: self.policy_value_net = PolicyValueNet(model_file=init_model) else: self.policy_value_net = PolicyValueNet() # 设置电脑玩家信息 self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) # def get_equi_data(self, play_data): # """ # 数据集增强,获取旋转后的数据,因为五子棋也是对称的 # play_data: [(state, mcts_prob, winner_z), ..., ...]""" # extend_data = [] # for state, mcts_porb, winner in play_data: # equi_state = np.array([np.rot90(s,2) for s in state]) # equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(9, 9)), 2) # extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # # flip horizontally # equi_state = np.array([np.fliplr(s) for s in equi_state]) # equi_mcts_prob = np.fliplr(equi_mcts_prob) # extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # return extend_data def collect_selfplay_data(self, n_games=1): """收集训练数据""" for i in range(n_games): winner, play_data = self.game.start_self_play( self.mcts_player, temp=self.temp) # 进行自博弈 play_data = list(play_data)[:] self.episode_len = len(play_data) # 数据增强 # play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """训练策略价值网络""" mini_batch = random.sample(self.data_buffer, self.batch_size) # 获取mini-batch state_batch = [data[0] for data in mini_batch] # 提取第一位的状态 mcts_probs_batch = [data[1] for data in mini_batch] # 提取第二位的概率 winner_batch = [data[2] for data in mini_batch] # 提取第三位的胜负情况 old_probs, old_v = self.policy_value_net.policy_value( state_batch) # 输入网络计算旧的概率和胜负价值,这里为什么要计算旧的数据是因为需要计算 # 新旧之间的KL散度来控制学习速率的退火 # 开始训练epochs个轮次 for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) new_probs, new_v = self.policy_value_net.policy_value( state_batch) # 计算新的概率和价值 kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # 如果KL散度发散的很不好,就提前结束训练 break # 根据KL散度,适应性调节学习速率 if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.5 explained_var_old = 1 - np.var( np.array(winner_batch) - old_v.flatten()) / np.var( np.array(winner_batch)) explained_var_new = 1 - np.var( np.array(winner_batch) - new_v.flatten()) / np.var( np.array(winner_batch)) print( "kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}" .format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def run(self): """训练""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print("batch i:{}, episode_len:{}".format( i + 1, self.episode_len)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() print("LOSS:", loss) # 保存loss with open('loss.txt', 'a') as f: f.writelines(str(loss) + '\n') if (i + 1) % self.check_freq == 0: print("current self-play batch: {}".format(i + 1)) # win_ratio = self.policy_evaluate() self.policy_value_net.save_model('current_policy') # 保存模型 except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game self.board_length = 6 self.n_in_row = 4 self.num_history = 2 self.chess = chessboard(self.board_length, self.n_in_row) # training params self.learn_rate = 5e-4 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temperature = 1.0 # the temperature param self.cpuct = 5 self.buffer_size = 10000 self.batch_size = 512 self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 10 self.kl_targ = 0.02 self.check_freq = 50 self.best_win_ratio = 0.0 self.game_batch_num = 4000 self.loss_dict = {} self.loss_hold = 50 self.real_mcts_simulation_times = 400 self.pure_mcts_simulation_times = 1000 if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_length, self.num_history, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_length, self.num_history) # ============================================================================= # deepcopy self.chess or not??????????????????????????????????????????? # ============================================================================= self.mcts_player = real_mcts(self.chess, self.policy_value_net.policy_value, self.cpuct, self.real_mcts_simulation_times, self.temperature, self.num_history, True) # ============================================================================= # self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, # c_puct=self.c_puct, # n_playout=self.n_playout, # is_selfplay=1) # ============================================================================= def get_equi_data(self, play_data): extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90(np.flipud( mcts_porb.reshape(self.board_length, self.board_length)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games = 1): for i in range(n_games): inter = interface(self.board_length) current_board = copy.deepcopy(self.chess) current_real_mcts = real_mcts(current_board, self.policy_value_net.policy_value, self.cpuct, self.real_mcts_simulation_times, self.temperature, self.num_history, True) play_data = inter.start_self_play(player = current_real_mcts) # ============================================================================= # play_data = start_self_play(player = current_real_mcts) # ============================================================================= play_data = list(play_data)[:] self.episode_len = len(play_data) play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) # ============================================================================= # mini_batch = self.data_buffer # ============================================================================= state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) first_loss = 0 for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate * self.lr_multiplier) if i == 0: first_loss = loss # ============================================================================= # if i % 10 == 0: # print('loss: ', loss, ' entropy: ', entropy) # ============================================================================= # ============================================================================= # print('loss: ', loss, ' entropy: ', entropy) # ============================================================================= new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),axis=1)) # ============================================================================= # if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly # break # # adaptively adjust the learning rate # if kl > self.kl_targ * 2 and self.lr_multiplier > 0.01: # self.lr_multiplier /= 1.5 # elif kl < self.kl_targ / 2 and self.lr_multiplier < 100: # self.lr_multiplier *= 1.5 # ============================================================================= explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl:{:.5f}," "lr_multiplier:{:.3f}," "loss:{}," "loss_change:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}" ).format(kl, self.lr_multiplier, loss, first_loss - loss, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): win_cnt = defaultdict(int) for i in range(n_games): inter = interface(self.board_length) current_board = copy.deepcopy(self.chess) current_real_mcts = real_mcts(current_board, self.policy_value_net.policy_value, self.cpuct, 1000, self.temperature, self.num_history, False) current_pure_mcts = pure_mcts(current_board, self.pure_mcts_simulation_times) winner = inter.start_play(current_real_mcts, current_pure_mcts, start_player=i % 2) win_cnt[winner] += 1 print('winner', winner) win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[0]) / n_games print("num_simulation_times:{}, win: {}, lose: {}, tie:{}".format(self.pure_mcts_simulation_times,win_cnt[1], win_cnt[2], win_cnt[0])) return win_ratio def run(self): total = 0 for i in range(self.game_batch_num): if (i + 1) % 100 == 0: self.learn_rate = self.learn_rate * 0.85 # ============================================================================= # start = time.time() # ============================================================================= self.collect_selfplay_data(self.play_batch_size) if len(self.data_buffer) >= self.batch_size: loss, entropy = self.policy_update() self.loss_dict[i] = loss total += loss if (i - self.loss_hold) in self.loss_dict: total -= self.loss_dict[i - self.loss_hold] self.loss_dict.pop(i - self.loss_hold) print("batch i:{}, episode_len:{}, loss_hist:{}".format(i + 1, self.episode_len, total / self.loss_hold)) if (i + 1) % self.check_freq == 0: print("current self-play batch: {}".format(i+1)) win_ratio = self.policy_evaluate() self.policy_value_net.save_model('./current_policy.model') if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio self.policy_value_net.save_model('./best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_simulation_times < 10000): self.pure_mcts_simulation_times += 1000 self.best_win_ratio = 0.0