def ner_train_data(): ent2mention = json_load(Config.ent2mention_json) # recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH) tokenizer = hanlp.load('PKU_NAME_MERGED_SIX_MONTHS_CONVSEG') recognizer = hanlp.load(hanlp.pretrained.ner.MSRA_NER_BERT_BASE_ZH) from LAC import LAC # 装载LAC模型 lac = LAC(mode='lac') jieba.enable_paddle() # 启动paddle模式。 0.40版之后开始支持,早期版本不支持 _ent_patten = re.compile(r'["<](.*?)[>"]') for q, sparql, a in load_data(): q_text = question_patten.findall(q)[0] hanlp_entities = recognizer([list(q_text)]) hanlp_words = tokenizer(q_text) lac_results = lac.run(q_text) q_entities = _ent_patten.findall(sparql) jieba_results = list(jieba.cut_for_search(q_text)) mentions = [ent2mention.get(ent) for ent in q_entities] print(f"q_text: {q_text}\nq_entities: {q_entities}, " f"\nlac_results:{lac_results}" f"\nhanlp_words: {hanlp_words}, " f"\njieba_results: {jieba_results}, " f"\nhanlp_entities: {hanlp_entities}, " f"\nmentions: {mentions}") import ipdb ipdb.set_trace()
def create_graph_csv(): """ 生成数据库导入文件 cd /home/wangshengguang/neo4j-community-3.4.5/bin/ ./neo4j-admin import --database=graph.db --nodes /home/wangshengguang/ccks-2020/data/graph_entity.csv --relationships /home/wangshengguang/ccks-2020/data/graph_relation2.csv --ignore-duplicate-nodes=true --id-type INTEGER --ignore-missing-nodes=true CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.id IS UNIQUE; CREATE CONSTRAINT ON (ent:Entity) ASSERT ent.name IS UNIQUE; CREATE CONSTRAINT ON (r:Relation) ASSERT r.name IS UNIQUE; """ logging.info('start load Config.id2entity ..') entity2id = json_load(Config.entity2id) pd.DataFrame.from_records( [(id, ent, 'Entity') for ent, id in entity2id.items()], columns=["id:ID(Entity)", "name:String", ":LABEL"]).to_csv(Config.graph_entity_csv, index=False, encoding='utf_8_sig') # records = [[ entity2id[head_name], entity2id[tail_name], 'Relation', rel_name ] for head_name, rel_name, tail_name in iter_triples( tqdm_prefix='gen relation csv')] del entity2id gc.collect() pd.DataFrame.from_records(records, columns=[ ":START_ID(Entity)", ":END_ID(Entity)", ":TYPE", "name:String" ]).to_csv(Config.graph_relation_csv, index=False, encoding='utf_8_sig')
def create_lac_custom_dict(): """生成自定义分词词典""" logging.info('create_lac_custom_dict start...') ent_counter, rel_counter, mention_counter = _get_top_counter() mention_count = mention_counter.most_common(50 * 10000) # customization_dict = { mention: 'MENTION' for mention, count in mention_count if 2 <= len(mention) <= 8 } del mention_count, mention_counter logging.info('create ent&rel customization_dict ...') _ent_pattrn = re.compile(r'["<](.*?)[>"]') # customization_dict.update({' '.join(_ent_pattrn.findall(rel)): 'REL' # for rel, count in rel_counter.most_common(10000) # 10万 # if 2 <= len(rel) <= 8}) # del rel_counter customization_dict.update({ ' '.join(_ent_pattrn.findall(ent)): 'ENT' for ent, count in ent_counter.most_common(50 * 10000) # 100万 if 2 <= len(ent) <= 8 }) q_entity2id = json_load(Config.q_entity2id_json) q_entity2id.update(json_load(Config.a_entity2id_json)) customization_dict.update({ ' '.join(_ent_pattrn.findall(ent)): 'ENT' for ent, _id in q_entity2id.items() }) del ent_counter with open(Config.lac_custom_dict_txt, 'w') as f: for e, t in customization_dict.items(): if len(e) >= 3: f.write(f"{e}/{t}\n") logging.info('attr_custom_dict gen start ...') entity2attrs = json_load(Config.entity2attrs_json) all_attrs = set() for attrs in entity2attrs.values(): all_attrs.update(attrs) name_patten = re.compile('"(.*?)"') with open(Config.lac_attr_custom_dict_txt, 'w') as f: for _attr in all_attrs: attr = ' '.join(name_patten.findall(_attr)) if len(attr) >= 2: f.write(f"{attr}/ATTR\n")
def lac_test(): from LAC import LAC print('start ...') mention_count = Counter(json_load('mention2count.json')).most_common(10000) customization_dict = { mention: 'MENTION' for mention, count in mention_count if len(mention) >= 2 } print('second ...') ent_count = Counter(json_load('ent2count.json')).most_common(300000) ent_pattrn = re.compile(r'["<](.*?)[>"]') customization_dict.update( {' '.join(ent_pattrn.findall(ent)): 'ENT' for ent, count in ent_count}) with open('./customization_dict.txt', 'w') as f: f.write('\n'.join([ f"{e}/{t}" for e, t in customization_dict.items() if len(e) >= 3 ])) import time before = time.time() print(f'before ...{before}') lac = LAC(mode='lac') lac.load_customization('./customization_dict.txt') # 20万21s;30万47s lac_raw = LAC(mode='lac') print(f'after ...{time.time() - before}') ## test_count = 10 for q, sparql, a in load_data(): q_text = q.split(':')[1] print('---' * 10) print(q_text) print(sparql) print(a) words, tags = lac_raw.run(q_text) print(list(zip(words, tags))) words, tags = lac.run(q_text) print(list(zip(words, tags))) if not test_count: break test_count -= 1 import ipdb ipdb.set_trace()
def load_cache(self): # self.queue = Queue(maxsize=1) if os.path.isfile(Config.neo4j_query_cache): data_dict = json_load(Config.neo4j_query_cache) self._one_hop_relNames_map = data_dict['_one_hop_relNames_map'] self._two_hop_relNames_map = data_dict['_two_hop_relNames_map'] self.total_count = self.get_total_entity_count() logging.info(f'load neo4j_query_cache, total: {self.total_count}') else: logging.info( f'not found neo4j_query_cache: {Config.neo4j_query_cache}') self.total_count = 0
def load_custom_dict(self): if not os.path.isfile(Config.jieba_custom_dict): all_ent_mention_words = set( json_load(Config.mention2count_json).keys()) entity_set = set(json_load(Config.entity2id).keys()) for ent in tqdm(entity_set, total=len(entity_set), desc='create jieba words '): if ent.startswith('<'): ent = ent[1:-1] all_ent_mention_words.add(ent) # FREQ,total, # 模仿jieba.add_word,重写逻辑,加速 # jieba.dt.FREQ = {} # jieba.dt.total = 0 for word in tqdm(all_ent_mention_words, desc='jieba custom create '): freq = len(word) * 3 jieba.dt.FREQ[word] = freq jieba.dt.total += freq for ch in range(len(word)): wfrag = word[:ch + 1] if wfrag not in jieba.dt.FREQ: jieba.dt.FREQ[wfrag] = 0 del all_ent_mention_words gc.collect() json_dump({ 'dt.FREQ': jieba.dt.FREQ, 'dt.total': jieba.dt.total }, Config.jieba_custom_dict) logging.info('create jieba custom dict done ...') # load jieba_custom = json_load(Config.jieba_custom_dict) jieba.dt.check_initialized() jieba.dt.FREQ = jieba_custom['dt.FREQ'] jieba.dt.total = jieba_custom['dt.total'] logging.info('load jieba custom dict done ...')
def _get_top_counter(): """26G,高频实体和mention,作为后期筛选和lac字典; 统计实体和mention出现的次数; 方便取top作为最终自定义分词的词典; """ logging.info('kb_count_top_dict start ...') if not (os.path.isfile(Config.entity2count_json) and os.path.isfile(Config.relation2count_json)): entities = [] relations = [] for head_ent, rel, tail_ent in iter_triples( tqdm_prefix='kb_top_count_dict '): entities.extend([head_ent, tail_ent]) relations.append(rel) ent_counter = Counter(entities) json_dump(dict(ent_counter), Config.entity2count_json) del entities rel_counter = Counter(relations) del relations json_dump(dict(rel_counter), Config.relation2count_json) else: ent_counter = Counter(json_load(Config.entity2count_json)) rel_counter = Counter(json_load(Config.relation2count_json)) # if not os.path.isfile(Config.mention2count_json): mentions = [] for line in tqdm_iter_file( mention2ent_txt, prefix='count_top_dict iter mention2ent.txt '): mention, ent, rank = line.split('\t') # 有部分数据有问题,mention为空字符串 mentions.append(mention) mention_counter = Counter(mentions) del mentions json_dump(dict(mention_counter), Config.mention2count_json) else: mention_counter = Counter(json_load(Config.mention2count_json)) return ent_counter, rel_counter, mention_counter
def data2samples(self, neg_rate=3, test_size=0.1): if os.path.isfile( Config.get_relation_score_sample_csv('train', neg_rate)): return questions = [] sim_questions = [] labels = [] all_relations = list(json_load(Config.relation2id)) _entity_pattern = re.compile(r'["<](.*?)[>"]') for q, sparql, a in load_data(tqdm_prefix='data2samples '): q_text = question_patten.findall(q)[0] q_entities = _entity_pattern.findall(sparql) questions.append(q_text) sim_questions.append('的'.join(q_entities)) labels.append(1) # for neg_relation in random.sample(all_relations, neg_rate): questions.append(q_text) neg_question = '的'.join(q_entities[:-1] + [neg_relation]) # 随机替换 <关系> sim_questions.append(neg_question) labels.append(0) data_df = pd.DataFrame({ 'question': questions, 'sim_question': sim_questions, 'label': labels }) data_df.to_csv(Config.relation_score_sample_csv, encoding='utf_8_sig', index=False) train_df, test_df = train_test_split(data_df, test_size=test_size) test_df.to_csv(Config.get_relation_score_sample_csv('test', neg_rate), encoding='utf_8_sig', index=False) train_df.to_csv(Config.get_relation_score_sample_csv( 'train', neg_rate), encoding='utf_8_sig', index=False)
def __init__(self): self.entity2id = json_load(Config.entity2id) self.mention2entity = json_load(Config.mention2ent_json)