Пример #1
0
class TrainPipeline:
    def __init__(self, init_model=None):
        # 棋盘数据
        self.board_width = 8
        self.board_height = 8
        # self.n_in_row = 5
        self.board = chessboard(row=self.board_width, col=self.board_height)
        # 训练参数
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0
        self.temp = 1.0
        self.n_playout = 400  # 每次模拟次数
        self.c_puct = 5
        self.buffer_size = 10000000
        self.batch_size = 512  # 每批样本量
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # 每次更新前迭代次数
        self.kl_targ = 0.02
        self.check_freq = 2
        # 自我对弈次数
        self.game_batch_num = 1000
        self.best_win_ratio = 0.0
        # 纯蒙特卡罗树搜索,用来作为基准
        self.pure_mcts_playout_num = 400
        # 有预训练模型的情况
        if init_model:
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model)
        else:
            # 从头开始训练
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    # 扩充训练数据
    def get_equi_data(self, play_data):
        # 用旋转和翻转来设置数据
        # play_data:[(state, mcts_prob, winner_z), ..., ...]
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # 顺时针旋转
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(
                    np.flipud(
                        mcts_porb.reshape(self.board_height,
                                          self.board_width)), i)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # 垂直翻转
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    # 进行一轮自我博弈
    def start_self_play(self, player, is_shown=0, temp=1e-3):
        self.board.reset()
        p1, p2 = self.board.players
        states, mcts_probs, current_players = [], [], []
        # 测试
        # t = 0
        while True:
            # t += 1
            # print(t)
            move, move_probs = player.get_action(self.board,
                                                 temp=temp,
                                                 return_prob=1)
            # print("测试", move_probs)
            # store the data
            states.append(self.board.current_state())
            mcts_probs.append(move_probs)
            current_players.append(self.board.current_player)
            # perform a move
            self.board.do_move(move)
            if is_shown:
                display(self.board)
            end, winner = self.board.game_end()
            # print(t, end, winner, self.board.count)
            if end:
                # winner from the perspective of the current player of each state
                winners_z = np.zeros(len(current_players))
                if winner != -1:
                    winners_z[np.array(current_players) == winner] = 1.0
                    winners_z[np.array(current_players) != winner] = -1.0
                # reset MCTS root node
                player.reset_player()
                if is_shown:
                    if winner != -1:
                        print("Game end. Winner is player:", winner)
                    else:
                        print("Game end. Tie")
                return winner, zip(states, mcts_probs, winners_z)

    # 收集自我博弈训练数据
    def collect_selfplay_data(self, n_games=1):
        for i in range(n_games):
            # print("测试", i)
            winner, play_data = self.start_self_play(self.mcts_player,
                                                     temp=self.temp,
                                                     is_shown=False)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    # 更新策略值网络
    def policy_update(self):
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # 早期停止
                break
        # 调整学习率
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5
        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss,
                                                  entropy, explained_var_old,
                                                  explained_var_new))
        return loss, entropy

    # 进行一局对弈
    def start_play(self, player1, player2, start_player=1, is_shown=1):
        if start_player not in (1, 2):
            raise Exception('start_player should be either 0 (player1 first) '
                            'or 1 (player2 first)')
        self.board.reset(start_player)
        p1, p2 = self.board.players
        player1.set_player_ind(p1)
        player2.set_player_ind(p2)
        players = {p1: player1, p2: player2}
        if is_shown:
            display(self.board)
        while True:
            current_player = self.board.get_current_player()
            # print(current_player, players)
            player_in_turn = players[current_player]
            move = player_in_turn.get_action(self.board)
            self.board.do_move(move)
            if is_shown:
                display(self.board)
            end, winner = self.board.game_end()
            if end:
                if is_shown:
                    if winner != -1:
                        print("Game end. Winner is", players[winner])
                    else:
                        print("Game end. Tie")
                return winner

    # 策略评估,用纯蒙特卡罗树搜索来做基准
    def policy_evaluate(self, n_games=10):
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.start_play(current_mcts_player,
                                     pure_mcts_player,
                                     start_player=i % 2 + 1,
                                     is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    # 运行训练
    @run.change_dir
    @run.timethis
    def run(self):
        try:
            losses = []
            for i in tqdm.tqdm(range(self.game_batch_num)):
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                    i + 1, self.episode_len))
                # 测试用的
                # self.policy_value_net.save_model('./output/best_policy.model')
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                    losses.append(loss)
                    # print(i, loss)
                # 检查当前模型表现并保存模型
                if (i + 1) % self.check_freq == 0:
                    print("当前自训练次数: {}".format(i + 1))
                    win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model(
                        './output/current_policy.model')
                    if win_ratio > self.best_win_ratio:
                        print("新的最佳策略!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model(
                            './output/best_policy.model')
                        if (self.best_win_ratio == 1.0
                                and self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
            plt.figure()
            plt.plot(losses)
            plt.savefig("./output/loss.png")
        except KeyboardInterrupt:
            print('\n\rquit')
Пример #2
0
class TrainPipeline():
    def __init__(self):
        # params of the board and the game
        self.board_width = BOARD_SIZE
        self.board_height = BOARD_SIZE
        self.board = Board()
        self.game = Game(self.board)
        # training params
        self.learn_rate = 5e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 300  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.025
        self.check_freq = 1
        self.game_batch_num = 1500
        self.best_win_ratio = 0.0
        self.episode_len = 0
        # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 300
        # start training from a given policy-value net
        #        policy_param = pickle.load(open('current_policy.model', 'rb'))
        #        self.policy_value_net = PolicyValueNet(self.board_width, self.board_height, net_params = policy_param)
        # start training from a new policy-value net
        self.policy_value_net = PolicyValueNet(self.board_width,
                                               self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """
        augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]"""
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(
                    np.flipud(
                        mcts_porb.reshape(self.board_height,
                                          self.board_width)), i)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = 1 - np.var(
            np.array(winner_batch) - old_v.flatten()) / np.var(
                np.array(winner_batch))
        explained_var_new = 1 - np.var(
            np.array(winner_batch) - new_v.flatten()) / np.var(
                np.array(winner_batch))
        print(
            "kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}"
            .format(kl, self.lr_multiplier, loss, entropy, explained_var_old,
                    explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing games against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                    i + 1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                    # check the performance of the current model,and save the model params
                if (i + 1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i + 1))
                    win_ratio = self.policy_evaluate()
                    net_params = self.policy_value_net.get_policy_param(
                    )  # get model params
                    pickle.dump(
                        net_params, open('current_policy.model', 'wb'),
                        pickle.HIGHEST_PROTOCOL)  # save model param to file
                    if win_ratio > self.best_win_ratio:
                        print("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        pickle.dump(
                            net_params, open('best_policy.model', 'wb'),
                            pickle.HIGHEST_PROTOCOL)  # update the best_policy
                        if self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 1000:
                            self.pure_mcts_playout_num += 100
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')
Пример #3
0
class TrainPipeline():
    def __init__(self, init_model=None):
        # params of the board and the game
        self.board_width = 6
        self.board_height = 6
        self.n_in_row = 4
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 400  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 1500
        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 1000
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(np.flipud(
                    mcts_porb.reshape(self.board_height, self.board_width)), i)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                    state_batch,
                    mcts_probs_batch,
                    winner_batch,
                    self.learn_rate*self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(np.sum(old_probs * (
                    np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1)
            )
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}"
               ).format(kl,
                        self.lr_multiplier,
                        loss,
                        entropy,
                        explained_var_old,
                        explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
                self.pure_mcts_playout_num,
                win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                        i+1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # check the performance of the current model,
                # and save the model params
                if (i+1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i+1))
                    win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model('./current_policy.model')
                    if win_ratio > self.best_win_ratio:
                        print("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model('./best_policy.model')
                        if (self.best_win_ratio == 1.0 and
                                self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')
Пример #4
0
class TrainPipeline(object):
    def __init__(self, init_model=None):
        self.game = Quoridor()


        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0
        self.temp = 1.0
        self.n_playout = 200
        self.c_puct = 5
        self.buffer_size = 10000
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.kl_targ = 0.02
        self.check_freq = 10
        self.game_batch_num = 1000
        self.best_win_ratio = 0.0
        self.pure_mcts_playout_num = 1000

        self.old_probs = 0
        self.new_probs = 0

        self.first_trained = False

        if init_model:
            self.policy_value_net = PolicyValueNet(model_file=init_model)
        else:
            self.policy_value_net = PolicyValueNet()

        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct,
                                      n_playout=self.n_playout, is_selfplay=1)

    def get_equi_data(self, play_data):

        extend_data = []
        for i, (state, mcts_prob, winner) in enumerate(play_data):
            wall_state = state[:3,:BOARD_SIZE - 1,:BOARD_SIZE - 1]
            dist_state1 = np.reshape(state[(6 + (WALL_NUM + 1) * 2), :BOARD_SIZE, :BOARD_SIZE], (1, BOARD_SIZE, BOARD_SIZE))
            dist_state2 = np.reshape(state[(7 + (WALL_NUM + 1) * 2), :BOARD_SIZE, :BOARD_SIZE], (1, BOARD_SIZE, BOARD_SIZE))

            # horizontally flipped game
            flipped_wall_state = []

            for i in range(3):
                wall_padded = np.fliplr(wall_state[i])
                wall_padded = np.pad(wall_padded, (0,1), mode='constant', constant_values=0)
                flipped_wall_state.append(wall_padded)

            flipped_wall_state = np.array(flipped_wall_state)

            player_position = state[3:5, :,:]

            flipped_player_position = []
            for i in range(2):
                flipped_player_position.append(np.fliplr(player_position[i]))

            flipped_player_position = np.array(flipped_player_position)

            h_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5:, :,:]])

            h_equi_mcts_prob = np.copy(mcts_prob)

            h_equi_mcts_prob[11] = mcts_prob[10]  # SE to SW
            h_equi_mcts_prob[10] = mcts_prob[11]  # SW to SE
            h_equi_mcts_prob[9] = mcts_prob[8]    # NE to NW
            h_equi_mcts_prob[8] = mcts_prob[9]    # NW to NE
            h_equi_mcts_prob[7] = mcts_prob[6]    # EE to WW
            h_equi_mcts_prob[6] = mcts_prob[7]    # WW to EE
            h_equi_mcts_prob[3] = mcts_prob[2]    # E to W
            h_equi_mcts_prob[2] = mcts_prob[3]    # W to E

            h_wall_actions = h_equi_mcts_prob[12:12 + (BOARD_SIZE-1) ** 2].reshape(BOARD_SIZE-1, BOARD_SIZE-1)
            v_wall_actions = h_equi_mcts_prob[12 + (BOARD_SIZE-1) ** 2:].reshape(BOARD_SIZE-1, BOARD_SIZE -1)

            flipped_h_wall_actions = np.fliplr(h_wall_actions)
            flipped_v_wall_actions = np.fliplr(v_wall_actions)

            h_equi_mcts_prob[12:] = np.hstack([flipped_h_wall_actions.flatten(), flipped_v_wall_actions.flatten()])

            # Vertically flipped game
            flipped_wall_state = []

            for i in range(3):
                wall_padded = np.flipud(wall_state[i])
                wall_padded = np.pad(wall_padded, (0,1), mode='constant', constant_values=0)
                flipped_wall_state.append(wall_padded)

            flipped_wall_state = np.array(flipped_wall_state)


            flipped_player_position = []
            for i in range(2):
                flipped_player_position.append(np.flipud(player_position[1-i]))

            flipped_player_position = np.array(flipped_player_position)

            cur_player = (np.ones((BOARD_SIZE, BOARD_SIZE)) - state[5 + 2* (WALL_NUM+1),:,:]).reshape(-1,BOARD_SIZE, BOARD_SIZE)

            v_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5+(WALL_NUM+1):5 + 2*(WALL_NUM+1), :,:], state[5:5+(WALL_NUM+1),:,:], cur_player, dist_state2, dist_state1])
            # v_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5:(5 + (WALL_NUM+1) * 2), :, :], cur_player, state[:(6 + (WALL_NUM + 1) * 2), :, :]])


            v_equi_mcts_prob = np.copy(mcts_prob)

            v_equi_mcts_prob[11] = mcts_prob[9]  # SE to NE
            v_equi_mcts_prob[10] = mcts_prob[8]  # SW to NW
            v_equi_mcts_prob[9] = mcts_prob[11]  # NE to SE
            v_equi_mcts_prob[8] = mcts_prob[10]  # NW to SW
            v_equi_mcts_prob[5] = mcts_prob[4]   # NN to SS
            v_equi_mcts_prob[4] = mcts_prob[5]   # SS to NN
            v_equi_mcts_prob[1] = mcts_prob[0]   # N to S
            v_equi_mcts_prob[0] = mcts_prob[1]   # S to N

            h_wall_actions = v_equi_mcts_prob[12:12 + (BOARD_SIZE-1) ** 2].reshape(BOARD_SIZE-1, BOARD_SIZE-1)
            v_wall_actions = v_equi_mcts_prob[12 + (BOARD_SIZE-1) ** 2:].reshape(BOARD_SIZE-1, BOARD_SIZE -1)

            flipped_h_wall_actions = np.flipud(h_wall_actions)
            flipped_v_wall_actions = np.flipud(v_wall_actions)

            v_equi_mcts_prob[12:] = np.hstack([flipped_h_wall_actions.flatten(), flipped_v_wall_actions.flatten()])

            ## Horizontally-vertically flipped game

            wall_state = state[:3,:BOARD_SIZE - 1,:BOARD_SIZE - 1]
            flipped_wall_state = []

            for i in range(3):
                wall_padded = np.fliplr(np.flipud(wall_state[i]))
                wall_padded = np.pad(wall_padded, (0,1), mode='constant', constant_values=0)
                flipped_wall_state.append(wall_padded)

            flipped_wall_state = np.array(flipped_wall_state)



            flipped_player_position = []
            for i in range(2):
                flipped_player_position.append(np.fliplr(np.flipud(player_position[1-i])))

            flipped_player_position = np.array(flipped_player_position)

            cur_player = (np.ones((BOARD_SIZE, BOARD_SIZE)) - state[5 + 2*(WALL_NUM+1),:,:]).reshape(-1,BOARD_SIZE, BOARD_SIZE)

            hv_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5 + (WALL_NUM+1):5 + 2*(WALL_NUM+1), :,:], state[5:5+(WALL_NUM+1),:,:], cur_player, dist_state2, dist_state1])
            # hv_equi_state = np.vstack([flipped_wall_state, flipped_player_position, state[5:(5 + (WALL_NUM+1) * 2), :, :], cur_player, state[(6 + (WALL_NUM + 1) * 2):, :, :]])

            hv_equi_mcts_prob = np.copy(mcts_prob)

            hv_equi_mcts_prob[11] = mcts_prob[8]  # SE to NW
            hv_equi_mcts_prob[10] = mcts_prob[9]  # SW to NE
            hv_equi_mcts_prob[9] = mcts_prob[10]  # NE to SW
            hv_equi_mcts_prob[8] = mcts_prob[11]  # NW to SE
            hv_equi_mcts_prob[7] = mcts_prob[6]   # EE to WW
            hv_equi_mcts_prob[6] = mcts_prob[7]   # WW to EE
            hv_equi_mcts_prob[5] = mcts_prob[4]   # NN to SS
            hv_equi_mcts_prob[4] = mcts_prob[5]   # SS to NN
            hv_equi_mcts_prob[3] = mcts_prob[2]   # E to W
            hv_equi_mcts_prob[2] = mcts_prob[3]   # W to E
            hv_equi_mcts_prob[1] = mcts_prob[0]   # N to S
            hv_equi_mcts_prob[0] = mcts_prob[1]   # S to N

            h_wall_actions = hv_equi_mcts_prob[12:12 + (BOARD_SIZE-1) ** 2].reshape(BOARD_SIZE-1, BOARD_SIZE-1)
            v_wall_actions = hv_equi_mcts_prob[12 + (BOARD_SIZE-1) ** 2:].reshape(BOARD_SIZE-1, BOARD_SIZE -1)

            flipped_h_wall_actions = np.fliplr(np.flipud(h_wall_actions))
            flipped_v_wall_actions = np.fliplr(np.flipud(v_wall_actions))

            hv_equi_mcts_prob[12:] = np.hstack([flipped_h_wall_actions.flatten(), flipped_v_wall_actions.flatten()])

            ###########

            extend_data.append((state, mcts_prob, winner))
            extend_data.append((h_equi_state, h_equi_mcts_prob, winner))
            extend_data.append((v_equi_state, v_equi_mcts_prob, winner * -1))
            extend_data.append((hv_equi_state, hv_equi_mcts_prob, winner * -1))

        return extend_data

    def collect_selfplay_data(self, n_games=1):
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)

            play_data = self.get_equi_data(play_data)

            self.data_buffer.extend(play_data)
            print("{}th game finished. Current episode length: {}, Length of data buffer: {}".format(i, self.episode_len, len(self.data_buffer)))

    def policy_update(self):

        dataloader = DataLoader(self.data_buffer, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

        valloss_acc = 0
        polloss_acc = 0
        entropy_acc = 0

        for i in range(NUM_EPOCHS):

            self.old_probs = self.new_probs

            if self.first_trained:
                kl = np.mean(np.sum(self.old_probs * (np.log(self.old_probs + 1e-10) - np.log(self.new_probs + 1e-10)), axis=1))
                if kl > self.kl_targ * 4:
                    break

                if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
                    self.lr_multiplier /= 1.5
                elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
                    self.lr_multiplier *= 1.5


            for i, (state, mcts_prob, winner) in enumerate(dataloader):
                valloss, polloss, entropy = self.policy_value_net.train_step(state, mcts_prob, winner, self.learn_rate * self.lr_multiplier)
                self.new_probs, new_v = self.policy_value_net.policy_value(state)

                global iter_count

                writer.add_scalar("Val Loss/train", valloss.item(), iter_count)
                writer.add_scalar("Policy Loss/train", polloss.item(), iter_count)
                writer.add_scalar("Entropy/train", entropy, iter_count)
                writer.add_scalar("LR Multiplier", self.lr_multiplier, iter_count)

                iter_count += 1

                valloss_acc += valloss.item()
                polloss_acc += polloss.item()
                entropy_acc += entropy.item()

            self.first_trained = True

        valloss_mean = valloss_acc / (len(dataloader) * NUM_EPOCHS)
        polloss_mean = polloss_acc / (len(dataloader) * NUM_EPOCHS)
        entropy_mean = entropy_acc / (len(dataloader) * NUM_EPOCHS)

        #explained_var_old = 1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch))
        #explained_var_new = 1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch))
        #print( "kl:{:.5f}, lr_multiplier:{:.3f}, value loss:{}, policy loss:[], entropy:{}".format(
        #        kl, self.lr_multiplier, valloss, polloss, entropy, explained_var_old, explained_var_new))
        return valloss_mean, polloss_mean, entropy_mean

    def run(self):
        try:
            self.collect_selfplay_data(3)
            count = 0
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)    # collect_s
                print("batch i:{}, episode_len:{}".format(i + 1, self.episode_len))
                if len(self.data_buffer) > BATCH_SIZE:
                    valloss, polloss, entropy = self.policy_update()
                    print("VALUE LOSS: %0.3f " % valloss, "POLICY LOSS: %0.3f " % polloss, "ENTROPY: %0.3f" % entropy)

                    #writer.add_scalar("Val Loss/train", valloss.item(), i)
                    #writer.add_scalar("Policy Loss/train", polloss.item(), i)
                    #writer.add_scalar("Entory/train", entropy, i)

                if (i + 1) % self.check_freq == 0:
                    count += 1
                    print("current self-play batch: {}".format(i + 1))
                    # win_ratio = self.policy_evaluate()
                    # Add generation to filename
                    self.policy_value_net.save_model('model_7x7_' + str(count) + '_' + str("%0.3f_" % (valloss+polloss) + str(time.strftime('%Y-%m-%d', time.localtime(time.time())))))
        except KeyboardInterrupt:
            print('\n\rquit')
class TrainPipeline():
    def __init__(self, size=(8, 8), init_model=None):
        # params of the board and the game
        print(size)
        self.board_width = size[1]
        self.board_height = size[0]
        self.board = GomokuBoard(size=(self.board_width, self.board_height))
        self.game = GomokuGame(self.board)
        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 400  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 3000
        self.best_win_ratio = 0.0
        self.all_loss = []
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 1000
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, z in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(
                    np.flipud(
                        mcts_porb.reshape(self.board_height,
                                          self.board_width)), i)
                extend_data.append((equi_state, np.flipud(equi_mcts_prob), z))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state, np.flipud(equi_mcts_prob), z))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            result, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            print('The result:', result)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        z_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, z_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=(1, 2)))
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 - np.var(np.array(z_batch) - old_v.flatten()) /
                             np.var(np.array(z_batch)))
        explained_var_new = (1 - np.var(np.array(z_batch) - new_v.flatten()) /
                             np.var(np.array(z_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss,
                                                  entropy, explained_var_old,
                                                  explained_var_new))
        return loss, entropy

    # def policy_evaluate(self, n_games=10):
    #     """
    #     Evaluate the trained policy by playing against the pure MCTS player
    #     Note: this is only for monitoring the progress of training
    #     """
    #     current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
    #                                      c_puct=self.c_puct,
    #                                      n_playout=self.n_playout)
    #     pure_mcts_player = MCTS_Pure(c_puct=5,
    #                                  n_playout=self.pure_mcts_playout_num)
    #     win_cnt = defaultdict(int)
    #     for i in range(n_games):
    #         winner = self.game.start_play(current_mcts_player,
    #                                       pure_mcts_player,
    #                                       start_player=i % 2,
    #                                       is_shown=0)
    #         win_cnt[winner] += 1
    #     win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
    #     print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
    #             self.pure_mcts_playout_num,
    #             win_cnt[1], win_cnt[2], win_cnt[-1]))
    #     return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                    i + 1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                    self.all_loss.append(loss)
                # check the performance of the current model,
                # and save the model params
                # if (i+1) % self.check_freq == 0:
                #     print("current self-play batch: {}".format(i+1))
                #     win_ratio = self.policy_evaluate()
                #     self.policy_value_net.save_model('./current_policy.model')
                #     if win_ratio > self.best_win_ratio:
                #         print("New best policy!!!!!!!!")
                #         self.best_win_ratio = win_ratio
                # update the best_policy
                if (i + 1) % 10 == 0:
                    self.policy_value_net.save_model(
                        './model/best_policy08_08.model')
                    print('save model.')
                    print(self.all_loss)
            print('finish')
            #         if (self.best_win_ratio == 1.0 and
            #                 self.pure_mcts_playout_num < 5000):
            #             self.pure_mcts_playout_num += 1000
            #             self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')
Пример #6
0
class TrainPipeline:
    def __init__(self, n: int, init_model=None):
        # params of the board and the game
        self.n = n
        self.board = Board(self.n)
        self.game = Game(self.board)
        # training params
        self.learn_rate = 5e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_play_out = 400  # number of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.epochs = 5  # number of train_steps for each update
        self.kl_target = 0.025
        self.check_freq = 50
        self.game_batch_number = 10000
        self.best_win_ratio = 0.0
        self.episode_length = 0
        self.pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
        # number of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy
        self.last_batch_number = 0
        self.pure_mcts_play_out_number = 1000
        if init_model:
            # start training from an initial policy-value net
            policy_param = pickle.load(open(init_model, 'rb'))
            self.policy_value_net = PolicyValueNet(self.n,
                                                   net_params=policy_param)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.n)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_func,
                                      c_puct=self.c_puct,
                                      n_play_out=self.n_play_out,
                                      is_self_play=1)

    def get_equivalent_data(self, play_data):
        """
        augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]"""
        extend_data = []
        for state, mcts_probabilities, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equivalent_state = np.array([np.rot90(s, i) for s in state])
                equivalent_mcts_prob = np.rot90(
                    np.flipud(mcts_probabilities.reshape(self.n, self.n)), i)
                extend_data.append(
                    (equivalent_state,
                     np.flipud(equivalent_mcts_prob).flatten(), winner))
                # flip horizontally
                equivalent_state = np.array(
                    [np.fliplr(s) for s in equivalent_state])
                equivalent_mcts_prob = np.fliplr(equivalent_mcts_prob)
                extend_data.append(
                    (equivalent_state,
                     np.flipud(equivalent_mcts_prob).flatten(), winner))
        return extend_data

    def collect_self_play_data(self):
        """collect self-play data for training"""
        play_data = list(
            self.game.start_self_play(self.mcts_player, temp=self.temp))
        self.episode_length = len(play_data)
        play_data = self.get_equivalent_data(play_data)
        self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        kl = 0
        new_v = 0
        loss = 0
        entropy = 0

        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probabilities_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probabilities, old_v = self.policy_value_net.policy_value(
            state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probabilities_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probabilities, new_v = self.policy_value_net.policy_value(
                state_batch)
            kl = np.mean(
                np.sum(old_probabilities * (np.log(old_probabilities + 1e-10) -
                                            np.log(new_probabilities + 1e-10)),
                       axis=1))
            if kl > self.kl_target * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_target * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_target / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = 1 - np.var(
            np.array(winner_batch) - old_v.flatten()) / np.var(
                np.array(winner_batch))
        explained_var_new = 1 - np.var(
            np.array(winner_batch) - new_v.flatten()) / np.var(
                np.array(winner_batch))
        print_log(
            "kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}"
            .format(kl, self.lr_multiplier, loss, entropy, explained_var_old,
                    explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing games against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(
            self.policy_value_net.policy_value_func,
            c_puct=self.c_puct,
            n_play_out=self.n_play_out)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_play_out=self.pure_mcts_play_out_number)
        win_cnt = defaultdict(int)
        results = self.pool.map(self.game.start_play,
                                [(current_mcts_player, pure_mcts_player, i)
                                 for i in range(n_games)])
        for winner in results:
            win_cnt[winner] += 1
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        print_log("number_play_outs:{}, win: {}, lose: {}, tie:{}".format(
            self.pure_mcts_play_out_number, win_cnt[1], win_cnt[2],
            win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_number):
                if os.path.exists("done"):
                    break
                start_time = time.time()
                self.collect_self_play_data()
                print_log("batch i:{}, episode_len:{}, in:{}".format(
                    i + 1 + self.last_batch_number, self.episode_length,
                    time.time() - start_time))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                    data_log(
                        str((i + 1 + self.last_batch_number, loss, entropy)))
                # check the performance of the current model,and save the model params
                if (i + 1) % self.check_freq == 0:
                    print_log("current self-play batch: {}".format(
                        i + 1 + self.last_batch_number))
                    start_time = time.time()
                    win_ratio = self.policy_evaluate()
                    net_params = self.policy_value_net.get_policy_parameter(
                    )  # get model params
                    pickle.dump(net_params, open('current_policy.model', 'wb'),
                                pickle.HIGHEST_PROTOCOL)
                    print_log(str(time.time() - start_time))
                    if win_ratio > self.best_win_ratio:
                        self.best_win_ratio = win_ratio
                        pickle.dump(net_params, open('best_policy.model',
                                                     'wb'),
                                    pickle.HIGHEST_PROTOCOL)
                        if self.best_win_ratio >= 0.8:
                            print_log("New best policy defeated " +
                                      str(self.pure_mcts_play_out_number) +
                                      " play out MCTS player ")
                            self.best_win_ratio = 0.0
                            self.pure_mcts_play_out_number += 1000

        except KeyboardInterrupt:
            pass
Пример #7
0
class TrainPipeline():
    def __init__(self, init_model=None):
        # params of the board and the game
        self.board_width = 6    #棋盘宽度
        self.board_height = 6   #棋盘高度
        self.n_in_row = 4       #胜利条件:多少个棋连成一线算是胜利

        # 实例化一个board,定义棋盘宽高和胜利条件
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)

        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 400  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 1500
        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 1000

        #初始化network和树,network是一直保存的,树的话不知道什么时候重置。
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    #作用是扩充data,因为五子棋是上下左右相同的。
    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        ##play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                # np.rot90:矩阵旋转90度
                # np.flipud:矩阵反转
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(np.flipud(
                    mcts_porb.reshape(self.board_height, self.board_width)), i)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
        return extend_data

    #搜集selfplay的data
    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        #进行n_games游戏
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)  #对弈步数
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        #======解压数据============
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        #=========================
        #这里好像做了important sampling,直接计算KL_diverges大小,超过一定就早停
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        #进行epochs次训练
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                    state_batch,
                    mcts_probs_batch,
                    winner_batch,
                    self.learn_rate*self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(np.sum(old_probs * (
                    np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1)
            )
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        # 根据上次更新的KL_diverges大小,动态调整学习率
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}"
               ).format(kl,
                        self.lr_multiplier,
                        loss,
                        entropy,
                        explained_var_old,
                        explained_var_new))
        return loss, entropy

    #用纯MCTS玩,和AlphaZERO玩,看看哪个更厉害
    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
                self.pure_mcts_playout_num,
                win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    #training pipeline
    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                #搜集data,搜集play_batch_size次,每次玩n_game次。
                #每次game都会新建一棵树,每一步就是树的一个节点。
                #每一步都会进行_n_playout次模拟
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                        i+1, self.episode_len))
                # data足够,update.可以用上important sampling,updata,n次。
                # update玩,进行新的搜集时,就会清空原来数据。
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # check the performance of the current model,
                # and save the model params
                if (i+1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i+1))
                    win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model('./current_policy.model')
                    if win_ratio > self.best_win_ratio:
                        print("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model('./best_policy.model')
                        if (self.best_win_ratio == 1.0 and
                                self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')
Пример #8
0
class TrainPipeline():
    def __init__(self, init_model=None):
        # params of the board and the game
        # basic params
        self.board_width = 9
        self.board_height = 9
        self.n_in_row = 5
        # init the board and game
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # training params
        self.learn_rate = 3e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1e-3  # the temperature param
        # self.n_playout = 400  # num of simulations for each move
        self.n_playout = 400
        self.c_puct = 3  # a number in (0, inf) that controls how quickly exploration
        # converges to the maximum-value policy. A higher value means
        # relying on the prior more.
        self.buffer_size = 10000
        # self.batch_size = 512  # mini-batch size for training
        self.batch_size = 256
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 1000
        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 400
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(
                    np.flipud(
                        mcts_porb.reshape(self.board_height,
                                          self.board_width)), i)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    def collect_selfplay_data(self, n_games=25):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)

            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss,
                                                  entropy, explained_var_old,
                                                  explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=30):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        with open(logfile_name, 'w+') as file:
            file.write("num_playouts:{}, win: {}, lose: {}, tie:{}\n".format(
                self.pure_mcts_playout_num, win_cnt[1], win_cnt[2],
                win_cnt[-1]))

        print("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                    i + 1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # check the performance of the current model,
                # and save the model params
                if (i + 1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i + 1))
                    win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model(
                        './models/current_policy_{}.model'.format(i + 1))
                    if win_ratio > self.best_win_ratio:
                        print("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model(
                            './models/best_policy_{}.model'.format(i + 1))
                        if (self.best_win_ratio > 0.8
                                and self.pure_mcts_playout_num < 25000):
                            print("stronger model to compete")
                            self.pure_mcts_playout_num += 500
                            self.best_win_ratio = 0.0
                        elif self.best_win_ratio == 0 and self.n_playout < 15000:
                            self.pure_mcts_playout_num += 250
                            print("enhance the alphazero mcts")
                print('-------------------------training_outer_epoch!!!!!!', i,
                      "-----------------")

        except KeyboardInterrupt:
            print('\n\rquit')
Пример #9
0
class TrainPipeline():
    def __init__(self, init_model=None):
        # params of the board and the game
        self.board_width = 6
        self.board_height = 6
        self.n_in_row = 4
        self.board = Board(width=self.board_width,
                           height=self.board_height,
                           n_in_row=self.n_in_row)
        self.game = Game(self.board)
        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 400  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 1500
        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 1000
        # add output log
        self.formatter = logging.Formatter('%(asctime)s [%(module)s] %(levelname)s: %(message)s', '%Y-%m-%d %H:%M:%S')
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(level=logging.INFO)
        self.handler = logging.FileHandler("output.log")
        self.handler.setLevel(logging.INFO)
        self.handler.setFormatter(self.formatter)
        self.console = logging.StreamHandler()
        self.console.setLevel(logging.INFO)
        self.console.setFormatter(self.formatter)
        self.logger.addHandler(self.handler)
        self.logger.addHandler(self.console)
        
        if init_model:
            if os.path.exists(init_model):
                # start training from an initial policy-value net
                self.policy_value_net = PolicyValueNet(self.board_width,
                                                       self.board_height,
                                                       model_file=init_model)
             else:
                self.logger.error("{} does not exists!\n".format(init_model))
                return -1
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(np.flipud(
                    mcts_porb.reshape(self.board_height, self.board_width)), i)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                    state_batch,
                    mcts_probs_batch,
                    winner_batch,
                    self.learn_rate*self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(np.sum(old_probs * (
                    np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1)
            )
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        self.logger.info(("kl:{:.5f}, lr_multiplier:{:.3f}, loss:{}, entropy:{}, explained_var_old:{:.3f}, explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss, entropy, explained_var_old, explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0*(win_cnt[1] + 0.5*win_cnt[-1]) / n_games
        self.logger.info("num_playouts:{}, win: {}, lose: {}, tie:{}".format(
                self.pure_mcts_playout_num,
                win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                self.logger.info("batch i:{}, episode_len:{}".format(
                        i+1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # check the performance of the current model,
                # and save the model params
                if (i+1) % self.check_freq == 0:
                    self.logger.info("current self-play batch: {}".format(i+1))
                    win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model('./current_policy.model')
                    self.policy_value_net.save_model('./policy_{}_{}_{}_{}.model'.fromat(self.board_width, self.board_height, self.n_in_row, datetime.datetime.strftime(datetime.datetime.now(), "%Y%m%d%H%M%S")))
                    if win_ratio > self.best_win_ratio:
                        self.logger.info("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model('./best_policy.model')
                        if (self.best_win_ratio == 1.0 and
                                self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')
Пример #10
0
class Train():
    def __init__(self, init_model=None):
        # params of the game
        self.width = 4
        self.height = 4
        self.game = Game()
        # params of training
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0
        self.temp = 1.0
        self.n_playout = 300
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 64
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 5000
        self.best_win_ratio = 0.0

        self.pure_mcts_playout_num = 500

        if init_model:
            self.policy_value_net = PolicyValueNet(self.width,
                                                   self.height,
                                                   model_file=init_model)
        else:
            self.policy_value_net = PolicyValueNet(self.width, self.height)

        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def collect_selfplay_data(self, n_games=1):

        for i in range(n_games):
            print "=====================Start===================="
            self.game = Game()
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            #print "winner",winner,play_data
            print "======================END====================="
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            #play_data = self.get_qui_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        #print "____policy___update_______"
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        #print "old_v = ",old_v
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:
                break
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5
        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print "result-eval var=", np.var(
            np.array(winner_batch) - new_v.flatten()), "\twinner var=", np.var(
                np.array(winner_batch))
        print "kl=", kl, "\tlr_mul=", self.lr_multiplier
        print "var_old : {:.3f}\tvar_new : {:.3f}".format(
            explained_var_old, explained_var_new)
        return loss, entropy

    def policy_evaluate(self, n_games=10):

        #print "_____policy__evaluation________"

        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)

        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)

        win_cnt = defaultdict(int)

        for i in range(n_games):

            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            print "winner", winner
            win_cnt[winner] += 1

            win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[0]) / n_games

        print "win ratio =", win_ratio
        print("num_playout:{}, win: {}, lose: {}, tie:{}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[0]))
        return win_ratio

    def run(self, modelfile=None):
        for i in range(self.game_batch_num):
            self.collect_selfplay_data(self.play_batch_size)
            print "gamebatch :", i + 1, "episode_len:", self.episode_len
            print "selfplayend,data_buffer len=", len(self.data_buffer)

            if len(self.data_buffer) > self.batch_size:
                loss, entropy = self.policy_update()
                print "loss = {:.3f}\tentropy = {:.3f}".format(loss, entropy)

            if (i + 1) % self.check_freq == 0:
                print("current self-play batch:{}".format(i + 1))
                win_ratio = self.policy_evaluate()
                self.policy_value_net.save_model('current.model')
                if win_ratio > self.best_win_ratio:
                    print("new best model")
                    self.best_win_ratio = win_ratio
                    self.policy_value_net.save_model("best.model")
                    if self.best_win_ratio >= 0.8 and self.pure_mcts_playout_num < 1000:
                        print "Pure Harder"
                        self.pure_mcts_playout_num += 100
                        self.best_win_ratio = 0.0
Пример #11
0
class TrainPipeline():
    def __init__(self, init_model=None):
        self.board = Board()
        self.game = Game(self.board)
        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        #                                                      self.board_height,
        self.temp = 1.0  # the temperature param
        self.n_playout = 1600  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 15000
        self.best_win_ratio = 0.0
        self.pure_mcts_playout_num = 1000
        if init_model:
            self.policy_value_net = PolicyValueNet(model_file=init_model,
                                                   use_gpu=True)
        else:
            self.policy_value_net = PolicyValueNet(use_gpu=True)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)
        print("init done")

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        print("play_data = {}".format(play_data))
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                #                print("state[0] = {}".format(state[0]))
                #                print("state = {}".format(state))
                #                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_state = np.rot90(state, i)
                equi_mcts_prob = np.rot90(
                    np.flipud(
                        mcts_porb.reshape(self.board_height,
                                          self.board_width)), i)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp,
                                                          is_shown=1)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            #       play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(
            ("kl = {:.5f},"
             "lr_multiplier = {:.3f},"
             "loss = {},"
             "entropy = {},"
             "explained_var_old = {:.3f},"
             "explained_var_new = {:.3f}").format(kl, self.lr_multiplier, loss,
                                                  entropy, explained_var_old,
                                                  explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            print(i)
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          is_shown=0)
            win_cnt[winner] += 1
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[-1]) / n_games
        print("num_playouts = {}, win = {}, lose = {}, tie = {}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2], win_cnt[-1]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            localtime = time.asctime(time.localtime(time.time()))
            print("本地时间为 :", localtime)
            for i in range(self.game_batch_num):
                print("selfplay....")
                self.collect_selfplay_data(self.play_batch_size)
                print("selfplay done")
                print("batch i = {}, episode_len = {}".format(
                    i + 1, self.episode_len))
                localtime = time.asctime(time.localtime(time.time()))
                print("本地时间为 :", localtime)
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                # check the performance of the current model,
                # and save the model params

                if (i + 1) % self.check_freq == 0:
                    print("current self-play batch = {}".format(i + 1))
                    win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model('./current_policy.model')
                    if win_ratio > self.best_win_ratio:
                        print("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # update the best_policy
                        self.policy_value_net.save_model('./best_policy.model')
                        if (self.best_win_ratio == 1.0
                                and self.pure_mcts_playout_num < 5000):
                            self.pure_mcts_playout_num += 1000
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            print('\n\rquit')
Пример #12
0
class TrainPipeline():
    def __init__(self):
        # 게임(오목)에 대한 변수들
        self.board_width, self.board_height = 9, 9
        self.n_in_row = 5
        self.board = Board(width=self.board_width, height=self.board_height, n_in_row=self.n_in_row)
        self.game = Game(self.board)
        
        # 학습에 대한 변수들
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # KL에 기반하여 학습 계수를 적응적으로 조정
        self.temp = 1.0  # the temperature param
        self.n_playout = 400  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.batch_size = 512  # mini-batch size : 버퍼 안의 데이터 중 512개를 추출
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 500  # 지정 횟수마다 모델을 체크하고 저장. 원래는 100이었음.
        self.game_batch_num = 3000  # 최대 학습 횟수
        self.train_num = 0 # 현재 학습 횟수
        
        # policy-value net에서 학습 시작
        self.policy_value_net = PolicyValueNet(self.board_width, self.board_height)
        
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn, c_puct=self.c_puct, n_playout=self.n_playout, is_selfplay=1)

    def get_equi_data(self, play_data):
        """
        회전 및 뒤집기로 데이터set 확대
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # 반시계 방향으로 회전
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(self.board_height, self.board_width)), i)
                extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # 수평으로 뒤집기
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # 데이터를 확대
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data) # deque의 오른쪽(마지막)에 삽입

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(state_batch, mcts_probs_batch, winner_batch, self.learn_rate*self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)), axis=1))
            
            # D_KL diverges 가 나쁘면 빠른 중지
            if kl > self.kl_targ * 4 : break
                
        # learning rate를 적응적으로 조절
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1 : self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10 : self.lr_multiplier *= 1.5

        explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch)))
        explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch)))

        print(f"kl:{kl:5f}, lr_multiplier:{self.lr_multiplier:3f}, loss:{loss}, entropy:{entropy}, explained_var_old:{explained_var_old:3f}, explained_var_new:{explained_var_new:3f}")

        return loss, entropy

    def run(self):
        for i in range(self.game_batch_num):
            self.collect_selfplay_data(self.play_batch_size)
            self.train_num += 1
            print(f"batch i:{self.train_num}, episode_len:{self.episode_len}")

            if len(self.data_buffer) > self.batch_size : loss, entropy = self.policy_update()

            # 현재 model의 성능을 체크, 모델 속성을 저장
            if (i+1) % self.check_freq == 0:
                print(f"★ {self.train_num}번째 batch에서 모델 저장 : {datetime.now()}")
                self.policy_value_net.save_model(f'{model_path}/policy_9_{self.train_num}.model')
                pickle.dump(self, open(f'{train_path}/train_9_{self.train_num}.pickle', 'wb'), protocol=2)
Пример #13
0
def test():
    from quoridor import Quoridor
    from pure_mcts import MCTSPlayer as MCTS_Pure
    from mcts_player import MCTSPlayer
    from policy_value_net import PolicyValueNet
    policy_value_net = PolicyValueNet(model_file=None)
    c_puct = 5
    n_playout = 800
    temp = 1.0
    board = Quoridor()
    game = Game(board)
    mcts_player = MCTSPlayer(policy_value_net.policy_value_fn,
                                      c_puct=c_puct,
                                      n_playout=n_playout,
                                      is_selfplay=1)
    winner, play_data = game.start_self_play(mcts_player,
                                            is_shown=1,
                                            temp=temp)
    print(winner)
    print(play_data)

    state_batch = [data[0] for data in play_data]
    mcts_probs_batch = [data[1] for data in play_data]
    winner_batch = [data[2] for data in play_data]

    learn_rate = 2e-3
    lr_multiplier = 1.0
    kl_targ = 0.02

    old_probs, old_v = policy_value_net.policy_value(state_batch)
    for i in range(5):
        loss, entropy = policy_value_net.train_step(
                    state_batch,
                    mcts_probs_batch,
                    winner_batch,
                    learn_rate*lr_multiplier)
        new_probs, new_v = policy_value_net.policy_value(state_batch)
        kl = np.mean(np.sum(old_probs * (
                    np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                    axis=1)
        )
        if kl > kl_targ * 4:  # early stopping if D_KL diverges badly
            break
    # adaptively adjust the learning rate
    if kl > kl_targ * 2 and lr_multiplier > 0.1:
        lr_multiplier /= 1.5
    elif kl < kl_targ / 2 and lr_multiplier < 10:
        lr_multiplier *= 1.5

    explained_var_old = (1 -
                            np.var(np.array(winner_batch) - old_v.flatten()) /
                            np.var(np.array(winner_batch)))
    explained_var_new = (1 -
                            np.var(np.array(winner_batch) - new_v.flatten()) /
                            np.var(np.array(winner_batch)))
    print(("kl:{:.5f},"
            "lr_multiplier:{:.3f},"
            "loss:{},"
            "entropy:{},"
            "explained_var_old:{:.3f},"
            "explained_var_new:{:.3f}"
            ).format(kl,
                    lr_multiplier,
                    loss,
                    entropy,
                    explained_var_old,
                    explained_var_new))
    policy_value_net.save_model('./current_policy.model')
Пример #14
0
class TrainPipeline():
    def __init__(self):
        # params of the board and the game
        self.board_width = 5
        self.board_height = 5
        self.game = Game()
        # training params
        self.learn_rate = 0.001
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 500  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 128  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 100
        self.game_batch_num = 2000
        self.best_win_ratio = 0.0
        # num of simulations used for the pure mcts, which is used as the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 3000

        # start training from a new policy-value net
        self.policy_value_net = PolicyValueNet(self.board_width,
                                               self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def collect_selfplay_data(self, n_games=1):
        """
        collect self-play data for training
        default collect one game data
        """
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            self.data_buffer.extend(play_data)

    def policy_update(self, verbose=False):
        """
        update the policy-value net
        verbose to show more details of the training steps, default not show
        """
        # ipdb.set_trace()
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]

        old_probs, old_v = self.policy_value_net.policy_value(state_batch)

        loss_list = []
        entropy_list = []
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)

            loss_list.append(loss)
            entropy_list.append(entropy)

            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break

        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        if verbose:
            explained_var_old = (
                1 - np.var(np.array(winner_batch) - old_v.flatten()) /
                np.var(np.array(winner_batch)))
            explained_var_new = (
                1 - np.var(np.array(winner_batch) - new_v.flatten()) /
                np.var(np.array(winner_batch)))

            print(("kl: {:.3f}, "
                   "lr_multiplier: {:.3f}\n"
                   "last loss: {:.3f}, "
                   "mean loss: {:.3f}, "
                   "mean entropy: {:.3f}\n"
                   "explained old: {:.3f}, "
                   "explained new: {:.3f}\n").format(kl, self.lr_multiplier,
                                                     loss_list[-1],
                                                     np.mean(loss_list),
                                                     np.mean(entropy_list),
                                                     explained_var_old,
                                                     explained_var_new))

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing games against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=3000)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = defaultdict(int)
        for i in range(n_games):
            # alphazero always red, but change the first player in the game
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          1,
                                          2,
                                          start_player=(i % 2) + 1,
                                          is_show=0)
            print("winner is {}".format(winner))
            win_cnt[winner] += 1
        # 计算红方(alphazero)的胜率
        win_ratio = win_cnt[1] / n_games
        print("num_playouts:{}, win: {}, lose: {}".format(
            self.pure_mcts_playout_num, win_cnt[1], win_cnt[2]))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        try:
            for i in range(self.game_batch_num):
                print("game", i, 'start ...')
                bt = time.time()
                self.collect_selfplay_data(self.play_batch_size)
                print('game', i, 'cost', int(time.time() - bt), 's')

                if len(self.data_buffer) > self.batch_size:
                    print("#### batch i:{} ####\n".format(i + 1))
                    for vi in range(5):
                        verbose = vi % 5 == 0
                        self.policy_update(verbose)

                # check the performance of the current model,and save the model params
                # every 1000 check once
                if (i + 1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i + 1))
                    self.policy_value_net.saver.save(
                        self.policy_value_net.session,
                        self.policy_value_net.model_file)
                    win_ratio = self.policy_evaluate()
                    print('*****win ration: {:.2f}%\n'.format(win_ratio * 100))

                    if win_ratio > self.best_win_ratio:
                        print("New best policy!!!!!!!!")
                        self.best_win_ratio = win_ratio
                        # save the model
                        self.policy_value_net.saver.save(
                            self.policy_value_net.session,
                            self.policy_value_net.model_file
                        )  # update the best_policy
                        if self.best_win_ratio == 1.0 and self.pure_mcts_playout_num < 5000:
                            self.pure_mcts_playout_num += 100
                            self.best_win_ratio = 0.0
        except KeyboardInterrupt:
            # save before quit
            self.policy_value_net.saver.save(self.policy_value_net.session,
                                             self.policy_value_net.model_file)
            print('quit, Bye !')
Пример #15
0
class TrainPipeline():
    def __init__(self):
        # params of the board and the game
        self.board_width = 9
        self.board_height = 9
        self.board = Board(width=self.board_width, height=self.board_height)
        self.game = Game(self.board)
        # training params
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temp = 1.0  # the temperature param
        self.n_playout = 800  # num of simulations for each move
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 512  # mini-batch size for training
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5  # num of train_steps for each update
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 3
        self.best_loss = None
        # num of simulations used for the pure mcts, which is used as
        # the opponent to evaluate the trained policy
        self.pure_mcts_playout_num = 1000
        init_model = 'checkpoint/current_policy.model'
        if os.path.isfile(init_model + '.index'):
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height,
                                                   model_file=init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_width,
                                                   self.board_height)
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    def get_equi_data(self, play_data):
        """augment the data set by rotation and flipping
        play_data: [(state, mcts_prob, winner_z), ..., ...]
        """
        print('1')
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(
                    np.flipud(
                        mcts_porb.reshape(self.board_height,
                                          self.board_width)), i)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append(
                    (equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
        return extend_data

    def collect_selfplay_data(self, n_games=1):
        """collect self-play data for training"""
        print('2')
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(self.mcts_player,
                                                          temp=self.temp)
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # augment the data
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        print('3')
        mini_batch = random.sample(self.data_buffer, self.batch_size)
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
                break
        # adaptively adjust the learning rate
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = (1 -
                             np.var(np.array(winner_batch) - old_v.flatten()) /
                             np.var(np.array(winner_batch)))
        explained_var_new = (1 -
                             np.var(np.array(winner_batch) - new_v.flatten()) /
                             np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "entropy:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}").format(kl, self.lr_multiplier, loss,
                                                  entropy, explained_var_old,
                                                  explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        """
        Evaluate the trained policy by playing against the pure MCTS player
        Note: this is only for monitoring the progress of training
        """
        print('4')
        current_mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                         c_puct=self.c_puct,
                                         n_playout=self.n_playout)
        pure_mcts_player = MCTS_Pure(c_puct=5,
                                     n_playout=self.pure_mcts_playout_num)
        win_cnt = 0
        for i in range(n_games):
            winner = self.game.start_play(current_mcts_player,
                                          pure_mcts_player,
                                          start_player=i % 2,
                                          is_shown=0)
            win_cnt += 1
        win_ratio = win_cnt / n_games
        print("num_playouts:{}, win: {}".format(self.pure_mcts_playout_num,
                                                win_cnt))
        return win_ratio

    def run(self):
        """run the training pipeline"""
        print('go1')
        try:
            if not os.path.isdir('checkpoint'):
                os.makedirs('checkpoint')
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                print("{}: batch i:{}, episode_len:{}".format(
                    datetime.datetime.now(), i + 1, self.episode_len))

                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                    if self.best_loss is None or loss < self.best_loss:
                        self.best_loss = loss
                        print(
                            "New best policy auto save at batch {}".format(i +
                                                                           1))
                        self.policy_value_net.save_model(
                            'checkpoint/best_policy.model')

                if (i + 1) % self.check_freq == 0:
                    print("current model auto save at batch {}".format(i + 1))
                    self.policy_value_net.save_model(
                        'checkpoint/current_policy.model')

        except KeyboardInterrupt:
            print('\n\rquit')
Пример #16
0
class TrainPipeline(object):
    def __init__(self, init_model=None):
        # 棋盘参数
        self.game = Quoridor()
        # 训练参数
        self.learn_rate = 2e-3
        self.lr_multiplier = 1.0  # 适应性调节学习速率
        self.temp = 1.0
        self.n_playout = 400
        self.c_puct = 5
        self.buffer_size = 10000
        self.batch_size = 128  # 取1 测试ing
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 5
        self.kl_targ = 0.02
        self.check_freq = 50
        self.game_batch_num = 1500
        self.best_win_ratio = 0.0
        self.pure_mcts_playout_num = 1000
        if init_model:
            self.policy_value_net = PolicyValueNet(model_file=init_model)
        else:
            self.policy_value_net = PolicyValueNet()
        # 设置电脑玩家信息
        self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
                                      c_puct=self.c_puct,
                                      n_playout=self.n_playout,
                                      is_selfplay=1)

    # def get_equi_data(self, play_data):
    #     """
    #     数据集增强,获取旋转后的数据,因为五子棋也是对称的
    #     play_data: [(state, mcts_prob, winner_z), ..., ...]"""
    #     extend_data = []
    #     for state, mcts_porb, winner in play_data:
    #         equi_state = np.array([np.rot90(s,2) for s in state])
    #         equi_mcts_prob = np.rot90(np.flipud(mcts_porb.reshape(9, 9)), 2)
    #         extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
    #         # flip horizontally
    #         equi_state = np.array([np.fliplr(s) for s in equi_state])
    #         equi_mcts_prob = np.fliplr(equi_mcts_prob)
    #         extend_data.append((equi_state, np.flipud(equi_mcts_prob).flatten(), winner))
    #     return extend_data

    def collect_selfplay_data(self, n_games=1):
        """收集训练数据"""
        for i in range(n_games):
            winner, play_data = self.game.start_self_play(
                self.mcts_player, temp=self.temp)  # 进行自博弈
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            # 数据增强
            # play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """训练策略价值网络"""
        mini_batch = random.sample(self.data_buffer,
                                   self.batch_size)  # 获取mini-batch
        state_batch = [data[0] for data in mini_batch]  # 提取第一位的状态
        mcts_probs_batch = [data[1] for data in mini_batch]  # 提取第二位的概率
        winner_batch = [data[2] for data in mini_batch]  # 提取第三位的胜负情况
        old_probs, old_v = self.policy_value_net.policy_value(
            state_batch)  # 输入网络计算旧的概率和胜负价值,这里为什么要计算旧的数据是因为需要计算
        #                                                                     新旧之间的KL散度来控制学习速率的退火
        # 开始训练epochs个轮次
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(
                state_batch, mcts_probs_batch, winner_batch,
                self.learn_rate * self.lr_multiplier)
            new_probs, new_v = self.policy_value_net.policy_value(
                state_batch)  # 计算新的概率和价值
            kl = np.mean(
                np.sum(old_probs *
                       (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),
                       axis=1))
            if kl > self.kl_targ * 4:  # 如果KL散度发散的很不好,就提前结束训练
                break
        # 根据KL散度,适应性调节学习速率
        if kl > self.kl_targ * 2 and self.lr_multiplier > 0.1:
            self.lr_multiplier /= 1.5
        elif kl < self.kl_targ / 2 and self.lr_multiplier < 10:
            self.lr_multiplier *= 1.5

        explained_var_old = 1 - np.var(
            np.array(winner_batch) - old_v.flatten()) / np.var(
                np.array(winner_batch))
        explained_var_new = 1 - np.var(
            np.array(winner_batch) - new_v.flatten()) / np.var(
                np.array(winner_batch))
        print(
            "kl:{:.5f},lr_multiplier:{:.3f},loss:{},entropy:{},explained_var_old:{:.3f},explained_var_new:{:.3f}"
            .format(kl, self.lr_multiplier, loss, entropy, explained_var_old,
                    explained_var_new))
        return loss, entropy

    def run(self):
        """训练"""
        try:
            for i in range(self.game_batch_num):
                self.collect_selfplay_data(self.play_batch_size)
                print("batch i:{}, episode_len:{}".format(
                    i + 1, self.episode_len))
                if len(self.data_buffer) > self.batch_size:
                    loss, entropy = self.policy_update()
                    print("LOSS:", loss)
                    # 保存loss
                    with open('loss.txt', 'a') as f:
                        f.writelines(str(loss) + '\n')
                if (i + 1) % self.check_freq == 0:
                    print("current self-play batch: {}".format(i + 1))
                    # win_ratio = self.policy_evaluate()
                    self.policy_value_net.save_model('current_policy')  # 保存模型
        except KeyboardInterrupt:
            print('\n\rquit')
Пример #17
0
class TrainPipeline():
    def __init__(self, init_model=None):
        # params of the board and the game
        self.board_length = 6
        self.n_in_row = 4
        self.num_history = 2
        self.chess = chessboard(self.board_length, self.n_in_row)
        # training params
        self.learn_rate = 5e-4
        self.lr_multiplier = 1.0  # adaptively adjust the learning rate based on KL
        self.temperature = 1.0  # the temperature param
        self.cpuct = 5
        self.buffer_size = 10000
        self.batch_size = 512
        self.data_buffer = deque(maxlen=self.buffer_size)
        self.play_batch_size = 1
        self.epochs = 10
        self.kl_targ = 0.02
        self.check_freq = 50
        self.best_win_ratio = 0.0
        self.game_batch_num = 4000
        self.loss_dict = {}
        self.loss_hold = 50
        
        self.real_mcts_simulation_times = 400
        self.pure_mcts_simulation_times = 1000
        if init_model:
            # start training from an initial policy-value net
            self.policy_value_net = PolicyValueNet(self.board_length,
                                                   self.num_history,
                                                   model_file=init_model)
        else:
            # start training from a new policy-value net
            self.policy_value_net = PolicyValueNet(self.board_length,
                                                   self.num_history)
# =============================================================================
#         deepcopy self.chess or not???????????????????????????????????????????
# =============================================================================
        self.mcts_player = real_mcts(self.chess,
                            self.policy_value_net.policy_value,
                            self.cpuct,
                            self.real_mcts_simulation_times,
                            self.temperature,
                            self.num_history,
                            True)
# =============================================================================
#         self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
#                                       c_puct=self.c_puct,
#                                       n_playout=self.n_playout,
#                                       is_selfplay=1)
# =============================================================================

    def get_equi_data(self, play_data):
        extend_data = []
        for state, mcts_porb, winner in play_data:
            for i in [1, 2, 3, 4]:
                # rotate counterclockwise
                equi_state = np.array([np.rot90(s, i) for s in state])
                equi_mcts_prob = np.rot90(np.flipud(
                    mcts_porb.reshape(self.board_length, self.board_length)), i)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
                # flip horizontally
                equi_state = np.array([np.fliplr(s) for s in equi_state])
                equi_mcts_prob = np.fliplr(equi_mcts_prob)
                extend_data.append((equi_state,
                                    np.flipud(equi_mcts_prob).flatten(),
                                    winner))
        return extend_data

    def collect_selfplay_data(self, n_games = 1):
        for i in range(n_games):
            inter = interface(self.board_length)
            current_board = copy.deepcopy(self.chess)
            current_real_mcts = real_mcts(current_board,
                                          self.policy_value_net.policy_value,
                                          self.cpuct,
                                          self.real_mcts_simulation_times,
                                          self.temperature,
                                          self.num_history,
                                          True)
            play_data = inter.start_self_play(player = current_real_mcts)
# =============================================================================
#             play_data = start_self_play(player = current_real_mcts)
# =============================================================================
            play_data = list(play_data)[:]
            self.episode_len = len(play_data)
            play_data = self.get_equi_data(play_data)
            self.data_buffer.extend(play_data)

    def policy_update(self):
        """update the policy-value net"""
        mini_batch = random.sample(self.data_buffer, self.batch_size)
# =============================================================================
#         mini_batch = self.data_buffer
# =============================================================================
        state_batch = [data[0] for data in mini_batch]
        mcts_probs_batch = [data[1] for data in mini_batch]
        winner_batch = [data[2] for data in mini_batch]
        old_probs, old_v = self.policy_value_net.policy_value(state_batch)
        first_loss = 0
        for i in range(self.epochs):
            loss, entropy = self.policy_value_net.train_step(state_batch,
                                                             mcts_probs_batch,
                                                             winner_batch,
                                                             self.learn_rate * self.lr_multiplier)
            if i == 0:
                first_loss = loss
# =============================================================================
#             if i % 10 == 0:
#                 print('loss: ', loss, ' entropy: ', entropy)
# =============================================================================
# =============================================================================
#             print('loss: ', loss, ' entropy: ', entropy)
# =============================================================================
            new_probs, new_v = self.policy_value_net.policy_value(state_batch)
            kl = np.mean(np.sum(old_probs * (np.log(old_probs + 1e-10) - np.log(new_probs + 1e-10)),axis=1))
# =============================================================================
#             if kl > self.kl_targ * 4:  # early stopping if D_KL diverges badly
#                 break
#         # adaptively adjust the learning rate
#         if kl > self.kl_targ * 2 and self.lr_multiplier > 0.01:
#             self.lr_multiplier /= 1.5
#         elif kl < self.kl_targ / 2 and self.lr_multiplier < 100:
#             self.lr_multiplier *= 1.5
# =============================================================================

        explained_var_old = (1 - np.var(np.array(winner_batch) - old_v.flatten()) / np.var(np.array(winner_batch)))
        explained_var_new = (1 - np.var(np.array(winner_batch) - new_v.flatten()) / np.var(np.array(winner_batch)))
        print(("kl:{:.5f},"
               "lr_multiplier:{:.3f},"
               "loss:{},"
               "loss_change:{},"
               "explained_var_old:{:.3f},"
               "explained_var_new:{:.3f}"
               ).format(kl,
                        self.lr_multiplier,
                        loss,
                        first_loss - loss,
                        explained_var_old,
                        explained_var_new))
        return loss, entropy

    def policy_evaluate(self, n_games=10):
        win_cnt = defaultdict(int)
        
        for i in range(n_games):
            inter = interface(self.board_length)
            current_board = copy.deepcopy(self.chess)
            current_real_mcts = real_mcts(current_board,
                                self.policy_value_net.policy_value,
                                self.cpuct,
                                1000,
                                self.temperature,
                                self.num_history,
                                False)
            current_pure_mcts = pure_mcts(current_board,
                                          self.pure_mcts_simulation_times)
            winner = inter.start_play(current_real_mcts,
                                      current_pure_mcts,
                                      start_player=i % 2)
            win_cnt[winner] += 1
            print('winner', winner)
        win_ratio = 1.0 * (win_cnt[1] + 0.5 * win_cnt[0]) / n_games
        print("num_simulation_times:{}, win: {}, lose: {}, tie:{}".format(self.pure_mcts_simulation_times,win_cnt[1], win_cnt[2], win_cnt[0]))
        return win_ratio

    def run(self):
        total = 0
        for i in range(self.game_batch_num):
            if (i + 1) % 100 == 0:
                self.learn_rate = self.learn_rate * 0.85
# =============================================================================
#             start = time.time()
# =============================================================================
            self.collect_selfplay_data(self.play_batch_size)
            if len(self.data_buffer) >= self.batch_size:
                loss, entropy = self.policy_update()
                self.loss_dict[i] = loss
                total += loss
            if (i - self.loss_hold) in self.loss_dict:
                total -= self.loss_dict[i - self.loss_hold]
                self.loss_dict.pop(i - self.loss_hold)
            print("batch i:{}, episode_len:{}, loss_hist:{}".format(i + 1, self.episode_len, total / self.loss_hold))
            if (i + 1) % self.check_freq == 0:
                print("current self-play batch: {}".format(i+1))
                win_ratio = self.policy_evaluate()
                self.policy_value_net.save_model('./current_policy.model')
                if win_ratio > self.best_win_ratio:
                    print("New best policy!!!!!!!!")
                    self.best_win_ratio = win_ratio
                    self.policy_value_net.save_model('./best_policy.model')
                    if (self.best_win_ratio == 1.0 and self.pure_mcts_simulation_times < 10000):
                        self.pure_mcts_simulation_times += 1000
                        self.best_win_ratio = 0.0