class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game self.parallel_games = 1 #self.pool = Pool() self.board_width = 8 self.board_height = 8 self.n_in_row = 5 # training params self.learn_rate = 1e-4 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.buffer_size = 200 self.batch_size = 128 # mini-batch size for training self.data_buffer = deque(maxlen=self.buffer_size) self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.001 self.check_freq = 1000 self.game_batch_num = 150000 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, model_params=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(self.board_width, self.board_height) params = self.policy_value_net.get_policy_param() infos = (self.board_height, self.board_width, self.n_in_row, self.temp, self.c_puct, self.n_playout) logging.info('hello......0') #self.mcts_players = [Actor.remote('gamer_'+str(gi), 2, infos, params) for gi in range(self.parallel_games)] self.mcts_players = [Actor('gamer_'+str(gi), 2, infos, params) for gi in range(self.parallel_games)] self.mcts_evaluater = Actor('evaluater', 2, infos, params) logging.info('hello......1') def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90(np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self): """collect self-play data for training""" logging.info('collect_selfplay_data....0') #datas = ray.get([self.mcts_players[pgi].Play.remote() for pgi in range(self.parallel_games)]) datas = [self.mcts_players[pgi].Play() for pgi in range(self.parallel_games)] #datas = [self.pool.apply(now_play, (self.mcts_players[pgi],)) for pgi in range(self.parallel_games)] #datas = [self.mcts_players[pgi].start() for pgi in range(self.parallel_games)] #datas = [self.mcts_players[pgi].join() for pgi in range(self.parallel_games)] logging.info('collect_selfplay_data....1') self.episode_len = 0 for pgi in range(self.parallel_games): play_data = datas[pgi] _len = len(play_data) self.episode_len += _len print('game ', pgi, _len) play_data = self.get_equi_data(play_data) self.data_buffer.extend(play_data) def policy_update(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) learn_rate = self.learn_rate*self.lr_multiplier for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, learn_rate) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * ( np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1) ) if kl > self.kl_targ * 1: # early stopping if D_KL diverges badly print('early stopping:', i, self.epochs) break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.05: self.lr_multiplier /= 1.0 elif kl < self.kl_targ / 2 and self.lr_multiplier < 20: self.lr_multiplier *= 1.0 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) print(("kl:{:.4f}," "lr:{:.1e}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}" ).format(kl, learn_rate, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.mcts_evaluater.selfplay_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.mcts_evaluater.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): self.collect_selfplay_data() print("batch i:{}, episode_len:{}, buffer_len:{}".format( i+1, self.episode_len, len(self.data_buffer))) if len(self.data_buffer) > self.batch_size: loss, entropy = self.policy_update() params = self.policy_value_net.get_policy_param() self.mcts_evaluater.Set_Params(params) for spi in range(self.parallel_games): gamer = self.mcts_players[spi] #gamer.Set_Params.remote(params) gamer.Set_Params(params) # check the performance of the current model, # and save the model params if (i+1) % 50 == 0: self.policy_value_net.save_model('./current_policy.model') if (i+1) % self.check_freq == 0: print("current self-play batch: {}".format(i+1)) win_ratio = self.policy_evaluate() if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model('./best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit')
class TrainPipeline(): def __init__(self, init_model=None): # params of the board and the game self.train_sampling_times = 20 self.selfplay_count = 0 self.parallel_games = 1 #self.pool = Pool() self.board_width = 8 self.board_height = 8 self.n_in_row = 5 # training params self.learn_rate = 1e-3 self.lr_multiplier = 1.0 # adaptively adjust the learning rate based on KL self.temp = 1.0 # the temperature param self.n_playout = 400 # num of simulations for each move self.c_puct = 5 self.agent_sampling_size = 128 self.batch_size = self.agent_sampling_size*comm_size # mini-batch size for training if comm_rank == 0: self.buffer_size = self.agent_sampling_size*(comm_size*self.train_sampling_times) self.data_buffer = deque(maxlen=self.buffer_size) else: self.buffer_size = 0 self.data_buffer = [] self.play_batch_size = 1 self.epochs = 5 # num of train_steps for each update self.kl_targ = 0.02 self.check_freq = 1 self.game_batch_num = 150000 self.best_win_ratio = 0.0 # num of simulations used for the pure mcts, which is used as # the opponent to evaluate the trained policy self.pure_mcts_playout_num = 1000 self.policy_value_net = None if comm_rank == 0: if init_model: # start training from an initial policy-value net self.policy_value_net = PolicyValueNet(0, self.batch_size, self.board_width, self.board_height, model_params=init_model) else: # start training from a new policy-value net self.policy_value_net = PolicyValueNet(0, self.batch_size, self.board_width, self.board_height) self.mcts_player = None self.mcts_evaluater = None self.params = None infos = (self.board_height, self.board_width, self.n_in_row, self.temp, self.c_puct, self.n_playout) if comm_rank>0: logging.info('rank '+str(comm_rank)+' before recv ') self.params = comm.recv(source=0) logging.info('rank '+str(comm_rank)+' after recv ') self.mcts_player = Actor('gamer_'+str(comm_rank), 1, infos, self.params) if self.policy_value_net and comm_rank==0: self.params = self.policy_value_net.get_policy_param() logging.info('rank '+str(comm_rank)+' before bcast') #comm.bcast('params', root=0) for pi in range(1, comm_size): comm.send(self.params, dest=pi) logging.info('rank '+str(comm_rank)+' after bcast') self.mcts_player = Actor('gamer_'+str(comm_rank), 0, infos, self.params) self.mcts_evaluater = Actor('evaluater', 1, infos, self.params) def get_equi_data(self, play_data): """augment the data set by rotation and flipping play_data: [(state, mcts_prob, winner_z), ..., ...] """ extend_data = [] for state, mcts_porb, winner in play_data: for i in [1, 2, 3, 4]: # rotate counterclockwise equi_state = np.array([np.rot90(s, i) for s in state]) equi_mcts_prob = np.rot90(np.flipud( mcts_porb.reshape(self.board_height, self.board_width)), i) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) # flip horizontally equi_state = np.array([np.fliplr(s) for s in equi_state]) equi_mcts_prob = np.fliplr(equi_mcts_prob) extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner)) return extend_data def collect_selfplay_data(self): """collect self-play data for training""" datas = self.mcts_player.Play() self.episode_len = 0 _len = len(datas) self.episode_len += _len play_data = self.get_equi_data(datas) if comm_rank == 0: self.data_buffer.extend(play_data) logging.info('gamer_%d %d collection finished.'%(comm_rank, self.episode_len)) if comm_rank > 0: self.data_buffer = play_data logging.info('gamer_%d %d sending data started...'%(comm_rank, self.episode_len)) comm.send(self.data_buffer, dest=0) logging.info('gamer_%d %d sending data finished...'%(comm_rank, self.episode_len)) def policy_update_old(self): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) learn_rate = self.learn_rate*self.lr_multiplier for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, learn_rate) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * ( np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1) ) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly #if kl > self.kl_targ: # early stopping if D_KL diverges badly logging.info('early stopping:%d, %d'%(i, self.epochs)) break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.05: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 20: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) logging.info(("kl:{:.4f}," "lr:{:.1e}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}" ).format(kl, learn_rate, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_update(self, train_i): """update the policy-value net""" mini_batch = random.sample(self.data_buffer, self.batch_size) state_batch = [data[0] for data in mini_batch] mcts_probs_batch = [data[1] for data in mini_batch] winner_batch = [data[2] for data in mini_batch] old_probs, old_v = self.policy_value_net.policy_value(state_batch) learn_rate = self.learn_rate*self.lr_multiplier for i in range(self.epochs): loss, entropy = self.policy_value_net.train_step( state_batch, mcts_probs_batch, winner_batch, learn_rate) new_probs, new_v = self.policy_value_net.policy_value(state_batch) kl = np.mean(np.sum(old_probs * ( np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1) ) if kl > self.kl_targ * 4: # early stopping if D_KL diverges badly #if kl > self.kl_targ: # early stopping if D_KL diverges badly logging.info('early stopping:%d, %d'%(i, self.epochs)) break # adaptively adjust the learning rate if kl > self.kl_targ * 2 and self.lr_multiplier > 0.05: self.lr_multiplier /= 1.5 elif kl < self.kl_targ / 2 and self.lr_multiplier < 20: self.lr_multiplier *= 1.5 explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))) explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))) if train_i%4==0: logging.info(("kl:{:.4f}," "lr:{:.1e}," "loss:{}," "entropy:{}," "explained_var_old:{:.3f}," "explained_var_new:{:.3f}" ).format(kl, learn_rate, loss, entropy, explained_var_old, explained_var_new)) return loss, entropy def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ current_mcts_player = MCTSPlayer(self.mcts_evaluater.selfplay_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout) pure_mcts_player = MCTS_Pure(c_puct=5, n_playout=self.pure_mcts_playout_num) win_cnt = defaultdict(int) for i in range(n_games): winner = self.mcts_evaluater.game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games logging.info("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio def run(self): """run the training pipeline""" try: for i in range(self.game_batch_num): recv_count = 1 logging.info('sending params to actors...') for nodei in range(1, comm_size): #logging.info('sending params to actor %d ...'%(nodei)) comm.send(self.params, dest=nodei) self.collect_selfplay_data() logging.info('receiving data from actors...') for nodei in range(1, comm_size): data = comm.recv(source=nodei) logging.info('receiving data from actor %d: %d...'%(nodei, len(data))) self.data_buffer.extend(data) recv_count += 1 logging.info("batch i:{}, batchsize:{}, recv count:{}, buffer_len:{}".format( i+1, self.batch_size, recv_count, len(self.data_buffer))) if len(self.data_buffer) >= self.buffer_size: for train_i in range(self.train_sampling_times): loss, entropy = self.policy_update(train_i) self.params = self.policy_value_net.get_policy_param() self.mcts_evaluater.Set_Params(self.params) self.mcts_player.Set_Params(self.params) # check the performance of the current model, # and save the model params if (i+1) % 10 == 0: self.policy_value_net.save_model('./current_policy.model') if (i+1) % self.check_freq == 0: logging.info("current self-play batch: {}".format(i+1)) win_ratio = self.policy_evaluate() if win_ratio > self.best_win_ratio: logging.info("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model('./best_policy.model') if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 except KeyboardInterrupt: print('\n\rquit') def run_selfplay(self): logging.info('selfplayer ' + str(comm_rank) + ' is started.....') while(True): cmd = comm.recv(source=0) if cmd is not None: self.params = cmd self.mcts_player.Set_Params(self.params) self.selfplay_count += 1 logging.info('start selfplaying...%d'%(self.selfplay_count)) self.collect_selfplay_data() logging.info('finished selfplaying...%d'%(self.selfplay_count)) time.sleep(1)