class SelfPlayer(Process): def __init__(self, config, sample_queue, model_queue): super(SelfPlayer, self).__init__() self.config = config self.temp = config['temperature'] self.sample_queue = sample_queue self.model_queue = model_queue self.board = Board(width=config['board_width'], height=config['board_height'], n_in_row=config['n_in_row']) self.game = Game(self.board) def collect_selfplay_data(self, n_games=1): """collect self-play data for training""" samples = [] for i in range(n_games): _, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp) samples.extend(list(play_data)[:]) return samples def run(self): self.policy_value_net = PolicyValueNet( self.config['board_width'], self.config['board_height'], model_file=self.config['init_model']) self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.config['c_puct'], n_playout=self.config['n_playout'], is_selfplay=1) print("running") while True: # always use the latest weight weights = None while not self.model_queue.empty(): weights = self.model_queue.get() if weights: self.policy_value_net.set_weight(weights) # sample samples = self.collect_selfplay_data() # put the new sample to sample queue self.sample_queue.put(samples)
class Evaluator(Process): def __init__(self, config, weight_queue): super(Evaluator, self).__init__() self.config = config self.queue = weight_queue self.best_win_ratio = 0.0 self.pure_mcts_playout_num = self.config['pure_mcts_playout_num'] def run(self): self.policy_value_net = PolicyValueNet( self.config['board_width'], self.config['board_height'], model_file=self.config['init_model']) while True: weight = self.queue.get() self.policy_value_net.set_weight(weight) win_ratio = self.policy_evaluate() self.policy_value_net.save_model( self.config['current_policy_name']) if win_ratio > self.best_win_ratio: print("New best policy!!!!!!!!") self.best_win_ratio = win_ratio # update the best_policy self.policy_value_net.save_model( self.config['best_policy_name']) if (self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 10000): self.pure_mcts_playout_num += 1000 self.best_win_ratio = 0.0 def policy_evaluate(self, n_games=10): """ Evaluate the trained policy by playing against the pure MCTS player Note: this is only for monitoring the progress of training """ self.evaluate_game = Game( Board(width=self.config['board_width'], height=self.config['board_height'], n_in_row=self.config['n_in_row'])) current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.config['c_puct'], n_playout=self.config['n_playout']) pure_mcts_player = MCTS_Pure( c_puct=5, n_playout=self.config['pure_mcts_playout_num']) win_cnt = defaultdict(int) for i in range(n_games): winner = self.evaluate_game.start_play(current_mcts_player, pure_mcts_player, start_player=i % 2, is_shown=0) win_cnt[winner] += 1 win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games print("num_playouts:{}, win: {}, lose: {}, tie:{}".format( self.config['pure_mcts_playout_num'], win_cnt[1], win_cnt[2], win_cnt[-1])) return win_ratio