예제 #1
0
class SelfPlayer(Process):
    def __init__(self, config, sample_queue, model_queue):
        super(SelfPlayer, self).__init__()

        self.config = config
        self.temp = config['temperature']

        self.sample_queue = sample_queue
        self.model_queue = model_queue

        self.board = Board(width=config['board_width'],
                           height=config['board_height'],
                           n_in_row=config['n_in_row'])
        self.game = Game(self.board)

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        samples = []
        for i in range(n_games):
            _, play_data = self.game.start_self_play(self.mcts_player,
                                                     temp=self.temp)
            samples.extend(list(play_data)[:])
        return samples

    def run(self):

        self.policy_value_net = PolicyValueNet(
            self.config['board_width'],
            self.config['board_height'],
            model_file=self.config['init_model'])
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.config['c_puct'],
                                      n_playout=self.config['n_playout'],
                                      is_selfplay=1)

        print("running")
        while True:
            # always use the latest weight
            weights = None
            while not self.model_queue.empty():
                weights = self.model_queue.get()
            if weights:
                self.policy_value_net.set_weight(weights)

            # sample
            samples = self.collect_selfplay_data()
            # put the new sample to sample queue
            self.sample_queue.put(samples)
예제 #2
0
class Evaluator(Process):
    def __init__(self, config, weight_queue):
        super(Evaluator, self).__init__()
        self.config = config
        self.queue = weight_queue

        self.best_win_ratio = 0.0
        self.pure_mcts_playout_num = self.config['pure_mcts_playout_num']

    def run(self):
        self.policy_value_net = PolicyValueNet(
            self.config['board_width'],
            self.config['board_height'],
            model_file=self.config['init_model'])

        while True:
            weight = self.queue.get()
            self.policy_value_net.set_weight(weight)
            win_ratio = self.policy_evaluate()
            self.policy_value_net.save_model(
                self.config['current_policy_name'])

            if win_ratio > self.best_win_ratio:
                print("New best policy!!!!!!!!")
                self.best_win_ratio = win_ratio
                # update the best_policy
                self.policy_value_net.save_model(
                    self.config['best_policy_name'])
                if (self.best_win_ratio == 1.0
                        and self.pure_mcts_playout_num < 10000):
                    self.pure_mcts_playout_num += 1000
                    self.best_win_ratio = 0.0

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        self.evaluate_game = Game(
            Board(width=self.config['board_width'],
                  height=self.config['board_height'],
                  n_in_row=self.config['n_in_row']))

        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.config['c_puct'],
                                         n_playout=self.config['n_playout'])

        pure_mcts_player = MCTS_Pure(
            c_puct=5, n_playout=self.config['pure_mcts_playout_num'])

        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.evaluate_game.start_play(current_mcts_player,
                                                   pure_mcts_player,
                                                   start_player=i % 2,
                                                   is_shown=0)
            win_cnt[winner] += 1

        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
            self.config['pure_mcts_playout_num'], win_cnt[1], win_cnt[2],
            win_cnt[-1]))
        return win_ratio