示例#1
0
文件: test_.py 项目: wut0n9/ccks-2020
def ner_train_data():
    ent2mention = json_load(Config.ent2mention_json)
    # recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)
    tokenizer = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG')
    recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH)
    from LAC import LAC
    # 装载LAC模型
    lac = LAC(mode='lac')
    jieba.enable_paddle()  # 启动paddle模式。 0.40版之后开始支持,早期版本不支持
    _ent_patten = re.compile(r'["<](.*?)[>"]')
    for q, sparql, a in load_data():
        q_text = question_patten.findall(q)[0]
        hanlp_entities = recognizer([list(q_text)])
        hanlp_words = tokenizer(q_text)
        lac_results = lac.run(q_text)
        q_entities = _ent_patten.findall(sparql)
        jieba_results = list(jieba.cut_for_search(q_text))
        mentions = [ent2mention.get(ent) for ent in q_entities]
        print(f"q_text: {q_text}\nq_entities: {q_entities}, "
              f"\nlac_results:{lac_results}"
              f"\nhanlp_words: {hanlp_words}, "
              f"\njieba_results: {jieba_results}, "
              f"\nhanlp_entities: {hanlp_entities}, "
              f"\nmentions: {mentions}")
        import ipdb
        ipdb.set_trace()
示例#2
0
def create_graph_csv():
    """  生成数据库导入文件
    cd /home/wangshengguang/neo4j-community-3.4.5/bin/
    ./neo4j-admin import --database=graph.db --nodes /home/wangshengguang/ccks-2020/data/graph_entity.csv  --relationships /home/wangshengguang/ccks-2020/data/graph_relation2.csv --ignore-duplicate-nodes=true --id-type INTEGER --ignore-missing-nodes=true
    CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.id IS UNIQUE;
    CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.name IS UNIQUE;
    CREATE CONSTRAINT ON (r:Relation) ASSERT r.name IS UNIQUE;
    """
    logging.info('start load Config.id2entity ..')
    entity2id = json_load(Config.entity2id)
    pd.DataFrame.from_records(
        [(id, ent, 'Entity') for ent, id in entity2id.items()],
        columns=["id:ID(Entity)", "name:String",
                 ":LABEL"]).to_csv(Config.graph_entity_csv,
                                   index=False,
                                   encoding='utf_8_sig')
    #
    records = [[
        entity2id[head_name], entity2id[tail_name], 'Relation', rel_name
    ] for head_name, rel_name, tail_name in iter_triples(
        tqdm_prefix='gen relation csv')]
    del entity2id
    gc.collect()
    pd.DataFrame.from_records(records,
                              columns=[
                                  ":START_ID(Entity)", ":END_ID(Entity)",
                                  ":TYPE", "name:String"
                              ]).to_csv(Config.graph_relation_csv,
                                        index=False,
                                        encoding='utf_8_sig')
示例#3
0
def create_lac_custom_dict():
    """生成自定义分词词典"""
    logging.info('create_lac_custom_dict start...')
    ent_counter, rel_counter, mention_counter = _get_top_counter()
    mention_count = mention_counter.most_common(50 * 10000)  #
    customization_dict = {
        mention: 'MENTION'
        for mention, count in mention_count if 2 <= len(mention) <= 8
    }
    del mention_count, mention_counter
    logging.info('create ent&rel customization_dict ...')
    _ent_pattrn = re.compile(r'["<](.*?)[>"]')
    # customization_dict.update({' '.join(_ent_pattrn.findall(rel)): 'REL'
    #                            for rel, count in rel_counter.most_common(10000)  # 10万
    #                            if 2 <= len(rel) <= 8})
    # del rel_counter
    customization_dict.update({
        ' '.join(_ent_pattrn.findall(ent)): 'ENT'
        for ent, count in ent_counter.most_common(50 * 10000)  # 100万
        if 2 <= len(ent) <= 8
    })
    q_entity2id = json_load(Config.q_entity2id_json)
    q_entity2id.update(json_load(Config.a_entity2id_json))
    customization_dict.update({
        ' '.join(_ent_pattrn.findall(ent)): 'ENT'
        for ent, _id in q_entity2id.items()
    })
    del ent_counter
    with open(Config.lac_custom_dict_txt, 'w') as f:
        for e, t in customization_dict.items():
            if len(e) >= 3:
                f.write(f"{e}/{t}\n")
    logging.info('attr_custom_dict gen start ...')
    entity2attrs = json_load(Config.entity2attrs_json)
    all_attrs = set()
    for attrs in entity2attrs.values():
        all_attrs.update(attrs)
    name_patten = re.compile('"(.*?)"')
    with open(Config.lac_attr_custom_dict_txt, 'w') as f:
        for _attr in all_attrs:
            attr = ' '.join(name_patten.findall(_attr))
            if len(attr) >= 2:
                f.write(f"{attr}/ATTR\n")
示例#4
0
def lac_test():
    from LAC import LAC
    print('start ...')
    mention_count = Counter(json_load('mention2count.json')).most_common(10000)
    customization_dict = {
        mention: 'MENTION'
        for mention, count in mention_count if len(mention) >= 2
    }
    print('second ...')
    ent_count = Counter(json_load('ent2count.json')).most_common(300000)
    ent_pattrn = re.compile(r'["<](.*?)[>"]')
    customization_dict.update(
        {' '.join(ent_pattrn.findall(ent)): 'ENT'
         for ent, count in ent_count})
    with open('./customization_dict.txt', 'w') as f:
        f.write('\n'.join([
            f"{e}/{t}" for e, t in customization_dict.items() if len(e) >= 3
        ]))
    import time
    before = time.time()
    print(f'before ...{before}')
    lac = LAC(mode='lac')
    lac.load_customization('./customization_dict.txt')  # 20万21s;30万47s
    lac_raw = LAC(mode='lac')
    print(f'after ...{time.time() - before}')
    ##
    test_count = 10
    for q, sparql, a in load_data():
        q_text = q.split(':')[1]
        print('---' * 10)
        print(q_text)
        print(sparql)
        print(a)
        words, tags = lac_raw.run(q_text)
        print(list(zip(words, tags)))
        words, tags = lac.run(q_text)
        print(list(zip(words, tags)))
        if not test_count:
            break
        test_count -= 1
    import ipdb
    ipdb.set_trace()
示例#5
0
 def load_cache(self):
     # self.queue = Queue(maxsize=1)
     if os.path.isfile(Config.neo4j_query_cache):
         data_dict = json_load(Config.neo4j_query_cache)
         self._one_hop_relNames_map = data_dict['_one_hop_relNames_map']
         self._two_hop_relNames_map = data_dict['_two_hop_relNames_map']
         self.total_count = self.get_total_entity_count()
         logging.info(f'load neo4j_query_cache, total: {self.total_count}')
     else:
         logging.info(
             f'not found neo4j_query_cache: {Config.neo4j_query_cache}')
         self.total_count = 0
示例#6
0
 def load_custom_dict(self):
     if not os.path.isfile(Config.jieba_custom_dict):
         all_ent_mention_words = set(
             json_load(Config.mention2count_json).keys())
         entity_set = set(json_load(Config.entity2id).keys())
         for ent in tqdm(entity_set,
                         total=len(entity_set),
                         desc='create jieba words '):
             if ent.startswith('<'):
                 ent = ent[1:-1]
             all_ent_mention_words.add(ent)
         # FREQ,total,
         # 模仿jieba.add_word,重写逻辑,加速
         # jieba.dt.FREQ = {}
         # jieba.dt.total = 0
         for word in tqdm(all_ent_mention_words,
                          desc='jieba custom create '):
             freq = len(word) * 3
             jieba.dt.FREQ[word] = freq
             jieba.dt.total += freq
             for ch in range(len(word)):
                 wfrag = word[:ch + 1]
                 if wfrag not in jieba.dt.FREQ:
                     jieba.dt.FREQ[wfrag] = 0
         del all_ent_mention_words
         gc.collect()
         json_dump({
             'dt.FREQ': jieba.dt.FREQ,
             'dt.total': jieba.dt.total
         }, Config.jieba_custom_dict)
         logging.info('create jieba custom dict done ...')
     # load
     jieba_custom = json_load(Config.jieba_custom_dict)
     jieba.dt.check_initialized()
     jieba.dt.FREQ = jieba_custom['dt.FREQ']
     jieba.dt.total = jieba_custom['dt.total']
     logging.info('load jieba custom dict done ...')
示例#7
0
def _get_top_counter():
    """26G,高频实体和mention,作为后期筛选和lac字典;
        统计实体和mention出现的次数; 方便取top作为最终自定义分词的词典;
    """
    logging.info('kb_count_top_dict start ...')
    if not (os.path.isfile(Config.entity2count_json)
            and os.path.isfile(Config.relation2count_json)):
        entities = []
        relations = []
        for head_ent, rel, tail_ent in iter_triples(
                tqdm_prefix='kb_top_count_dict '):
            entities.extend([head_ent, tail_ent])
            relations.append(rel)
        ent_counter = Counter(entities)
        json_dump(dict(ent_counter), Config.entity2count_json)
        del entities
        rel_counter = Counter(relations)
        del relations
        json_dump(dict(rel_counter), Config.relation2count_json)
    else:
        ent_counter = Counter(json_load(Config.entity2count_json))
        rel_counter = Counter(json_load(Config.relation2count_json))
    #
    if not os.path.isfile(Config.mention2count_json):
        mentions = []
        for line in tqdm_iter_file(
                mention2ent_txt,
                prefix='count_top_dict iter mention2ent.txt '):
            mention, ent, rank = line.split('\t')  # 有部分数据有问题,mention为空字符串
            mentions.append(mention)
        mention_counter = Counter(mentions)
        del mentions
        json_dump(dict(mention_counter), Config.mention2count_json)
    else:
        mention_counter = Counter(json_load(Config.mention2count_json))
    return ent_counter, rel_counter, mention_counter
示例#8
0
 def data2samples(self, neg_rate=3, test_size=0.1):
     if os.path.isfile(
             Config.get_relation_score_sample_csv('train', neg_rate)):
         return
     questions = []
     sim_questions = []
     labels = []
     all_relations = list(json_load(Config.relation2id))
     _entity_pattern = re.compile(r'["<](.*?)[>"]')
     for q, sparql, a in load_data(tqdm_prefix='data2samples '):
         q_text = question_patten.findall(q)[0]
         q_entities = _entity_pattern.findall(sparql)
         questions.append(q_text)
         sim_questions.append('的'.join(q_entities))
         labels.append(1)
         #
         for neg_relation in random.sample(all_relations, neg_rate):
             questions.append(q_text)
             neg_question = '的'.join(q_entities[:-1] +
                                     [neg_relation])  # 随机替换 <关系>
             sim_questions.append(neg_question)
             labels.append(0)
     data_df = pd.DataFrame({
         'question': questions,
         'sim_question': sim_questions,
         'label': labels
     })
     data_df.to_csv(Config.relation_score_sample_csv,
                    encoding='utf_8_sig',
                    index=False)
     train_df, test_df = train_test_split(data_df, test_size=test_size)
     test_df.to_csv(Config.get_relation_score_sample_csv('test', neg_rate),
                    encoding='utf_8_sig',
                    index=False)
     train_df.to_csv(Config.get_relation_score_sample_csv(
         'train', neg_rate),
                     encoding='utf_8_sig',
                     index=False)
示例#9
0
 def __init__(self):
     self.entity2id = json_load(Config.entity2id)
     self.mention2entity = json_load(Config.mention2ent_json)