Beispiel #1
0
def data_prepare():
    logging_config('data_prepare.log', stream_log=True)
    from ckbqa.dataset.data_prepare import fit_on_texts, data_convert
    # map_mention_entity()
    # data2samples(neg_rate=3)
    data_convert()
    fit_on_texts()
    create_db_tabels()
Beispiel #2
0
def run(model_name, mode):
    logging_config(f'{model_name}_{mode}.log', stream_log=True)
    if model_name in ['bert_match', 'bert_match2']:
        from ckbqa.models.relation_score.trainer import RelationScoreTrainer
        RelationScoreTrainer(model_name).train_match_model()
    elif model_name == 'entity_score':
        from ckbqa.models.entity_score.model import EntityScore
        EntityScore().train()
Beispiel #3
0
def kb_data_prepare():
    logging_config('kb_data_prepare.log', stream_log=True)
    from ckbqa.dataset.kb_data_prepare import (candidate_words, fit_triples)
    from ckbqa.dataset.kb_data_prepare import create_graph_csv
    fit_triples()  # 生成字典
    candidate_words()  # 属性
    # create_lac_custom_dict()  # 自定义分词词典

    create_graph_csv()  # 生成数据库导入文件
Beispiel #4
0
def train_data():
    logging_config('train_evaluate.log', stream_log=True)
    from ckbqa.models.evaluation_matrics import get_metrics
    #
    partten = re.compile(r'["<](.*?)[>"]')
    #
    _paths = ResultSaver(find_exist_path=True).train_result_csv
    print(_paths)
    train_df = pd.read_csv(_paths[0])
    ceg_precisions, ceg_recalls, ceg_f1_scores = [], [], []
    answer_precisions, answer_recalls, answer_f1_scores = [], [], []
    for index, row in tqdm(train_df.iterrows(),
                           total=train_df.shape[0],
                           desc='evaluate '):
        subject_entities = partten.findall(
            row['standard_subject_entities'])  # 匹配文字
        if not subject_entities:
            subject_entities = eval(row['standard_subject_entities'])
        # 修复之前把实体<>去掉造成的问题;问题解析时去掉,但预测时未去掉;
        # 所以需要匹配文字,不匹配 <>, ""
        # CEG  Candidate Entity Generation
        candidate_entities = eval(row['candidate_entities']) + partten.findall(
            row['candidate_entities'])
        precision, recall, f1 = get_metrics(subject_entities,
                                            candidate_entities)
        ceg_precisions.append(precision)
        ceg_recalls.append(recall)
        ceg_f1_scores.append(f1)
        # Answer
        standard_entities = eval(row['standard_answer_entities'])
        result_entities = eval(row['result_entities'])
        precision, recall, f1 = get_metrics(standard_entities, result_entities)
        answer_precisions.append(precision)
        answer_recalls.append(recall)
        answer_f1_scores.append(f1)
        #
        # print(f"question: {row['question']}\n"
        #       f"subject_entities: {subject_entities}, candidate_entities: {candidate_entities}"
        #       f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n")
        # import time
        # time.sleep(2)
    ave_ceg_precision = sum(ceg_precisions) / len(ceg_precisions)
    ave_ceg_recall = sum(ceg_recalls) / len(ceg_recalls)
    ave_ceg_f1_score = sum(ceg_f1_scores) / len(ceg_f1_scores)
    print(f"ave_ceg_precision: {ave_ceg_precision:.3f}, "
          f"ave_ceg_recall: {ave_ceg_recall:.3f}, "
          f"ave_ceg_f1_score:{ave_ceg_f1_score:.3f}")
    #
    ave_answer_precision = sum(answer_precisions) / len(answer_precisions)
    ave_answer_recall = sum(answer_recalls) / len(answer_recalls)
    ave_answer_f1_score = sum(answer_f1_scores) / len(answer_f1_scores)
    print(f"ave_result_precision: {ave_answer_precision:.3f}, "
          f"ave_result_recall: {ave_answer_recall:.3f}, "
          f"ave_result_f1_score:{ave_answer_f1_score:.3f}")
Beispiel #5
0
def train_qa():
    """训练数据进行回答;做指标测试"""
    logging_config('train_qa.log', stream_log=True)
    from ckbqa.qa.qa import QA
    from ckbqa.dataset.data_prepare import load_data, question_patten, entity_pattern, attr_pattern
    # from ckbqa.qa.evaluation_matrics import get_metrics
    #
    logging.info('* start run ...')
    qa = QA()
    data = []
    for question, sparql, answer in load_data(tqdm_prefix='test qa'):
        print('\n' * 2)
        print('*****' * 20)
        print(f" question  : {question}")
        print(f" sparql    : {sparql}")
        print(f" standard answer   : {answer}")
        q_text = question_patten.findall(question)[0]
        standard_subject_entities = entity_pattern.findall(
            sparql) + attr_pattern.findall(sparql)
        standard_answer_entities = entity_pattern.findall(
            answer) + attr_pattern.findall(answer)
        try:
            (result_entities, candidate_entities, candidate_out_paths,
             candidate_in_paths) = qa.run(q_text, return_candidates=True)
        except KeyboardInterrupt:
            exit('Ctrl C , exit')
        except:
            logging.info(f'ERROR: {traceback.format_exc()}')
            result_entities = []
            candidate_entities = []
        # print(f" result answer   : {result_entities}")
        # precision, recall, f1 = get_metrics(subject_entities, candidate_entities)
        # if recall == 0 or len(set(standard_entities) & set(candidate_entities)) == 0:
        #     print(f"question: {question}\n"
        #           f"subject_entities: {subject_entities}, candidate_entities: {candidate_entities}"
        #           f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n")
        # import ipdb
        # ipdb.set_trace()
        data.append([
            question, standard_subject_entities,
            list(candidate_entities), standard_answer_entities, result_entities
        ])
    data_df = pd.DataFrame(data,
                           columns=[
                               'question', 'standard_subject_entities',
                               'candidate_entities',
                               'standard_answer_entities', 'result_entities'
                           ])
    data_df.to_csv(ResultSaver().train_result_csv,
                   index=False,
                   encoding='utf_8_sig')
Beispiel #6
0
def ceg():
    logging_config('train_evaluate.log', stream_log=True)
    from ckbqa.models.evaluation_matrics import get_metrics
    from ckbqa.qa.el import CEG
    from ckbqa.dataset.data_prepare import load_data, question_patten, entity_pattern, attr_pattern  #
    ceg = CEG()  # Candidate Entity Generation
    ceg_precisions, ceg_recalls, ceg_f1_scores = [], [], []
    ceg_csv = "./ceg.csv"
    data = []
    for q, sparql, a in load_data(tqdm_prefix='ceg evaluate '):
        q_entities = entity_pattern.findall(sparql) + attr_pattern.findall(
            sparql)

        q_text = ''.join(question_patten.findall(q))
        # 修复之前把实体<>去掉造成的问题;问题解析时去掉,但预测时未去掉;
        # 所以需要匹配文字,不匹配 <>, ""
        ent2mention = ceg.get_ent2mention(q_text)
        # CEG  Candidate Entity Generation
        precision, recall, f1 = get_metrics(q_entities, ent2mention)
        ceg_precisions.append(precision)
        ceg_recalls.append(recall)
        ceg_f1_scores.append(f1)
        #
        data.append([q, q_entities, list(ent2mention.keys())])
        if recall == 0:
            # ceg.memory.entity2id
            # ceg.memory.mention2entity
            print(
                f"question: {q}\n"
                f"subject_entities: {q_entities}, candidate_entities: {ent2mention}"
                f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n"
            )
            # import ipdb
            # ipdb.set_trace()
            print('\n\n')
        # import time
        # time.sleep(2)
    pd.DataFrame(data, columns=['question', 'q_entities',
                                'ceg']).to_csv(ceg_csv,
                                               index=False,
                                               encoding='utf_8_sig')
    ave_precision = sum(ceg_precisions) / len(ceg_precisions)
    ave_recall = sum(ceg_recalls) / len(ceg_recalls)
    ave_f1_score = sum(ceg_f1_scores) / len(ceg_f1_scores)
    print(f"ave_precision: {ave_precision:.3f}, "
          f"ave_recall: {ave_recall:.3f}, "
          f"ave_f1_score:{ave_f1_score:.3f}")
Beispiel #7
0
def test():
    logging_config('test.log', stream_log=True)
    from ckbqa.qa.qa import QA
    from ckbqa.dataset.data_prepare import question_patten
    # from ckbqa.qa.evaluation_matrics import get_metrics
    #
    logging.info('* start run ...')
    qa = QA()
    q188 = 'q188:墨冰仙是哪个门派的?'
    q189 = 'q189:在金庸小说《天龙八部》中,斗转星移的修习者是谁?'
    q190 = 'q190:《基督山伯爵》的作者是谁?'  # 这里跑没问题
    for question in [q188, q189, q190]:
        q_text = question_patten.findall(question)[0]
        (result_entities, candidate_entities, candidate_out_paths,
         candidate_in_paths) = qa.run(q_text, return_candidates=True)
        print(question)
        import ipdb
        ipdb.set_trace()
Beispiel #8
0
def valid2submit():
    '''
        验证数据输出转化为答案提交
    '''
    logging_config('valid2submit.log', stream_log=True)
    data_df = pd.read_csv(
        ResultSaver(find_exist_path=True).valid_result_csv[0])
    with open(ResultSaver().submit_result_txt, 'w', encoding='utf-8') as f:
        for index, row in data_df.iterrows():
            ents = []
            for ent in eval(row['result'])[:10]:  # 只保留前10个
                if ent.startswith('<'):
                    ents.append(ent)
                elif ent.startswith('"'):
                    ent = ent.strip('"')
                    ents.append(f'"{ent}"')
            if ents:
                f.write('\t'.join(ents) + '\n')
            else:
                f.write('""\n')
            f.flush()
Beispiel #9
0
def valid_qa():
    """验证数据答案;验证集做回答;得到提交数据  """
    logging_config('valid_qa.log', stream_log=True)
    from ckbqa.qa.qa import QA
    from ckbqa.dataset.data_prepare import question_patten
    #
    data_df = pd.DataFrame([
        question.strip()
        for question in open(valid_question_txt, 'r', encoding='utf-8')
    ],
                           columns=['question'])
    logging.info(f"data_df.shape: {data_df.shape}")
    qa = QA()
    valid_datas = {'question': [], 'result': []}
    with open(ResultSaver().submit_result_txt, 'w', encoding='utf-8') as f:
        for index, row in tqdm(data_df.iterrows(),
                               total=data_df.shape[0],
                               desc='qa answer'):
            question = row['question']
            q_text = question_patten.findall(question)[0]
            try:
                result_entities = qa.run(q_text)
                result_entities = [
                    ent if ent.startswith('<') else f'"{ent}"'
                    for ent in result_entities
                ]
                f.write('\t'.join(result_entities) + '\n')
            except KeyboardInterrupt:
                exit('Ctrl C , exit')
            except:
                result_entities = []
                f.write('""\n')
                logging.info(traceback.format_exc())
            f.flush()
            valid_datas['question'].append(question)
            valid_datas['result'].append(result_entities)
    pd.DataFrame(valid_datas).to_csv(ResultSaver().valid_result_csv,
                                     index=False,
                                     encoding='utf_8_sig')
Beispiel #10
0
        cur_dir = cur_dir.parent
        times -= 1
    print(cur_dir)
    sys.path.append(str(cur_dir))


add_root_path()

import logging

from ckbqa.qa.lac_tools import BaiduLac, JiebaLac
from ckbqa.utils.logger import logging_config
from ckbqa.utils.tools import pkl_dump, ProcessManager
from config import Config

logging_config('lac_test.log', stream_log=True)


def lac_model():
    # self.lac_seg = LAC(mode='seg')
    logging.info(
        f' load lac_custom_dict from  {Config.lac_custom_dict_txt} start ...')
    baidu_lac = BaiduLac(mode='lac', _load_customization=True,
                         reload=True)  # 装载LAC模型
    save_path = baidu_lac._save_customization()  # Config.lac_custom_dict_txt
    logging.info(f'load lac_custom_dict done, save to {save_path}...')


def lac_test():
    logging.info(
        f' load lac_custom_dict from  {Config.lac_custom_dict_txt} start ...')
Beispiel #11
0
def task():
    logging_config('qa_task.log', stream_log=True)
    train_qa()
    valid_qa()
Beispiel #12
0
def task():
    logging_config('ceg.log', stream_log=True)
    ceg()