def __init__(self, conf, init_model=None): # params of the board and the game self.board_width = conf['board_width'] self.board_height = conf['board_height'] self.n_in_row = conf['n_in_row'] self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) self.game_ai = Game_AI(self.board) # training params self.learn_rate = conf['learn_rate'] self.lr_multiplier = conf[ 'lr_multiplier'] # adaptively adjust the learning rate based on KL self.temp = conf['temp'] # the temperature param self.n_playout = conf[ 'n_playout'] # 500 # num of simulations for each move self.c_puct = conf['c_puct'] self.buffer_size = conf['buffer_size'] self.batch_size = conf['batch_size'] # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = conf['play_batch_size'] self.epochs = conf['epochs'] # num of train_steps for each update self.kl_targ = conf['kl_targ'] self.check_freq = conf['check_freq'] self.game_batch_num = conf['game_batch_num'] self.best_win_ratio = 0.0 # 多线程相关 self._cpu_count = mp.cpu_count() - 8 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = conf['pure_mcts_playout_num'] # 训练集文件 self._sgf_home = current_relative_path(conf['sgf_dir']) _logger.info('path: %s' % self._sgf_home) self._ai_data_home = current_relative_path(conf['ai_data_dir']) # 加载人类对弈数据 self._load_training_data(self._sgf_home) # 加载保存的自对弈数据 # self._load_pickle_data(self._ai_data_home) if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, self.batch_size, n_blocks=10, n_filter=128, model_params=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, self.batch_size, n_blocks=10, n_filter=128) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1)
def run(): n = 5 width, height = 15, 15 model_file = 'best_policy.model' try: board = Board(width=width, height=height, n_in_row=n) game = Game(board) # ############### human VS AI ################### # load the trained policy_value_net in either Theano/Lasagne, PyTorch or TensorFlow # best_policy = PolicyValueNet(width, height, model_file = model_file) # mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) # load the provided model (trained in Theano/Lasagne) into a MCTS player written in pure numpy try: policy_param = pickle.load(open(model_file, 'rb')) except: policy_param = pickle.load(open(model_file, 'rb'), encoding='bytes') # To support python3 best_policy = PolicyValueNet(width, height, policy_param) mcts_player = MCTSPlayer(best_policy.policy_value_fn, c_puct=5, n_playout=400) # set larger n_playout for better performance # uncomment the following line to play with pure MCTS (it's much weaker even with a larger n_playout) # mcts_player = MCTS_Pure(c_puct=5, n_playout=1000) # human player, input your move in the format: 2,3 human = Human() # set start_player=0 for human first game.start_play(human, mcts_player, start_player=1, is_shown=1) except KeyboardInterrupt: print('\n\rquit')
def __init__(self, init_model=None): # params of the board and the game self.parallel_games = 1 #self.pool = Pool() self.board_width = 8 self.board_height = 8 self.n_in_row = 5 # training params self.learn_rate = 1e-4 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 200 self.batch_size = 128 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.001 self.check_freq = 1000 self.game_batch_num = 150000 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_params=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) params = self.policy_value_net.get_policy_param() infos = (self.board_height, self.board_width, self.n_in_row, self.temp, self.c_puct, self.n_playout) logging.info('hello......0') #self.mcts_players = [Actor.remote('gamer_'+str(gi), 2, infos, params) for gi in range(self.parallel_games)] self.mcts_players = [Actor('gamer_'+str(gi), 2, infos, params) for gi in range(self.parallel_games)] self.mcts_evaluater = Actor('evaluater', 2, infos, params) logging.info('hello......1')
def __init__(self, init_model=None): # params of the board and the game self.board_width = 15 self.board_height = 15 self.n_in_row = 4 self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 50 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 1500 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1)
winners_z[np.array(current_players) == winner] = 1.0 winners_z[np.array(current_players) != winner] = -1.0 # reset MCTS root node player.reset_player() if is_shown: if winner != -1: print("Game end. Winner is player:", winner) else: print("Game end. Tie") # go_on = input('go on:') # winner 1:2 return warning, winner, zip(states, mcts_probs, winners_z) if __name__ == '__main__': model_file = 'current_policy.model' policy_value_net = PolicyValueNet(15, 15) mcts_player = MCTSPlayer(policy_value_net.policy_value_fn, c_puct=3, n_playout=2, is_selfplay=1) board = Board(width=15, height=15, n_in_row=5) game = Game(board) sgf_home = current_relative_path('./sgf_data') file_name = '1000_white_.sgf' winner, play_data = game.start_self_play(mcts_player, is_shown=1, temp=1.0, sgf_home=sgf_home, file_name=file_name)
class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game self.board_width = 15 self.board_height = 15 self.n_in_row = 4 self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) # training params self.learn_rate = 2e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 10000 self.batch_size = 512 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 50 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 50 self.game_batch_num = 1500 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_file=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90( np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" for i in range(n_games): winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) #print(len(self.data_buffer), n_games) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) learn_rate = self.learn_rate * self.lr_multiplier for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, learn_rate) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly print('early stopping:', i, self.epochs) break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1: self.lr_multiplier /= 1.1 elif kl < self.kl_targ / 2 and self.lr_multiplier < 10: self.lr_multiplier *= 1.1 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl:{:.5f}," "lr:{:.6f}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}").format(kl, learn_rate, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data(self.play_batch_size) print("batch i:{}, episode_len:{}".format( i + 1, self.episode_len)) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() # check the performance of the current model, # and save the model params if (i + 1) % self.check_freq == 0: print("current self-play batch: {}".format(i + 1)) win_ratio = self.policy_evaluate() self.policy_value_net.save_model('./current_policy.model') if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model('./best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self, conf, init_model=None): # params of the board and the game self.board_width = conf['board_width'] self.board_height = conf['board_height'] self.n_in_row = conf['n_in_row'] self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row) self.game = Game(self.board) self.game_ai = Game_AI(self.board) # training params self.learn_rate = conf['learn_rate'] self.lr_multiplier = conf[ 'lr_multiplier'] # adaptively adjust the learning rate based on KL self.temp = conf['temp'] # the temperature param self.n_playout = conf[ 'n_playout'] # 500 # num of simulations for each move self.c_puct = conf['c_puct'] self.buffer_size = conf['buffer_size'] self.batch_size = conf['batch_size'] # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = conf['play_batch_size'] self.epochs = conf['epochs'] # num of train_steps for each update self.kl_targ = conf['kl_targ'] self.check_freq = conf['check_freq'] self.game_batch_num = conf['game_batch_num'] self.best_win_ratio = 0.0 # 多线程相关 self._cpu_count = mp.cpu_count() - 8 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = conf['pure_mcts_playout_num'] # 训练集文件 self._sgf_home = current_relative_path(conf['sgf_dir']) _logger.info('path: %s' % self._sgf_home) self._ai_data_home = current_relative_path(conf['ai_data_dir']) # 加载人类对弈数据 self._load_training_data(self._sgf_home) # 加载保存的自对弈数据 # self._load_pickle_data(self._ai_data_home) if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, self.batch_size, n_blocks=10, n_filter=128, model_params=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, self.batch_size, n_blocks=10, n_filter=128) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1) def _load_training_data(self, data_dir): file_list = os.listdir(data_dir) self._training_data = [ item for item in file_list if item.endswith('.sgf') and os.path.isfile(os.path.join(data_dir, item)) ] random.shuffle(self._training_data) self._length_train_data = len(self._training_data) """" def _load_pickle_data(self, data_dir): file_list = os.listdir(data_dir) txt_list = [item for item in file_list if item.endswith('.txt') and os.path.isfile(os.path.join(data_dir, item))] self._ai_history_data = [] for txt_f in txt_list: with open(os.path.join(data_dir, txt_f), 'rb') as f_object: d = pickle.load(f_object) self._ai_history_data += d f_object.close() """ def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90( np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append( (equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self, n_games=1, training_index=None): """collect SGF file data for training""" data_index = training_index % self._length_train_data if data_index == 0: random.shuffle(self._training_data) for i in range(n_games): warning, winner, play_data = self.game.start_self_play( self.mcts_player, temp=self.temp, sgf_home=self._sgf_home, file_name=self._training_data[data_index]) if warning: _logger.error( '\033[0;41m %s \033[0m anxingle_training_index: %s, data_index: %s, file: %s' % ('WARNING', training_index, data_index, self._training_data[data_index])) else: _logger.info('winner: %s, file: %s ' % (winner, self._training_data[data_index])) # print('play_data: ', play_data) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) _logger.info('game_batch_index: %s, length of data_buffer: %s' % (training_index, len(self.data_buffer))) """ def collect_selfplay_data_pickle(self, n_games=1, training_index=None): # load AI self play data(auto save for every N game play nums) data_index = training_index % len(self._ai_history_data) if data_index == 0: random.shuffle(self._ai_history_data) for i in range(n_games): play_data = self._ai_history_data[data_index] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) """ def collect_selfplay_data_ai(self, n_games=1, training_index=None): """collect AI self-play data for training""" for i in range(n_games): winner, play_data = self.game_ai.start_self_play(self.mcts_player, temp=self.temp) _logger.info('traing_index: %s, winner is: %s' % (training_index, winner)) play_data = list(play_data)[:] self.episode_len = len(play_data) # augment the data play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) # def _multiprocess_collect_selfplay_data(self, q, process_index): # """ # TODO: CUDA multiprocessing have bugs! # winner, play_data = self.game.start_self_play(self.mcts_player, # temp=self.temp) # play_data = list(play_data)[:] # self.episode_len = len(play_data) # # augment the data # play_data = self.get_equi_data(play_data) # q.put(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) learn_rate = self.learn_rate * self.lr_multiplier for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, learn_rate) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean( np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1)) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly _logger.info('early stopping. i:%s. epochs: %s' % (i, self.epochs)) break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.05: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 20: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) _logger.info( ("kl:{:.4f}," "lr:{:.1e}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}").format(kl, learn_rate, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games _logger.info("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): current_time = time.time() # 这里的400是临时值,应该为真实sgf数据 if i < 100: self.collect_selfplay_data(1, training_index=i) else: self.collect_selfplay_data_ai(10, training_index=i) _logger.info('collection cost time: %d ' % (time.time() - current_time)) _logger.info( "batch i:{}, episode_len:{}, buffer_len:{}".format( i + 1, self.episode_len, len(self.data_buffer))) if len(self.data_buffer) > self.batch_size: batch_time = time.time() loss, entropy = self.policy_update() _logger.info('train batch cost time: %d' % (time.time() - batch_time)) # check the performance of the current model, # and save the model params if (i + 1) % 50 == 0: self.policy_value_net.save_model( './logs/current_policy.model') if (i + 1) % self.check_freq == 0: check_time = time.time() _logger.info("current self-play batch: {}".format(i + 1)) win_ratio = self.policy_evaluate() _logger.info('evaluate the network cost time: %s ', int(time.time() - check_time)) if win_ratio > self.best_win_ratio: _logger.info("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model( './logs/best_policy_%s.model' % i) if (self.best_win_ratio >= 0.98 and self.pure_mcts_playout_num < 8000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: _logger.info('\n\rquit')
def __init__(self, init_model=None): # params of the board and the game self.train_sampling_times = 20 self.selfplay_count = 0 self.parallel_games = 1 #self.pool = Pool() self.board_width = 8 self.board_height = 8 self.n_in_row = 5 # training params self.learn_rate = 1e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.agent_sampling_size = 128 self.batch_size = self.agent_sampling_size*comm_size # mini-batch size for training if comm_rank == 0: self.buffer_size = self.agent_sampling_size*(comm_size*self.train_sampling_times) self.data_buffer = deque(maxlen=self.buffer_size) else: self.buffer_size = 0 self.data_buffer = [] self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 1 self.game_batch_num = 150000 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 self.policy_value_net = None if comm_rank == 0: if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(0, self.batch_size, self.board_width, self.board_height, model_params=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(0, self.batch_size, self.board_width, self.board_height) self.mcts_player = None self.mcts_evaluater = None self.params = None infos = (self.board_height, self.board_width, self.n_in_row, self.temp, self.c_puct, self.n_playout) if comm_rank>0: logging.info('rank '+str(comm_rank)+' before recv ') self.params = comm.recv(source=0) logging.info('rank '+str(comm_rank)+' after recv ') self.mcts_player = Actor('gamer_'+str(comm_rank), 1, infos, self.params) if self.policy_value_net and comm_rank==0: self.params = self.policy_value_net.get_policy_param() logging.info('rank '+str(comm_rank)+' before bcast') #comm.bcast('params', root=0) for pi in range(1, comm_size): comm.send(self.params, dest=pi) logging.info('rank '+str(comm_rank)+' after bcast') self.mcts_player = Actor('gamer_'+str(comm_rank), 0, infos, self.params) self.mcts_evaluater = Actor('evaluater', 1, infos, self.params)
class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game self.train_sampling_times = 20 self.selfplay_count = 0 self.parallel_games = 1 #self.pool = Pool() self.board_width = 8 self.board_height = 8 self.n_in_row = 5 # training params self.learn_rate = 1e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.agent_sampling_size = 128 self.batch_size = self.agent_sampling_size*comm_size # mini-batch size for training if comm_rank == 0: self.buffer_size = self.agent_sampling_size*(comm_size*self.train_sampling_times) self.data_buffer = deque(maxlen=self.buffer_size) else: self.buffer_size = 0 self.data_buffer = [] self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 1 self.game_batch_num = 150000 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 self.policy_value_net = None if comm_rank == 0: if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(0, self.batch_size, self.board_width, self.board_height, model_params=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(0, self.batch_size, self.board_width, self.board_height) self.mcts_player = None self.mcts_evaluater = None self.params = None infos = (self.board_height, self.board_width, self.n_in_row, self.temp, self.c_puct, self.n_playout) if comm_rank>0: logging.info('rank '+str(comm_rank)+' before recv ') self.params = comm.recv(source=0) logging.info('rank '+str(comm_rank)+' after recv ') self.mcts_player = Actor('gamer_'+str(comm_rank), 1, infos, self.params) if self.policy_value_net and comm_rank==0: self.params = self.policy_value_net.get_policy_param() logging.info('rank '+str(comm_rank)+' before bcast') #comm.bcast('params', root=0) for pi in range(1, comm_size): comm.send(self.params, dest=pi) logging.info('rank '+str(comm_rank)+' after bcast') self.mcts_player = Actor('gamer_'+str(comm_rank), 0, infos, self.params) self.mcts_evaluater = Actor('evaluater', 1, infos, self.params) def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90(np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self): """collect self-play data for training""" datas = self.mcts_player.Play() self.episode_len = 0 _len = len(datas) self.episode_len += _len play_data = self.get_equi_data(datas) if comm_rank == 0: self.data_buffer.extend(play_data) logging.info('gamer_%d %d collection finished.'%(comm_rank, self.episode_len)) if comm_rank > 0: self.data_buffer = play_data logging.info('gamer_%d %d sending data started...'%(comm_rank, self.episode_len)) comm.send(self.data_buffer, dest=0) logging.info('gamer_%d %d sending data finished...'%(comm_rank, self.episode_len)) def policy_update_old(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) learn_rate = self.learn_rate*self.lr_multiplier for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, learn_rate) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * ( np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1) ) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly #if kl > self.kl_targ: # early stopping if D_KL diverges badly logging.info('early stopping:%d, %d'%(i, self.epochs)) break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.05: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 20: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) logging.info(("kl:{:.4f}," "lr:{:.1e}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}" ).format(kl, learn_rate, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_update(self, train_i): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) learn_rate = self.learn_rate*self.lr_multiplier for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, learn_rate) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * ( np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1) ) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly #if kl > self.kl_targ: # early stopping if D_KL diverges badly logging.info('early stopping:%d, %d'%(i, self.epochs)) break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.05: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 20: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) if train_i%4==0: logging.info(("kl:{:.4f}," "lr:{:.1e}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}" ).format(kl, learn_rate, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.mcts_evaluater.selfplay_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.mcts_evaluater.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games logging.info("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): recv_count = 1 logging.info('sending params to actors...') for nodei in range(1, comm_size): #logging.info('sending params to actor %d ...'%(nodei)) comm.send(self.params, dest=nodei) self.collect_selfplay_data() logging.info('receiving data from actors...') for nodei in range(1, comm_size): data = comm.recv(source=nodei) logging.info('receiving data from actor %d: %d...'%(nodei, len(data))) self.data_buffer.extend(data) recv_count += 1 logging.info("batch i:{}, batchsize:{}, recv count:{}, buffer_len:{}".format( i+1, self.batch_size, recv_count, len(self.data_buffer))) if len(self.data_buffer) >= self.buffer_size: for train_i in range(self.train_sampling_times): loss, entropy = self.policy_update(train_i) self.params = self.policy_value_net.get_policy_param() self.mcts_evaluater.Set_Params(self.params) self.mcts_player.Set_Params(self.params) # check the performance of the current model, # and save the model params if (i+1) % 10 == 0: self.policy_value_net.save_model('./current_policy.model') if (i+1) % self.check_freq == 0: logging.info("current self-play batch: {}".format(i+1)) win_ratio = self.policy_evaluate() if win_ratio > self.best_win_ratio: logging.info("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model('./best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit') def run_selfplay(self): logging.info('selfplayer ' + str(comm_rank) + ' is started.....') while(True): cmd = comm.recv(source=0) if cmd is not None: self.params = cmd self.mcts_player.Set_Params(self.params) self.selfplay_count += 1 logging.info('start selfplaying...%d'%(self.selfplay_count)) self.collect_selfplay_data() logging.info('finished selfplaying...%d'%(self.selfplay_count)) time.sleep(1)