def create_extend_train_file(): print("start create extend train file...") # train_file = open(fileConfig.dir_data + fileConfig.file_train_data, 'r', encoding='utf-8') train_file = com_utils.pickle_load(fileConfig.dir_data + fileConfig.file_train_pkl) test_file = com_utils.pickle_load(fileConfig.dir_data + fileConfig.file_test_pkl) extend_out_file = open(fileConfig.dir_data + fileConfig.file_extend_train_data, 'w', encoding='utf-8') extend_test_file = open(fileConfig.dir_data + fileConfig.file_extend_test_data, 'w', encoding='utf-8') kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) for line in tqdm(train_file, desc='extend train file'): jstr = ujson.loads(line) extend_lines = data_utils.get_extend_ner_train_list(kb_dict, jstr) for lines in extend_lines: if len(lines['text']) > (nerConfig.max_seq_length - 2): continue extend_out_file.write(ujson.dumps(lines, ensure_ascii=False)) extend_out_file.write('\n') for line in tqdm(test_file, desc='extend test file'): jstr = ujson.loads(line) extend_test_file.write(ujson.dumps(jstr, ensure_ascii=False)) extend_test_file.write('\n') # jstr = ujson.loads('{"text_id": "27755", "text": "副军级海军大校何永明担任解放军驻海南某部", "mention_data": [{"kb_id": "NIL", "mention": "副军级", "offset": "0"}, {"kb_id": "346365", "mention": "海军", "offset": "3"}, {"kb_id": "163745", "mention": "大校", "offset": "5"}, {"kb_id": "183299", "mention": "何永明", "offset": "7"}, {"kb_id": "253615", "mention": "担任", "offset": "10"}, {"kb_id": "101210", "mention": ">解放军驻海南某部", "offset": "12"}, {"kb_id": "193906", "mention": "驻", "offset": "15"}, {"kb_id": "155589", "mention": "海南", "offset": "16"}]}') # extend_lines = data_utils.get_extend_ner_train_list(kb_dict, jstr) # for lines in extend_lines: # extend_out_file.write(ujson.dumps(lines, ensure_ascii=False)) # extend_out_file.write('\n') print('success create extend train file')
def create_nel_train_data(): if not os.path.exists(fileConfig.dir_nel): os.mkdir(fileConfig.dir_nel) train_data = open(fileConfig.dir_data + fileConfig.file_train_data, 'r') kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) pd_df = pd.read_csv(fileConfig.dir_kb_info + fileConfig.file_kb_pandas_csv) data_list = [] for line in tqdm(train_data, desc='create entity link train data'): # for line in train_data: jstr = ujson.loads(line) text_id = jstr['text_id'] text = jstr['text'] mention_datas = jstr['mention_data'] for mention_data in mention_datas: kb_id = mention_data['kb_id'] mention = mention_data['mention'] start = mention_data['offset'] end = int(start) + len(mention) - 1 kb_entity = kb_dict.get(kb_id) if kb_entity is not None: entity_cands, entity_ids, entity_text = data_utils.get_entity_cands(kb_entity, kb_id, pd_df) else: continue data_list.append({'text_id': text_id, 'mention_text': text, 'mention': mention, 'mention_position': [start, end], 'entity_cands': entity_cands, 'entity_text': entity_text, 'entity_ids': entity_ids}) com_utils.pickle_save(data_list, fileConfig.dir_nel + fileConfig.file_nel_entity_link_train_data) print("success create nel entity link train data")
def create_fasttext_sup_train_data(index, train_data_file, kb_dict_file, kb_alia_file, stopword_file, out_file, mode=fasttextConfig.create_data_word): print("create {} sup train data".format(index)) kb_alias_df = pd.read_csv(kb_alia_file) stopwords = data_utils.get_stopword_list(stopword_file) train_datas = open(train_data_file, 'r', encoding='utf-8').readlines() kb_dict = com_utils.pickle_load(kb_dict_file) train_out_file = open(out_file, 'w', encoding='utf-8') text_ids = {} max_extend_countd = 3 for line in tqdm(train_datas, desc='deal {} train file'.format(index)): jstr = ujson.loads(line) text = jstr['text'] text_id = jstr['text_id'] if text_ids.get(text_id) == max_extend_countd: continue mentions = jstr['mention_data'] for mention in mentions: mention_id = mention['kb_id'] mention_text = mention['mention'] neighbor_text = com_utils.get_neighbor_sentence(text, mention_text) # true values kb_entity = kb_dict.get(mention_id) if kb_entity is not None: out_str = com_utils.get_entity_mention_pair_text(kb_entity['text'], neighbor_text, stopwords, cut_client, fasttextConfig.label_true, mode) train_out_file.write(out_str) # false values alia_ids = [] alia_count = 0 alias_df = kb_alias_df[kb_alias_df['subject'] == com_utils.cht_to_chs(mention_text)] for _, item in alias_df.iterrows(): a_id = str(item['subject_id']) if a_id != mention_id: alia_ids.append(a_id) alia_count += 1 if alia_count == max_extend_countd: break if len(alia_ids) > 0: for alia_id in alia_ids: alia_entity = kb_dict.get(alia_id) if alia_entity is not None: out_str = com_utils.get_entity_mention_pair_text(alia_entity['text'], neighbor_text, stopwords, cut_client, fasttextConfig.label_false, mode) train_out_file.write(out_str) # add text text_ids = com_utils.dict_add(text_ids, text_id) # 清理资源 train_out_file.close() train_datas = None train_out_file = None kb_alias_df = None stopwords = None kb_dict = None
def analysis_train_data(): print("start use the fasttext model to predict test data") if not os.path.exists(fileConfig.dir_analysis): os.mkdir(fileConfig.dir_analysis) kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) train_file = open(fileConfig.dir_data + fileConfig.file_train_data, 'r', encoding='utf-8') out_file = open(fileConfig.dir_analysis + fileConfig.file_analysis_train_untind, 'w', encoding='utf-8') count = 1 for line in tqdm(train_file, 'find unmatch'): jstr = ujson.loads(line) mention_data = jstr['mention_data'] for mention in mention_data: mention_text = mention['mention'] mention_id = mention['kb_id'] kb_entity = kb_dict.get(mention_id) is_match = False if kb_entity is not None: kb_subject = kb_entity['subject'] kb_alias = kb_entity['alias'] if kb_subject == mention_text: is_match = True if not is_match: for alia in kb_alias: if alia == mention_text: is_match = True if not is_match: out_file.write('-' * 20) out_file.write('\n') out_file.write("num:{}--text_id:{}--text:{}".format( count, jstr['text_id'], jstr['text'])) out_file.write('\n') out_file.write("not match:") out_file.write('\n') out_file.write('*' * 20) out_file.write('\n') out_file.write('mention_original: {}'.format( ujson.dumps(mention, ensure_ascii=False))) out_file.write('\n') out_file.write("kb: {}".format( 'subject:{} alias:{}'.format(kb_entity['subject'], kb_entity['alias']))) out_file.write('\n') out_file.write('*' * 20) out_file.write('\n') count += 1 train_file.close() out_file.close() print("success analysis train file find miss match entities")
def create_pandas_kb_alias_data(): kb_file = open(fileConfig.dir_data + fileConfig.file_kb_data, 'r', encoding='utf-8') train_file = open(fileConfig.dir_data + fileConfig.file_train_data, 'r', encoding='utf-8') kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) subject_id_list = [] subject_list = [] subjects = {} # from kb file for line in tqdm(kb_file, desc='deal kb_file'): jstr = ujson.loads(line) subject_id = jstr['subject_id'] subject = com_utils.cht_to_chs(jstr['subject'].strip().lower()) subject_id_list.append(subject_id) subject_list.append(subject) alias = jstr['alias'] subjects[subject] = 1 for alia in alias: alia_str = com_utils.cht_to_chs(alia.strip().lower()) if subjects.get(alia_str) is not None: continue else: subjects[alia_str] = 1 subject_id_list.append(subject_id) subject_list.append(alia_str) # from train file for line in tqdm(train_file, desc='deal train file'): jstr = ujson.loads(line) mention_data = jstr['mention_data'] for mention in mention_data: mention_text = mention['mention'] mention_text = com_utils.cht_to_chs(mention_text.lower()) kb_id = mention['kb_id'] kb_entity = kb_dict.get(kb_id) is_match = False if kb_entity is not None: kb_subject = kb_entity['subject'] kb_alias = kb_entity['alias'] if kb_subject == mention_text: is_match = True if not is_match: for alia in kb_alias: if alia == mention_text: is_match = True if not is_match: if subjects.get(mention_text) is not None: continue else: subjects[mention_text] = 1 subject_id_list.append(kb_id) subject_list.append(mention_text) pandas_dict = {'subject_id': subject_id_list, 'subject': subject_list} df = pd.DataFrame.from_dict(pandas_dict) df.to_csv(fileConfig.dir_kb_info + fileConfig.file_kb_pandas_alias_data) print("success create pandas kb alia data file")
def create_nel_vocab(): # create mention entity vocab train_file = open(fileConfig.dir_data + fileConfig.file_train_data, mode='r', encoding='utf-8') dev_file = open(fileConfig.dir_data + fileConfig.file_dev_data, mode='r', encoding='utf-8') out_file = open(fileConfig.dir_nel + fileConfig.file_nel_mention_context_vocab, mode='w', encoding='utf-8') vocab_dict = {} for line in tqdm(train_file, desc='deal train file'): jstr = ujson.loads(line) text = jstr['text'] words = text.strip() for word in words: if word not in vocab_dict: vocab_dict[word] = 1 else: vocab_dict[word] += 1 for line in tqdm(dev_file, desc='deal dev file'): jstr = ujson.loads(line) text = jstr['text'] words = text.strip() for word in words: if word not in vocab_dict: vocab_dict[word] = 1 else: vocab_dict[word] += 1 out_file.write('[PAD]\n') out_file.write('[UNK]\n') vocab_length = len(vocab_dict) for i, item in enumerate(Counter(vocab_dict).most_common()): out_file.write(item[0] + '\n') if i < vocab_length - 1 else out_file.write(item[0]) out_file.close() print("success create mention entity vocab data") # create entity context vocab entity_dict = set() text_dict = defaultdict(int) kb_data = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) for key, value in kb_data.items(): text = value['text'] subject = key if subject not in entity_dict: entity_dict.add(subject) else: raise Exception(f'entity : {subject} duplicated!!!') for t in text: text_dict[t] += 1 entity_vocab = open(fileConfig.dir_nel + fileConfig.file_nel_entity_vocab, 'w', encoding='utf-8') entity_context_vocab = open(fileConfig.dir_nel + fileConfig.file_nel_entity_context_vocab, 'w', encoding='utf-8') for entity in entity_dict: entity_vocab.write(entity + '\n') entity_context_vocab.write('[PAD]\n') entity_context_vocab.write('[UNK]\n') for token, value in Counter(text_dict).most_common(): entity_context_vocab.write(token + '\n') print("success create entity context vocab / entity vacab data")
def split_train_data(train_file_path=None, out_train_file=None, out_dev_file=None, is_split=True): data_list = com_utils.pickle_load(train_file_path) if not is_split: dev_list = com_utils.pickle_load(fileConfig.dir_ner + fileConfig.file_extend_ner_dev_data) data_len = len(data_list) # train_size = int(data_len * comConfig.train_ratio) test_size = 10000 random.seed(comConfig.random_seed) random.shuffle(data_list) # train_data = data_list[:train_size] if is_split: train_data = data_list[:data_len - test_size] dev_data = data_list[data_len - test_size:data_len] com_utils.pickle_save(train_data, out_train_file) com_utils.pickle_save(dev_data, out_dev_file) else: train_data = data_list com_utils.pickle_save(train_data, out_train_file) com_utils.pickle_save(dev_list, out_dev_file) print("success split data set")
def analysis_test_result(): print("start analysis test result...") test_result_file = open(fileConfig.dir_result + fileConfig.file_result_fasttext_test, 'r', encoding='utf-8') kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) error_results = [] more_dict = {} for line in tqdm(test_result_file, 'read from file'): jstr = ujson.loads(line) gen_mentions = jstr['mention_data'] original_mentions = jstr['mention_data_original'] error_list = get_result_error_list(gen_mentions, original_mentions, more_dict) if (len(error_list)) > 0: error_results.append({ 'text_id': jstr['text_id'], 'text': jstr['text'], 'errors': error_list }) test_result_file.close() test_result_file = None out_file = open(fileConfig.dir_result + fileConfig.file_result_fasttext_test_analysis, 'w', encoding='utf-8') for item in tqdm(error_results, 'write result'): out_file.write('-' * 20) out_file.write('\n') out_file.write("text_id:{}--text:{}".format(item['text_id'], item['text'])) out_file.write('\n') out_file.write("errors:") out_file.write('\n') for error in item['errors']: out_file.write('*' * 20) out_file.write('\n') out_file.write('error type:{}'.format( get_error_type(error['error_type']))) out_file.write('\n') out_file.write(get_error_content(kb_dict, error)) out_file.write('*' * 20) out_file.write('\n') out_file.write('-' * 20) out_file.write('\n') out_more = open(fileConfig.dir_result + fileConfig.file_analysis_gen_more, 'w', encoding='utf-8') for item in tqdm(Counter(more_dict).most_common(), 'write more'): out_more.write(item[0]) out_more.write('\n')
def split_eval_mention(num): dev_mention_data = com_utils.pickle_load( fileConfig.dir_ner + fileConfig.file_ner_eval_mention_data) data_len = len(dev_mention_data) block_size = data_len / num for i in range(1, num + 1): data_iter = dev_mention_data[int((i - 1) * block_size):int(i * block_size)] com_utils.pickle_save( data_iter, fileConfig.dir_ner_split + fileConfig.file_ner_eval_mention_split.format(i)) print("success split test mention to:{} files".format(num))
def create_eval_cands_from_split(num): out_file = open(fileConfig.dir_ner + fileConfig.file_ner_eval_cands_data, 'w', encoding='utf-8') for i in range(1, num + 1): datas = com_utils.pickle_load( fileConfig.dir_ner_split + fileConfig.file_ner_eval_cands_split.format(i)) for line in datas: text = ujson.dumps(line, ensure_ascii=False) out_file.write(text) out_file.write('\n') print("merge eval cands data success!")
def train(): datas = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) vectorizer = TfidfVectorizer() train_sentence = [] print("prepare train data") for key, data in tqdm(datas.items(), desc='init train data'): train_sentence.append(' '.join(cut_client.cut_text(data['text']))) print("start train tfidf model") X = vectorizer.fit_transform(train_sentence) print("save model and keyword") tfidf_save_data = [X, vectorizer] if not os.path.exists(fileConfig.dir_tfidf): os.mkdir(fileConfig.dir_tfidf) com_utils.pickle_save(tfidf_save_data, fileConfig.dir_tfidf + fileConfig.file_tfidf_save_data) print("success train and save tfidf file")
def test(): print("start test the tfidf model") tfidf_data = com_utils.pickle_load(fileConfig.dir_tfidf + fileConfig.file_tfidf_save_data) vectorizer = tfidf_data[1] X = tfidf_data[0] # init test data ratio = 0.01 print("init test datas use ratio:{}".format(ratio)) test_datas = [] train_file = open(fileConfig.dir_data + fileConfig.file_train_data, 'r', encoding='utf-8') for line in train_file: test_datas.append(ujson.loads(line)) test_data_len = int(len(test_datas) * ratio) random.seed(comConfig.random_seed) random.shuffle(test_datas) test_datas = test_datas[:test_data_len] mentions = [] for data in test_datas: mention_datas = data['mention_data'] for mention in mention_datas: if mention['kb_id'] != 'NIL': mention_copy = mention.copy() mention_copy['sentence'] = data['text'] mentions.append(mention_copy) # start test model print("start find mention") y_pred = [] y_true = [] for mention in tqdm(mentions, desc='find mention'): # for mention in mentions: y_true.append(int(mention['kb_id'])) text = mention['mention'] text_len = len(text) sentence = mention['sentence'] for i in range(len(sentence) - text_len + 1): if sentence[i:i + text_len] == text: if i > 10 and i + text_len < len(sentence) - 9: neighbor_sentence = sentence[i - 10:i + text_len + 9] elif i < 10: neighbor_sentence = sentence[:20] elif i + text_len > len(sentence) - 9: neighbor_sentence = sentence[-20:] kb_id = get_entityid(neighbor_sentence, vectorizer, X) y_pred.append(kb_id) break # calc the f1 acc, f1 = acc_f1(y_pred, y_true) print("acc:{:.4f} f1:{:.4f}".format(acc, f1))
def gen_simi_subject_list(file_path): print('start gen similar subject...') file_datas = com_utils.pickle_load(file_path) gensim_model = word2vec.Word2VecKeyedVectors.load( fileConfig.dir_fasttext + fileConfig.file_gensim_tencent_unsup_model) for item in tqdm(file_datas, 'gen simi subject'): mention_data = item['mention_data'] for mention in mention_data: mention_text = mention['mention'] try: mention['gen_subjects'] = get_simi_subject_list( gensim_model.most_similar(positive=[mention_text], topn=5)) except BaseException: mention['gen_subjects'] = [] com_utils.pickle_save(file_datas, file_path) print('success gen similar subject...')
def create_nel_batch_iter(mode, entity_context_vocab, mention_context_vocab, entity_vocab): if mode == nelConfig.mode_train: datas = com_utils.pickle_load( fileConfig.dir_nel + fileConfig.file_nel_entity_link_train_data) data_len = len(datas) train_size = int(data_len * comConfig.train_ratio) random.seed(comConfig.random_seed) random.shuffle(datas) train_data = datas[:train_size] dev_data = datas[train_size:] print("create nel data train len:{} dev len:{}".format( len(train_data), len(dev_data))) train_dataset = DataSet(train_data, transform=transforms.Compose([ NELToTensor(entity_context_vocab, mention_context_vocab, entity_vocab) ])) dev_dataset = DataSet(dev_data, transform=transforms.Compose([ NELToTensor(entity_context_vocab, mention_context_vocab, entity_vocab) ])) train_iter = DataLoader(train_dataset, num_workers=4, batch_size=nelConfig.batch_size, shuffle=True, collate_fn=collate_fn_entity_link) dev_iter = DataLoader(dev_dataset, num_workers=4, batch_size=nelConfig.batch_size, shuffle=True, collate_fn=collate_fn_entity_link) return train_iter, dev_iter
def analysis_test_ner_result(): print("start analysis test ner result...") ner_test_datas = com_utils.pickle_load( fileConfig.dir_ner + fileConfig.file_ner_test_predict_tag) jieba_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_jieba_kb) out_file = open(fileConfig.dir_result + fileConfig.file_ner_test_result_analysis, 'w', encoding='utf-8') stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword + fileConfig.file_stopword) gen_more_words = data_utils.get_stopword_list( fileConfig.dir_stopword + fileConfig.file_analysis_gen_more) text_id = 1 for data in tqdm(ner_test_datas, 'find entity'): text = ''.join(data['text']) tag_list = data['tag'] start_index = 0 mention_length = 0 is_find = False mentions = [] type_dict = {} # use tag find for i, tag in enumerate(tag_list): # if tag == nerConfig.B_seg + nerConfig.KB_seg: if tag.find(nerConfig.B_seg) > -1 or ( tag.find(nerConfig.I_seg) > -1 and not is_find): type_str = tag.split('_')[1] type_dict = com_utils.dict_add(type_dict, type_str) start_index = i mention_length += 1 is_find = True elif tag.find(nerConfig.E_seg) > -1 and not is_find: type_str = tag.split('_')[1] type_dict = com_utils.dict_add(type_dict, type_str) start_index = i mention_length += 1 mention = text[start_index:start_index + mention_length] # mention = data_utils.strip_punctuation(mention) type_list = Counter(type_dict).most_common() mentions.append({ 'T': 'NER', 'mention': mention, 'offset': str(start_index), 'type': type_list[0][0] }) is_find = False mention_length = 0 type_dict = {} # elif tag == nerConfig.I_seg + nerConfig.KB_seg and is_find: elif tag.find(nerConfig.I_seg) > -1 and is_find: type_str = tag.split('_')[1] type_dict = com_utils.dict_add(type_dict, type_str) mention_length += 1 # elif tag == nerConfig.E_seg + nerConfig.KB_seg and is_find: elif tag.find(nerConfig.E_seg) > -1 and is_find: type_str = tag.split('_')[1] type_dict = com_utils.dict_add(type_dict, type_str) mention_length += 1 mention = text[start_index:start_index + mention_length] # mention = data_utils.strip_punctuation(mention) type_list = Counter(type_dict).most_common() mentions.append({ 'T': 'NER', 'mention': mention, 'offset': str(start_index), 'type': type_list[0][0] }) is_find = False mention_length = 0 type_dict = {} elif tag == nerConfig.O_seg: is_find = False mention_length = 0 type_dict = {} # use jieba find jieba_entities = cut_client.cut_text(text) for i, tag in enumerate(tag_list): # if tag == nerConfig.B_seg + nerConfig.KB_seg or tag == nerConfig.I_seg + nerConfig.KB_seg or tag == nerConfig.E_seg + nerConfig.KB_seg: # if tag.find(nerConfig.B_seg) > -1 or tag.find(nerConfig.I_seg) > -1 or tag.find(nerConfig.E_seg) > -1: jieba_offset = i jieba_char = text[i] jieba_text = get_jieba_mention(jieba_entities, jieba_char) if jieba_text is None: continue elif jieba_text == '_' or jieba_text == '-': continue elif len(jieba_text) == 1: continue elif stopwords.get(jieba_text) is not None: continue elif gen_more_words.get(jieba_text) is not None: continue jieba_offset = jieba_offset - jieba_text.find(jieba_char) if len(jieba_text) <= comConfig.max_jieba_cut_len and ( jieba_dict.get(jieba_text) is not None): type_str = tag.split('_')[1] if tag.find('_') > -1 else 'O' if jieba_text is None: continue if not is_already_find_mention(mentions, jieba_text, jieba_offset): # jieba_offset = text.find(jieba_text) mentions.append({ 'T': 'JIEBA', 'mention': jieba_text, 'offset': str(jieba_offset), 'type': type_str }) # find inner brackets mentions bracket_mentions = data_utils.get_mention_inner_brackets( text, tag_list) for mention in bracket_mentions: mention['T'] = 'bracket' if len(bracket_mentions) > 0: mentions += bracket_mentions # completion mentions # mentions_com = [] # for mention in mentions: # mention_str = mention['mention'] # try: # for find in re.finditer(mention_str, text): # find_offset = find.span()[0] # if find_offset != int(mention['offset']): # mentions_com.append( # {'T': 'COM', 'mention': mention['mention'], 'offset': str(find_offset), # 'type': mention['type']}) # except BaseException: # # print("occur error when match mention str in completion mentions, error value:{} text:{}".format( # # mention_str, text)) # pass # mentions_com.append(mention) # mentions = mentions_com # completion mentions out_file.write('\n') result_str = '' for i in range(len(text)): result_str += text[i] + '-' + tag_list[i] + ' ' out_file.write(' text_id:{}, text:{} '.format(text_id, result_str)) out_file.write('\n') out_file.write(' gen_mentions:{} '.format( ujson.dumps(mentions, ensure_ascii=False))) out_file.write('\n') text_id += 1
def get_train_examples(self): lines = com_utils.pickle_load(fileConfig.dir_ner + fileConfig.file_ner_train_data) examples = self._create_example(lines, nerConfig.mode_train) return examples
def test_sup(mode=fasttextConfig.create_data_word): print("start use the fasttext model/supervise model to predict test data") if not os.path.exists(fileConfig.dir_result): os.mkdir(fileConfig.dir_result) unsup_model_fasttext = fastText.load_model( fileConfig.dir_fasttext + fileConfig.file_fasttext_model.format(fasttextConfig.choose_model)) unsup_model_gensim = word2vec.Word2VecKeyedVectors.load( fileConfig.dir_fasttext + fileConfig.file_gensim_tencent_unsup_model) sup_model = fastText.load_model(fileConfig.dir_fasttext + fileConfig.file_fasttext_sup_word_model) stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword + fileConfig.file_stopword) kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) dev_file = open(fileConfig.dir_ner + fileConfig.file_ner_test_cands_data, 'r', encoding='utf-8') out_file = open(fileConfig.dir_result + fileConfig.file_result_fasttext_test, 'w', encoding='utf-8') # f1 parmas gen_mention_count = 0 original_mention_count = 0 correct_mention_count = 0 # count = 0 # entity diambiguation for line in tqdm(dev_file, 'entity diambiguation'): # count += 1 # if count < 3456: # continue jstr = ujson.loads(line) dev_entity = {} text = com_utils.cht_to_chs(jstr['text'].lower()) dev_entity['text_id'] = jstr['text_id'] dev_entity['text'] = jstr['text'] mention_data = jstr['mention_data'] original_mention_data = jstr['mention_data_original'] mentions = [] for mention in mention_data: mention_text = mention['mention'] if mention_text is None: continue cands = mention['cands'] if len(cands) == 0: continue # use supervised model to choose mention supervise_cands = [] for cand in cands: neighbor_text = com_utils.get_neighbor_sentence( text, com_utils.cht_to_chs(mention_text.lower())) cand_entity = kb_dict.get(cand['cand_id']) if cand_entity is not None: out_str = com_utils.get_entity_mention_pair_text( com_utils.cht_to_chs(cand_entity['text'].lower()), neighbor_text, stopwords, cut_client, mode=mode) # print(out_str) result = sup_model.predict(out_str.replace('\n', ' '))[0][0] if result == fasttextConfig.label_true: supervise_cands.append(cand) # unsupervise model choose item max_cand = None if len(supervise_cands) == 0: supervise_cands = cands # score list score_list = [] mention_neighbor_sentence = text for i, cand in enumerate(supervise_cands): # score_fasttext = fasttext_get_sim(unsup_model_fasttext, mention_neighbor_sentence, # com_utils.cht_to_chs(cand['cand_text'].lower()), stopwords) score_gensim = gensim_get_sim( unsup_model_gensim, mention_neighbor_sentence, com_utils.cht_to_chs(cand['cand_text'].lower()), stopwords) # score = (0.8 * score_gensim) + (0.2 * score_fasttext) score = score_gensim # if score > max_score: # max_score = score # max_index = score if score < fasttextConfig.min_entity_similarity_threshold: continue score_list.append({ 'cand_id': cand['cand_id'], 'cand_score': score, 'cand_type': cand['cand_type'] }) # if max_score < fasttextConfig.min_entity_similarity_threshold: # continue # find the best cand # find_type = False score_list.sort(key=get_socre_key, reverse=True) # for item in score_list: # if item['cand_type'] == mention['type']: # find_type = True # if find_type: # for item in score_list: # if item['cand_score'] > fasttextConfig.choose_entity_similarity_threshold: # max_cand = item if max_cand is None: if len(score_list) > 0: max_cand = score_list[0] # find the best cand if max_cand is not None: mentions.append({ 'kb_id': max_cand['cand_id'], 'mention': mention['mention'], 'offset': mention['offset'] }) # optim mentions delete_mentions = [] mentions.sort(key=get_mention_len) for mention in mentions: mention_offset = int(mention['offset']) mention_len = len(mention['mention']) for sub_mention in mentions: if mention_offset != int(sub_mention['offset']) and int( sub_mention['offset']) in range( mention_offset, mention_offset + mention_len): if not data_utils.is_mention_already_in_list( delete_mentions, sub_mention): delete_mentions.append(sub_mention) if mention_offset == int(sub_mention['offset']) and len( mention['mention']) > len(sub_mention['mention']): if not data_utils.is_mention_already_in_list( delete_mentions, sub_mention): delete_mentions.append(sub_mention) if len(delete_mentions) > 0: change_mentions = [] for mention in mentions: if not data_utils.is_mention_already_in_list( delete_mentions, mention): change_mentions.append(mention) mentions = change_mentions change_mentions = [] for mention in mentions: if not data_utils.is_mention_already_in_list( change_mentions, mention ) and mention['mention'] not in comConfig.punctuation: change_mentions.append(mention) mentions = change_mentions mentions.sort(key=get_mention_offset) # optim mentions # calc f1 for mention in mentions: if is_find_correct_entity(mention['kb_id'], original_mention_data): correct_mention_count += 1 gen_mention_count += len(mentions) for orginal_mention in original_mention_data: if orginal_mention['kb_id'] != 'NIL': original_mention_count += 1 # out result dev_entity['mention_data'] = mentions dev_entity['mention_data_original'] = original_mention_data out_file.write(ujson.dumps(dev_entity, ensure_ascii=False)) out_file.write('\n') precision = correct_mention_count / gen_mention_count recall = correct_mention_count / original_mention_count f1 = 2 * precision * recall / (precision + recall) print("success create test result, p:{:.4f} r:{:.4f} f1:{:.4f}".format( precision, recall, f1))
def test(): print("start use the fasttext model to predict test data") if not os.path.exists(fileConfig.dir_result): os.mkdir(fileConfig.dir_result) model = fastText.load_model( fileConfig.dir_fasttext + fileConfig.file_fasttext_model.format(fasttextConfig.choose_model)) stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword + fileConfig.file_stopword) kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) dev_file = open(fileConfig.dir_ner + fileConfig.file_ner_test_cands_data, 'r', encoding='utf-8') out_file = open(fileConfig.dir_result + fileConfig.file_result_fasttext_test, 'w', encoding='utf-8') # f1 parmas gen_mention_count = 0 original_mention_count = 0 correct_mention_count = 0 # entity diambiguation for line in tqdm(dev_file, 'entity diambiguation'): jstr = ujson.loads(line) dev_entity = {} text = jstr['text'] dev_entity['text_id'] = jstr['text_id'] dev_entity['text'] = jstr['text'] mention_data = jstr['mention_data'] original_mention_data = jstr['mention_data_original'] mentions = [] for mention in mention_data: mention_text = mention['mention'] cands = mention['cands'] if len(cands) == 0: continue # if len(cands) == 1: # mentions.append( # {'kb_id': str(cands[0]['cand_id']), 'mention': mention['mention'], # 'offset': str(mention['offset'])}) # continue max_index = 0 max_score = 0.0 max_cand = None # mention_neighbor_sentence = get_neighbor_sentence(text, mention_text) # score list score_list = [] mention_neighbor_sentence = text for i, cand in enumerate(cands): score = fasttext_get_sim(model, mention_neighbor_sentence, cand['cand_text'], stopwords) # if score > max_score: # max_score = score # max_index = i if score < fasttextConfig.min_entity_similarity_threshold: continue score_list.append({ 'cand_id': cand['cand_id'], 'cand_score': score, 'cand_type': cand['cand_type'] }) # if max_score < fasttextConfig.min_entity_similarity_threshold: # continue # find the best cand find_type = False score_list.sort(key=get_socre_key, reverse=True) for item in score_list: if item['cand_type'] == mention['type']: find_type = True if find_type: for item in score_list: if item['cand_score'] > fasttextConfig.choose_entity_similarity_threshold: max_cand = item if max_cand is None: if len(score_list) > 0: max_cand = score_list[0] # find the best cand if max_cand is not None: if is_find_correct_entity(max_cand['cand_id'], original_mention_data): correct_mention_count += 1 mentions.append({ 'kb_id': max_cand['cand_id'], 'mention': mention['mention'], 'offset': mention['offset'] }) # calc f1 params gen_mention_count += len(mentions) original_mention_count += len(original_mention_data) dev_entity['mention_data'] = mentions dev_entity['mention_data_original'] = original_mention_data out_file.write('-' * 20) out_file.write('\n') out_file.write("text_id:{}--text:{}".format(dev_entity['text_id'], dev_entity['text'])) out_file.write('\n') out_file.write("mention_data:") out_file.write('\n') # generate mention for mention in dev_entity['mention_data']: kb_mention = '' if mention['kb_id'] != 'NIL': kb_mention = ujson.dumps(kb_dict[mention['kb_id']], ensure_ascii=False) out_file.write('*' * 20) out_file.write('\n') out_file.write('mention_original: {}'.format(mention)) out_file.write('\n') out_file.write("kb: {}".format(kb_mention)) out_file.write('\n') out_file.write('*' * 20) out_file.write('\n') # original mention out_file.write("kb_data:") out_file.write('\n') for mention in dev_entity['mention_data_original']: kb_mention = '' if mention['kb_id'] != 'NIL': kb_mention = ujson.dumps(kb_dict[mention['kb_id']], ensure_ascii=False) out_file.write('*' * 20) out_file.write('\n') out_file.write('kb_original: {}'.format(mention)) out_file.write('\n') out_file.write("kb: {}".format(kb_mention)) out_file.write('\n') out_file.write('*' * 20) out_file.write('\n') out_file.write('-' * 20) out_file.write('\n') precision = correct_mention_count / gen_mention_count recall = correct_mention_count / original_mention_count f1 = 2 * precision * recall / (precision + recall) print("success create test result, p:{:.4f} r:{:.4f} f1:{:.4f}".format( precision, recall, f1))
def eval_sup(mode=fasttextConfig.create_data_word): print("start use the fasttext/supervised model to predict eval data") if not os.path.exists(fileConfig.dir_result): os.mkdir(fileConfig.dir_result) # unsup_model = fastText.load_model( # fileConfig.dir_fasttext + fileConfig.file_fasttext_model.format(fasttextConfig.model_skipgram)) unsup_model = word2vec.Word2VecKeyedVectors.load( fileConfig.dir_fasttext + fileConfig.file_gensim_tencent_unsup_model) sup_model = fastText.load_model(fileConfig.dir_fasttext + fileConfig.file_fasttext_sup_word_model) kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict) stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword + fileConfig.file_stopword) dev_file = open(fileConfig.dir_ner + fileConfig.file_ner_eval_cands_data, 'r', encoding='utf-8') out_file = open(fileConfig.dir_result + fileConfig.file_result_eval_data, 'w', encoding='utf-8') # entity diambiguation for line in tqdm(dev_file, 'entity diambiguation'): if len(line.strip('\n')) == 0: continue jstr = ujson.loads(line) dev_entity = {} text = com_utils.cht_to_chs(jstr['text'].lower()) dev_entity['text_id'] = jstr['text_id'] dev_entity['text'] = jstr['text'] mention_data = jstr['mention_data'] mentions = [] for mention in mention_data: mention_text = mention['mention'] if mention_text is None: continue cands = mention['cands'] if len(cands) == 0: continue # use supervised model to choose mention supervise_cands = [] for cand in cands: neighbor_text = com_utils.get_neighbor_sentence( text, com_utils.cht_to_chs(mention_text.lower())) cand_entity = kb_dict.get(cand['cand_id']) if cand_entity is not None: out_str = com_utils.get_entity_mention_pair_text( com_utils.cht_to_chs(cand_entity['text'].lower()), neighbor_text, stopwords, cut_client, mode=mode) result = sup_model.predict(out_str.strip('\n'))[0][0] if result == fasttextConfig.label_true: supervise_cands.append(cand) if len(supervise_cands) == 0: supervise_cands = cands # unsupervise model choose item max_cand = None # score list score_list = [] mention_neighbor_sentence = text for i, cand in enumerate(supervise_cands): # score = fasttext_get_sim(unsup_model, mention_neighbor_sentence, # com_utils.cht_to_chs(cand['cand_text'].lower()), stopwords) score = gensim_get_sim( unsup_model, mention_neighbor_sentence, com_utils.cht_to_chs(cand['cand_text'].lower()), stopwords) if score < fasttextConfig.min_entity_similarity_threshold: continue score_list.append({ 'cand_id': cand['cand_id'], 'cand_score': score, 'cand_type': cand['cand_type'] }) score_list.sort(key=get_socre_key, reverse=True) if len(score_list) > 0: max_cand = score_list[0] # find the best cand if max_cand is not None: mentions.append({ 'kb_id': max_cand['cand_id'], 'mention': mention['mention'], 'offset': mention['offset'] }) # optim mentions delete_mentions = [] mentions.sort(key=get_mention_len) for optim_mention in mentions: mention_offset = int(optim_mention['offset']) mention_len = len(optim_mention['mention']) for sub_mention in mentions: if mention_offset != int(sub_mention['offset']) and int( sub_mention['offset']) in range( mention_offset, mention_offset + mention_len): if not data_utils.is_mention_already_in_list( delete_mentions, sub_mention): delete_mentions.append(sub_mention) if len(delete_mentions) > 0: change_mentions = [] for optim_mention in mentions: if not data_utils.is_mention_already_in_list( delete_mentions, optim_mention): change_mentions.append(optim_mention) mentions = change_mentions change_mentions = [] for optim_mention in mentions: if not data_utils.is_mention_already_in_list( change_mentions, optim_mention ) and optim_mention['mention'] not in comConfig.punctuation: change_mentions.append(optim_mention) mentions = change_mentions mentions.sort(key=get_mention_offset) dev_entity['mention_data'] = mentions out_file.write(ujson.dumps(dev_entity, ensure_ascii=False)) out_file.write('\n') print("success create supervised eval result")
def get_extend_dev_examples(self): lines = com_utils.pickle_load(fileConfig.dir_ner + fileConfig.file_ner_extend_dev_data) examples = self._create_example(lines, nerConfig.mode_dev) return examples
def create_dev_mention_cands_data(index, mention_file, pd_file, alia_kb_df, out_file): print("start create {} mention cands".format(index)) dev_mention_data = com_utils.pickle_load(mention_file) print("{} data length is {}".format(index, len(dev_mention_data))) pd_df = pandas.read_csv(pd_file) alia_kb_df = pandas.read_csv(alia_kb_df) alia_kb_df.fillna('') count = 0 for dev_data in tqdm(dev_mention_data, desc='find {} cands'.format(index)): # count += 1 # if (count < 465): # continue mention_data = dev_data['mention_data'] for mention in mention_data: mention_text = mention['mention'] if mention_text is None: continue cands = [] cand_ids = {} # match orginal mention_text_proc = com_utils.cht_to_chs(mention_text.lower()) mention_text_proc = com_utils.complete_brankets(mention_text_proc) # print(mention_text_proc) mention_text_proc_extend = mention_text_proc[ 0:len(mention_text_proc) - 1] subject_df = data_utils.pandas_query(pd_df, 'subject', mention_text_proc) for _, item in subject_df.iterrows(): s_id = str(item['subject_id']) if cand_ids.get(s_id) is not None: continue cand_ids[s_id] = 1 subject = item['subject'] # text = data_utils.get_text(ast.literal_eval(item['data']), item['subject']) text = data_utils.get_all_text(item['subject'], ast.literal_eval(item['data'])) cands.append({ 'cand_id': s_id, 'cand_subject': subject, 'cand_text': text, 'cand_type': com_utils.get_kb_type(ast.literal_eval(item['type'])) }) # match more # subject_df = data_utils.pandas_query(pd_df, 'subject', mention_text_proc_extend) # for _, item in subject_df.iterrows(): # s_id = str(item['subject_id']) # if cand_ids.get(s_id) is not None: # continue # cand_ids[s_id] = 1 # subject = item['subject'] # # text = data_utils.get_text(ast.literal_eval(item['data']), item['subject']) # text = data_utils.get_all_text(item['subject'], ast.literal_eval(item['data'])) # cands.append({'cand_id': s_id, 'cand_subject': subject, 'cand_text': text, # 'cand_type': com_utils.get_kb_type(ast.literal_eval(item['type']))}) # match alias alias_subject_ids = [] # match orginal alias_df = data_utils.pandas_query(alia_kb_df, 'subject', mention_text_proc) for _, item in alias_df.iterrows(): a_id = str(item['subject_id']) if alias_subject_ids.__contains__(a_id): continue alias_subject_ids.append(a_id) # match more # alias_df = data_utils.pandas_query(alia_kb_df, 'subject', mention_text_proc_extend) # for _, item in alias_df.iterrows(): # a_id = str(item['subject_id']) # if alias_subject_ids.__contains__(a_id): # continue # alias_subject_ids.append(a_id) for alia_id in alias_subject_ids: alias_df = pd_df[pd_df['subject_id'] == int(alia_id)] for _, item in alias_df.iterrows(): b_id = str(item['subject_id']) if cand_ids.get(b_id) is not None: continue cand_ids[b_id] = 1 subject = item['subject'] # text = data_utils.get_text(ast.literal_eval(item['data']), item['subject']) text = data_utils.get_all_text( item['subject'], ast.literal_eval(item['data'])) cands.append({ 'cand_id': b_id, 'cand_subject': subject, 'cand_text': text, 'cand_type': com_utils.get_kb_type(ast.literal_eval(item['type'])) }) # match gen subject # gen_subject_ids = [] # for gen_subject in mention['gen_subjects']: # gen_text = com_utils.cht_to_chs(gen_subject.lower()) # alias_df = alia_kb_df[alia_kb_df['subject'] == gen_text] # for _, item in alias_df.iterrows(): # a_id = str(item['subject_id']) # if gen_subject_ids.__contains__(a_id): # continue # gen_subject_ids.append(a_id) # for alia_id in gen_subject_ids: # alias_df = pd_df[pd_df['subject_id'] == int(alia_id)] # for _, item in alias_df.iterrows(): # b_id = str(item['subject_id']) # if cand_ids.get(b_id) is not None: # continue # cand_ids[b_id] = 1 # subject = item['subject'] # # text = data_utils.get_text(ast.literal_eval(item['data']), item['subject']) # text = data_utils.get_all_text(item['subject'], ast.literal_eval(item['data'])) # cands.append({'cand_id': b_id, 'cand_subject': subject, 'cand_text': text, # 'cand_type': com_utils.get_kb_type(ast.literal_eval(item['type']))}) mention['cands'] = cands com_utils.pickle_save(dev_mention_data, out_file) print("success create {} dev data with mention and cands!".format(index))
def get_eval_examples(self): lines = com_utils.pickle_load(fileConfig.dir_ner + fileConfig.file_ner_eval_data) examples = self._create_example(lines, nerConfig.mode_predict) return examples
def create_dev_mention_data(mode, ner_datas, out_file): ner_datas = com_utils.pickle_load(ner_datas) jieba_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_jieba_kb) stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword + fileConfig.file_stopword) gen_more_words = data_utils.get_stopword_list( fileConfig.dir_stopword + fileConfig.file_analysis_gen_more) text_id = 1 dev_mention_data = [] # count = 0 for data in tqdm(ner_datas, 'find entity'): # count += 1 # if count < 1496: # continue text = ''.join(data['text']) tag_list = data['tag'] start_index = 0 mention_length = 0 is_find = False mentions = [] type_dict = {} # use tag find for i, tag in enumerate(tag_list): # if tag == nerConfig.B_seg + nerConfig.KB_seg: if tag.find(nerConfig.B_seg) > -1 or ( tag.find(nerConfig.I_seg) > -1 and not is_find): type_str = tag.split('_')[1] type_dict = com_utils.dict_add(type_dict, type_str) start_index = i mention_length = 1 is_find = True elif tag.find(nerConfig.E_seg) > -1 and not is_find: type_str = tag.split('_')[1] type_dict = com_utils.dict_add(type_dict, type_str) start_index = i mention_length += 1 mention = text[start_index:start_index + mention_length] mention = data_utils.strip_punctuation(mention) type_list = Counter(type_dict).most_common() mentions.append({ 'mention': mention, 'offset': str(start_index), 'type': type_list[0][0] }) is_find = False mention_length = 0 type_dict = {} # elif tag == nerConfig.I_seg + nerConfig.KB_seg and is_find: elif tag.find(nerConfig.I_seg) > -1 and is_find: type_str = tag.split('_')[1] type_dict = com_utils.dict_add(type_dict, type_str) mention_length += 1 # elif tag == nerConfig.E_seg + nerConfig.KB_seg and is_find: elif tag.find(nerConfig.E_seg) > -1 and is_find: type_str = tag.split('_')[1] type_dict = com_utils.dict_add(type_dict, type_str) mention_length += 1 mention = text[start_index:start_index + mention_length] mention = data_utils.strip_punctuation(mention) type_list = Counter(type_dict).most_common() mentions.append({ 'mention': mention, 'offset': str(start_index), 'type': type_list[0][0] }) is_find = False mention_length = 0 type_dict = {} elif tag == nerConfig.O_seg: is_find = False mention_length = 0 type_dict = {} # use jieba find jieba_entities = cut_client.cut_text(text) for i, tag in enumerate(tag_list): # if tag == nerConfig.B_seg + nerConfig.KB_seg or tag == nerConfig.I_seg + nerConfig.KB_seg or tag == nerConfig.E_seg + nerConfig.KB_seg: if tag.find(nerConfig.B_seg) > -1 or tag.find( nerConfig.I_seg) > -1 or tag.find(nerConfig.E_seg) > -1: jieba_offset = i jieba_char = text[i] jieba_text = get_jieba_mention(jieba_entities, jieba_char, jieba_offset) if jieba_text is None: continue elif jieba_text == '_' or jieba_text == '-': continue elif data_utils.is_punctuation(jieba_text): continue elif len(jieba_text) == 1: continue elif stopwords.get(jieba_text) is not None: continue # elif gen_more_words.get(jieba_text) is not None: # continue jieba_offset = jieba_offset - jieba_text.find(jieba_char) if len(jieba_text) <= comConfig.max_jieba_cut_len and ( jieba_dict.get(jieba_text) is not None): type_str = tag.split('_')[1] if tag.find('_') > -1 else 'O' if jieba_text is None: continue if not is_already_find_mention(mentions, jieba_text, jieba_offset): mentions.append({ 'mention': jieba_text, 'offset': str(jieba_offset), 'type': type_str }) # find inner brackets mentions bracket_mentions = data_utils.get_mention_inner_brackets( text, tag_list) if len(bracket_mentions) > 0: mentions += bracket_mentions # completion mentions # mentions_com = [] # for mention in mentions: # mention_str = mention['mention'] # try: # for find in re.finditer(mention_str, text): # find_offset = find.span()[0] # if find_offset != int(mention['offset']): # mentions_com.append( # {'mention': mention['mention'], 'offset': str(find_offset), 'type': mention['type']}) # except BaseException: # # print("occur error when match mention str in completion mentions, error value:{} text:{}".format( # # mention_str, text)) # pass # mentions_com.append(mention) # mentions = mentions_com # optim mentions delete_mentions = [] mentions.sort(key=get_mention_len) for mention in mentions: mention_offset = int(mention['offset']) mention_len = len(mention['mention']) for sub_mention in mentions: if mention_offset != int(sub_mention['offset']) and int( sub_mention['offset']) in range( mention_offset, mention_offset + mention_len): if not data_utils.is_mention_already_in_list( delete_mentions, sub_mention): delete_mentions.append(sub_mention) if mention_offset == int(sub_mention['offset']) and len( mention['mention']) > len(sub_mention['mention']): if not data_utils.is_mention_already_in_list( delete_mentions, sub_mention): delete_mentions.append(sub_mention) if len(delete_mentions) > 0: change_mentions = [] for mention in mentions: if not data_utils.is_mention_already_in_list( delete_mentions, mention): change_mentions.append(mention) mentions = change_mentions change_mentions = [] for mention in mentions: if not data_utils.is_mention_already_in_list( change_mentions, mention ) and mention['mention'] not in comConfig.punctuation: change_mentions.append(mention) mentions = change_mentions # optim mentions # sort mentions mentions.sort(key=get_offset) # optimize the mention data mentions_optim = [] for mention in mentions: mentions_optim.append({ 'mention': get_optim_mention_text(jieba_entities, mention['mention']), 'offset': mention['offset'], 'type': mention['type'] }) if mode == 1: dev_mention_data.append({ 'text_id': str(text_id), 'text': text, 'mention_data': mentions_optim }) elif mode == 2: dev_mention_data.append({ 'text_id': str(text_id), 'text': text, 'mention_data': mentions_optim, 'mention_data_original': data['mention_data_original'] }) elif mode == 3: dev_mention_data.append({ 'text_id': str(text_id), 'text': text, 'mention_data': mentions_optim }) text_id += 1 com_utils.pickle_save(dev_mention_data, out_file) print("success create dev data with mentions, mode:{}".format(mode))