示例#1
0
文件: test_.py 项目: wut0n9/ccks-2020
def ner_train_data():
    ent2mention = json_load(Config.ent2mention_json)
    # recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)
    tokenizer = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG')
    recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)
    from LAC import LAC
    # 装载LAC模型
    lac = LAC(mode='lac')
    jieba.enable_paddle()  # 启动paddle模式。 0.40版之后开始支持,早期版本不支持
    _ent_patten = re.compile(r'["<](.*?)[>"]')
    for q, sparql, a in load_data():
        q_text = question_patten.findall(q)[0]
        hanlp_entities = recognizer([list(q_text)])
        hanlp_words = tokenizer(q_text)
        lac_results = lac.run(q_text)
        q_entities = _ent_patten.findall(sparql)
        jieba_results = list(jieba.cut_for_search(q_text))
        mentions = [ent2mention.get(ent) for ent in q_entities]
        print(f"q_text: {q_text}\nq_entities: {q_entities}, "
              f"\nlac_results:{lac_results}"
              f"\nhanlp_words: {hanlp_words}, "
              f"\njieba_results: {jieba_results}, "
              f"\nhanlp_entities: {hanlp_entities}, "
              f"\nmentions: {mentions}")
        import ipdb
        ipdb.set_trace()
示例#2
0
def test_recgnizer():
    print('start')
    logging.info('test start ...')
    el = EL()
    for q, sparql, a in load_data():
        q_text = q.split(':')[1]
        rec_entities = el.ceg.get_ent2mention(q_text)
        print(rec_entities)
        print(q)
        print(sparql)
        print(a)
        import ipdb
        ipdb.set_trace()
示例#3
0
def test_recgnizer():
    print('start')
    logging.info('test start ...')
    recognizer = Recognizer()
    for q, sparql, a in load_data():
        q_text = q.split(':')[1]
        rec_entities = recognizer.find_entities(q_text)
        print(rec_entities)
        print(q)
        print(sparql)
        print(a)
        import ipdb
        ipdb.set_trace()
示例#4
0
文件: qa.py 项目: wut0n9/ccks-2020
def train_qa():
    """训练数据进行回答;做指标测试"""
    logging_config('train_qa.log', stream_log=True)
    from ckbqa.qa.qa import QA
    from ckbqa.dataset.data_prepare import load_data, question_patten, entity_pattern, attr_pattern
    # from ckbqa.qa.evaluation_matrics import get_metrics
    #
    logging.info('* start run ...')
    qa = QA()
    data = []
    for question, sparql, answer in load_data(tqdm_prefix='test qa'):
        print('\n' * 2)
        print('*****' * 20)
        print(f" question  : {question}")
        print(f" sparql    : {sparql}")
        print(f" standard answer   : {answer}")
        q_text = question_patten.findall(question)[0]
        standard_subject_entities = entity_pattern.findall(
            sparql) + attr_pattern.findall(sparql)
        standard_answer_entities = entity_pattern.findall(
            answer) + attr_pattern.findall(answer)
        try:
            (result_entities, candidate_entities, candidate_out_paths,
             candidate_in_paths) = qa.run(q_text, return_candidates=True)
        except KeyboardInterrupt:
            exit('Ctrl C , exit')
        except:
            logging.info(f'ERROR: {traceback.format_exc()}')
            result_entities = []
            candidate_entities = []
        # print(f" result answer   : {result_entities}")
        # precision, recall, f1 = get_metrics(subject_entities, candidate_entities)
        # if recall == 0 or len(set(standard_entities) & set(candidate_entities)) == 0:
        #     print(f"question: {question}\n"
        #           f"subject_entities: {subject_entities}, candidate_entities: {candidate_entities}"
        #           f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n")
        # import ipdb
        # ipdb.set_trace()
        data.append([
            question, standard_subject_entities,
            list(candidate_entities), standard_answer_entities, result_entities
        ])
    data_df = pd.DataFrame(data,
                           columns=[
                               'question', 'standard_subject_entities',
                               'candidate_entities',
                               'standard_answer_entities', 'result_entities'
                           ])
    data_df.to_csv(ResultSaver().train_result_csv,
                   index=False,
                   encoding='utf_8_sig')
示例#5
0
def ceg():
    logging_config('train_evaluate.log', stream_log=True)
    from ckbqa.models.evaluation_matrics import get_metrics
    from ckbqa.qa.el import CEG
    from ckbqa.dataset.data_prepare import load_data, question_patten, entity_pattern, attr_pattern  #
    ceg = CEG()  # Candidate Entity Generation
    ceg_precisions, ceg_recalls, ceg_f1_scores = [], [], []
    ceg_csv = "./ceg.csv"
    data = []
    for q, sparql, a in load_data(tqdm_prefix='ceg evaluate '):
        q_entities = entity_pattern.findall(sparql) + attr_pattern.findall(
            sparql)

        q_text = ''.join(question_patten.findall(q))
        # 修复之前把实体<>去掉造成的问题;问题解析时去掉,但预测时未去掉;
        # 所以需要匹配文字,不匹配 <>, ""
        ent2mention = ceg.get_ent2mention(q_text)
        # CEG  Candidate Entity Generation
        precision, recall, f1 = get_metrics(q_entities, ent2mention)
        ceg_precisions.append(precision)
        ceg_recalls.append(recall)
        ceg_f1_scores.append(f1)
        #
        data.append([q, q_entities, list(ent2mention.keys())])
        if recall == 0:
            # ceg.memory.entity2id
            # ceg.memory.mention2entity
            print(
                f"question: {q}\n"
                f"subject_entities: {q_entities}, candidate_entities: {ent2mention}"
                f"precision: {precision:.4f}, recall: {recall:.4f}, f1: {f1:.4f}\n\n"
            )
            # import ipdb
            # ipdb.set_trace()
            print('\n\n')
        # import time
        # time.sleep(2)
    pd.DataFrame(data, columns=['question', 'q_entities',
                                'ceg']).to_csv(ceg_csv,
                                               index=False,
                                               encoding='utf_8_sig')
    ave_precision = sum(ceg_precisions) / len(ceg_precisions)
    ave_recall = sum(ceg_recalls) / len(ceg_recalls)
    ave_f1_score = sum(ceg_f1_scores) / len(ceg_f1_scores)
    print(f"ave_precision: {ave_precision:.3f}, "
          f"ave_recall: {ave_recall:.3f}, "
          f"ave_f1_score:{ave_f1_score:.3f}")
示例#6
0
def lac_test():
    from LAC import LAC
    print('start ...')
    mention_count = Counter(json_load('mention2count.json')).most_common(10000)
    customization_dict = {
        mention: 'MENTION'
        for mention, count in mention_count if len(mention) >= 2
    }
    print('second ...')
    ent_count = Counter(json_load('ent2count.json')).most_common(300000)
    ent_pattrn = re.compile(r'["<](.*?)[>"]')
    customization_dict.update(
        {' '.join(ent_pattrn.findall(ent)): 'ENT'
         for ent, count in ent_count})
    with open('./customization_dict.txt', 'w') as f:
        f.write('\n'.join([
            f"{e}/{t}" for e, t in customization_dict.items() if len(e) >= 3
        ]))
    import time
    before = time.time()
    print(f'before ...{before}')
    lac = LAC(mode='lac')
    lac.load_customization('./customization_dict.txt')  # 20万21s;30万47s
    lac_raw = LAC(mode='lac')
    print(f'after ...{time.time() - before}')
    ##
    test_count = 10
    for q, sparql, a in load_data():
        q_text = q.split(':')[1]
        print('---' * 10)
        print(q_text)
        print(sparql)
        print(a)
        words, tags = lac_raw.run(q_text)
        print(list(zip(words, tags)))
        words, tags = lac.run(q_text)
        print(list(zip(words, tags)))
        if not test_count:
            break
        test_count -= 1
    import ipdb
    ipdb.set_trace()
示例#7
0
 def gen_train_data(self):
     X_train = []
     Y_label = []
     from ckbqa.qa.el import EL  # 避免循环导入
     el = EL()
     for q, sparql, a in load_data(tqdm_prefix='EntityScore train data '):
         # a_entities = entity_pattern.findall(a)
         q_entities = set(
             entity_pattern.findall(sparql) +
             attr_pattern.findall(sparql))  # attr
         q_text = question_patten.findall(q)[0]
         candidate_entities = el.el(q_text)
         for ent_name, feature_dict in candidate_entities.items():
             feature = feature_dict['feature']
             label = 1 if ent_name in q_entities else 0  # 候选实体有的不在答案中
             X_train.append(feature)
             Y_label.append(label)
     pkl_dump({
         'x_data': X_train,
         'y_label': Y_label
     }, Config.entity_score_data_pkl)
示例#8
0
 def data2samples(self, neg_rate=3, test_size=0.1):
     if os.path.isfile(
             Config.get_relation_score_sample_csv('train', neg_rate)):
         return
     questions = []
     sim_questions = []
     labels = []
     all_relations = list(json_load(Config.relation2id))
     _entity_pattern = re.compile(r'["<](.*?)[>"]')
     for q, sparql, a in load_data(tqdm_prefix='data2samples '):
         q_text = question_patten.findall(q)[0]
         q_entities = _entity_pattern.findall(sparql)
         questions.append(q_text)
         sim_questions.append('的'.join(q_entities))
         labels.append(1)
         #
         for neg_relation in random.sample(all_relations, neg_rate):
             questions.append(q_text)
             neg_question = '的'.join(q_entities[:-1] +
                                     [neg_relation])  # 随机替换 <关系>
             sim_questions.append(neg_question)
             labels.append(0)
     data_df = pd.DataFrame({
         'question': questions,
         'sim_question': sim_questions,
         'label': labels
     })
     data_df.to_csv(Config.relation_score_sample_csv,
                    encoding='utf_8_sig',
                    index=False)
     train_df, test_df = train_test_split(data_df, test_size=test_size)
     test_df.to_csv(Config.get_relation_score_sample_csv('test', neg_rate),
                    encoding='utf_8_sig',
                    index=False)
     train_df.to_csv(Config.get_relation_score_sample_csv(
         'train', neg_rate),
                     encoding='utf_8_sig',
                     index=False)