Ejemplo n.º 1
0
    def start_selfplay(self,
                       batch_num=10000,
                       c_puct=5,
                       n_playout=400,
                       best_model=None):
        """
        启动持续的selfplay,用于为模型train生成训练数据
        Params:
            batch_num   selfplay对战次数
            c_puct      MCTS child搜索深度
            n_playout   模型训练时每个action的mcts模拟次数
        """
        logging.info("__start_selfplay__")
        # 1.init net & ai player
        model_last_mdy_time = os.stat(best_model).st_mtime if os.path.exists(
            best_model) else time.time()  # 模型最后更新时间
        policy_value_net = self._load_policy_value_net(best_model)
        ai_player = AIPlayer(policy_value_net.policy_value_fn,
                             c_puct=c_puct,
                             n_playout=n_playout,
                             is_selfplay=1)

        # 2.start selfplay
        try:
            for i in range(batch_num):  # 对战盘数
                # 2.1使用MCTS蒙特卡罗树搜索进行自我对抗
                logging.info("selfplay batch start: {}".format(i + 1))
                winner, play_data = self._selfplay(ai_player)
                logging.info(
                    "selfplay batch res. batch:{}, winner:{}, step_num:{}".
                    format(i + 1, winner, len(play_data)))
                # 2.2保存本局数据到databuffer目录文件
                data_file = self._get_databuffer_file(event=n_playout,
                                                      winner=winner,
                                                      step_num=len(play_data))
                utils.pickle_dump(play_data, data_file)
                logging.info("selfplay batch save. batch:{}, file:{}".format(
                    i + 1, data_file))
                # 2.3检查是否有新的模型需要reload
                model_time = os.stat(best_model).st_mtime if os.path.exists(
                    best_model) else time.time()  # 模型最后更新时间
                if model_time > model_last_mdy_time:
                    logging.info(
                        "selfplay reload model! new:{} > old:{}".format(
                            utils.get_date(os.stat(best_model).st_mtime),
                            utils.get_date(model_last_mdy_time)))
                    model_last_mdy_time = os.stat(
                        best_model).st_mtime if os.path.exists(
                            best_model) else time.time()  # 模型最后更新时间
                    policy_value_net = self._load_policy_value_net(best_model)
                    ai_player = AIPlayer(policy_value_net.policy_value_fn,
                                         c_puct=c_puct,
                                         n_playout=n_playout,
                                         is_selfplay=1)

        except KeyboardInterrupt:
            logging.info('\n\rselfplay quit')
Ejemplo n.º 2
0
 def save_model(self, model_file):
     """保存模型参数到文件"""
     net_params = self.get_policy_param()
     #pickle.dump(net_params, open(model_file, 'wb'), protocol=4)
     utils.pickle_dump(net_params, model_file)
Ejemplo n.º 3
0
 def save_model(self, model_file):
     """ save model params to file """
     net_params = self.get_policy_param()  # get model params
     #pickle.dump(net_params, open(model_file, 'wb'), protocol=4)
     utils.pickle_dump(net_params, model_file)
Ejemplo n.º 4
0
    def execute(self):
        """执行业务逻辑"""
        logging.info('API REQUEST INFO[' + self.request.path + '][' +
                     self.request.method + '][' + self.request.remote_ip +
                     '][' + str(self.request.arguments) + ']')
        session_id = self.get_argument('session_id', '')
        res = {
            'session_id': session_id,
            'player': -1,
            'step': 0,
            'move': '',
            'san': '',
            'end': False,
            'winner': -1,
            'curr_player': 0,
            'state': {},
            'ponder': '',
            'score': -1
        }
        move = self.get_argument('move', '')
        if session_id == '':
            return {'code': 2, 'msg': 'session_id不能为空', 'data': res}

        try:
            # 1.新的对局
            session = {}
            if session_id not in self.games:
                logging.info("[{}] init new game!".format(session_id))
                # plays id
                session['human_player_id'] = int(
                    self.get_argument('human_player_id', '1'))  # human默认执黑
                session['ai_player_id'] = (session['human_player_id'] +
                                           1) % 2  # ai与human相反
                session['players'] = {
                    session['human_player_id']: 'Human',
                    session['ai_player_id']: 'AI'
                }
                session['step'] = 0
                session['actions'], session['mcts_probs'] = [], []
                # 初始化棋盘
                session['game'] = Game()
                session['game'].board.init_board()
                # 初始化AI棋手
                #session['ai_player'] = AIPlayer(self.best_policy.policy_value_fn, n_playout=50)
                session['ai_player'] = StockfishPlayer()
                self.games[session_id] = session
            else:
                session = self.games[session_id]
                # clear old games
                for k in list(self.games.keys()):
                    if int(time.time()
                           ) - int(k) / 1000 > 60 * 40:  # 超过40分钟的session清理
                        del (self.games[k])
                        logging.warning("[{}] timeout clear!".format(k))
            # 2.get ai move
            res['players'], res['human_player_id'], res[
                'ai_player_id'] = session['players'], session[
                    'human_player_id'], session['ai_player_id']
            res['curr_player'] = session['game'].board.current_player_id
            res['availables'] = [
                session['game'].board.action_to_move(act)
                for act in session['game'].board.availables
            ]
            res['state'] = session['game'].board.state()
            action = -1
            if res['curr_player'] == session[
                    'ai_player_id']:  # 轮到ai时,忽略传入的move参数
                action, probs, ponder, res['score'] = session[
                    'ai_player'].get_action(session['game'].board,
                                            return_prob=1,
                                            return_ponder=1,
                                            return_score=1)
                move = session['game'].board.action_to_move(action)
                res['ponder'] = session['game'].board.action_to_move(ponder)
                logging.info(
                    "[{}] {} AI move: {}  Score: {} Ponder: {}".format(
                        session_id, res['curr_player'], move, res['score'],
                        res['ponder']))
                # save state
                session['actions'].append(
                    session['game'].board.current_actions())
                session['mcts_probs'].append(probs)
            else:  # 轮到human走
                if len(move) < 4:  # 没有传入move
                    logging.info("[{}] {} Human need give move !".format(
                        session_id, res['curr_player']))
                    return {'code': 2, 'msg': '轮到人类走子', 'data': res}
                logging.info("[{}] {} Human move: {}".format(
                    session_id, res['curr_player'], move))
                action = session['game'].board.move_to_action(move)
                if action not in session[
                        'game'].board.availables:  # human action不合法
                    logging.info(
                        "[{}] {} Human action ({},{}) invalid !".format(
                            session_id, res['curr_player'], move, action))
                    return {
                        'code': 3,
                        'msg': '错误的落子位置:{}'.format(move),
                        'data': res
                    }
                # save state
                session['actions'].append(
                    session['game'].board.current_actions())
                probs = np.zeros(session['game'].board.action_ids_size)
                probs[action] = 0.01
                session['mcts_probs'].append(probs)

            # 3.do move
            if len(move) >= 4 and action != -1:
                # do move
                res['san'] = session['game'].board.move_to_san(move)
                session['game'].board.do_move(action)  # do move
                try:
                    if len(res['ponder']) > 2:
                        res['ponder'] = session['game'].board.move_to_san(
                            res['ponder'])
                except:
                    logging.warning(utils.get_trace())
                session['step'] += 1
                res['player'], res['move'], res['step'] = res[
                    'curr_player'], move, session['step']
                res['end'], res['winner'] = session['game'].board.game_end()
                res['curr_player'] = session['game'].board.current_player_id
                res['availables'] = [
                    session['game'].board.action_to_move(act)
                    for act in session['game'].board.availables
                ]
                res['state'] = session['game'].board.state(
                )  # res['state'][move[:2]] + move[:2]
                # save state -> databuffer
                if res['end']:
                    # 从当前玩家视角确定winner
                    winners_z = np.zeros(
                        len(session['game'].board.book_variations['all']))
                    if res['winner'] != -1:  # 不是和棋
                        for i in range(len(winners_z)):
                            if (i + res['winner']) % 2 == 0:
                                winners_z[i] = 1.0  # 更新赢家步骤位置=1
                            else:
                                winners_z[i] = -1.0  # 更新输家步骤位置=-1
                    play_data = list(
                        zip(session['actions'], session['mcts_probs'],
                            winners_z))[:]
                    data_file = session['game']._get_databuffer_file(
                        event='vs',
                        winner=res['winner'],
                        white=session['players'][0],
                        black=session['players'][1],
                        step_num=len(play_data))
                    utils.pickle_dump(play_data, data_file)
                    logging.info(
                        "api vs play save to databuffer: {}".format(data_file))

                return {'code': 0, 'msg': 'success', 'data': res}
        except:
            logging.error('execute fail [' + str(move) + '][' + session_id +
                          '] ' + utils.get_trace())
            return {'code': 5, 'msg': '请求失败', 'data': res}

        # 组织返回格式
        return {'code': 0, 'msg': 'success', 'data': res}
Ejemplo n.º 5
0
    def pgn_to_databuffer(self, pgn_file):
        """将pgn棋谱转为databuffer用于模型训练"""
        logging.info("__pgn_to_databuffer__ {}".format(pgn_file))
        from xpinyin import Pinyin
        pinyin = Pinyin()
        # 1.加载棋谱
        pgn = open(CUR_PATH + "/data/pgn/" + pgn_file)
        # 2.读取第一局
        game = chess.pgn.read_game(pgn)
        batch = 0
        while game:  # 棋谱包含多局
            batch += 1
            logging.info(game)
            logging.info(game.headers)
            winner = self._get_pgn_winner(game.headers['Result'])
            event = pinyin.get_pinyin(
                game.headers['Event'].replace(' ', '').replace(
                    '.', '').replace('-', '').replace('/', '').replace(
                        '(', '').replace(')', '').replace("'", ""), "")
            white = pinyin.get_pinyin(
                game.headers['White'].replace(' ', '').replace(
                    ',', '').replace('-', '').replace('/', '').replace(
                        '(', '').replace(')', '').replace("'", ""), "")
            black = pinyin.get_pinyin(
                game.headers['Black'].replace(' ', '').replace(
                    ',', '').replace('-', '').replace('/', '').replace(
                        '(', '').replace(')', '').replace("'", ""), "")
            players = [white, black]
            # 3.重放对局过程,获得playdata
            # 初始化棋盘
            self.board.init_board()
            self.board.graphic()
            actions, mcts_probs = [], []
            # 重放走子
            step = 0
            # for move in game.mainline_moves():
            moves = game.mainline_moves().__iter__()
            move = next(moves, None)
            while move:
                step += 1
                logging.info("step: {},  curr: {} {},  winner: {}".format(
                    step, Board.PLAYERS[self.board.current_player_id].upper(),
                    players[self.board.current_player_id],
                    self.board.current_player_id == winner))
                actions.append(self.board.current_actions())
                # 执行落子
                action = self.board.move_to_action(move)
                if action == -1:
                    logging.error("invalid move! {}".format(move))
                    break
                probs = np.zeros(self.board.action_ids_size)
                if self.board.current_player_id == winner:
                    probs[action] = 1.0
                else:
                    probs[action] = 0.8
                if pgn_file == 'chessease.pgn':  #非top大师棋谱,权重调低
                    probs[action] *= 0.1
                mcts_probs.append(probs)
                logging.info("{}'s probs: {}".format(
                    Board.PLAYERS[self.board.current_player_id].upper(),
                    {self.board.action_to_move(action): probs[action]}))
                self.board.do_move(action)
                self.board.graphic()
                # next move
                move = next(moves, None)
                # 检查游戏是否结束
                end, win = self.board.game_end()
                agreement = ""
                if end or move is None:
                    if end is False and move is None:  # 人工投降或协议和棋了
                        agreement = "Agreement"
                    # 从当前玩家视角确定winner
                    winners_z = np.zeros(len(
                        self.board.book_variations['all']))
                    if winner != -1:  # 不是和棋
                        for i in range(len(winners_z)):
                            if (i + winner) % 2 == 0:
                                winners_z[i] = 1.0  # 更新赢家步骤位置=1
                            else:
                                winners_z[i] = -1.0  # 更新输家步骤位置=-1
                    if winner != -1:
                        logging.info("Game end. {} Winner is {}".format(
                            agreement, Board.PLAYERS[winner]))
                    else:
                        logging.info("Game end. {} Tie".format(agreement))
                    # print(actions, mcts_probs, winners_z)
                    # print(list(zip(actions, mcts_probs, winners_z))[:])
                    play_data = list(zip(actions, mcts_probs, winners_z))[:]
                    if len(play_data) < 7:  # 6步不足以将杀,肯定是人工协议和棋,没有训练意义
                        continue
                    if len(play_data
                           ) < 30 and winner == -1:  # 30步内和棋肯定是人工协议和棋,没有训练意义
                        continue

                    # 4.保存本局数据到databuffer目录文件
                    data_file = self._get_databuffer_file(
                        date=game.headers['Date'].replace('.', '').replace(
                            '?', '0'),
                        event=event,
                        winner=winner,
                        white=white,
                        black=black,
                        step_num=len(play_data),
                        agreement=agreement)
                    utils.pickle_dump(play_data, data_file)
                    logging.info(
                        "pgn_to_databuffer save. pgn:{}, batch:{}, databuffer:{}"
                        .format(pgn_file, batch, data_file))
            # 5.读取棋谱下一局
            game = chess.pgn.read_game(pgn)