Exemplo n.º 1
0
def create_extend_train_file():
    print("start create extend train file...")
    # train_file = open(fileConfig.dir_data + fileConfig.file_train_data, 'r', encoding='utf-8')
    train_file = com_utils.pickle_load(fileConfig.dir_data + fileConfig.file_train_pkl)
    test_file = com_utils.pickle_load(fileConfig.dir_data + fileConfig.file_test_pkl)
    extend_out_file = open(fileConfig.dir_data + fileConfig.file_extend_train_data, 'w', encoding='utf-8')
    extend_test_file = open(fileConfig.dir_data + fileConfig.file_extend_test_data, 'w', encoding='utf-8')
    kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict)
    for line in tqdm(train_file, desc='extend train file'):
        jstr = ujson.loads(line)
        extend_lines = data_utils.get_extend_ner_train_list(kb_dict, jstr)
        for lines in extend_lines:
            if len(lines['text']) > (nerConfig.max_seq_length - 2):
                continue
            extend_out_file.write(ujson.dumps(lines, ensure_ascii=False))
            extend_out_file.write('\n')
    for line in tqdm(test_file, desc='extend test file'):
        jstr = ujson.loads(line)
        extend_test_file.write(ujson.dumps(jstr, ensure_ascii=False))
        extend_test_file.write('\n')

    # jstr = ujson.loads('{"text_id": "27755", "text": "副军级海军大校何永明担任解放军驻海南某部", "mention_data": [{"kb_id": "NIL", "mention": "副军级", "offset": "0"}, {"kb_id": "346365", "mention": "海军", "offset": "3"}, {"kb_id": "163745", "mention": "大校", "offset": "5"}, {"kb_id": "183299", "mention": "何永明", "offset": "7"}, {"kb_id": "253615", "mention": "担任", "offset": "10"}, {"kb_id": "101210", "mention": ">解放军驻海南某部", "offset": "12"}, {"kb_id": "193906", "mention": "驻", "offset": "15"}, {"kb_id": "155589", "mention": "海南", "offset": "16"}]}')
    # extend_lines = data_utils.get_extend_ner_train_list(kb_dict, jstr)
    # for lines in extend_lines:
    #     extend_out_file.write(ujson.dumps(lines, ensure_ascii=False))
    #     extend_out_file.write('\n')
    print('success create extend train file')
Exemplo n.º 2
0
def create_nel_train_data():
    if not os.path.exists(fileConfig.dir_nel):
        os.mkdir(fileConfig.dir_nel)
    train_data = open(fileConfig.dir_data + fileConfig.file_train_data, 'r')
    kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict)
    pd_df = pd.read_csv(fileConfig.dir_kb_info + fileConfig.file_kb_pandas_csv)
    data_list = []
    for line in tqdm(train_data, desc='create entity link train data'):
        # for line in train_data:
        jstr = ujson.loads(line)
        text_id = jstr['text_id']
        text = jstr['text']
        mention_datas = jstr['mention_data']
        for mention_data in mention_datas:
            kb_id = mention_data['kb_id']
            mention = mention_data['mention']
            start = mention_data['offset']
            end = int(start) + len(mention) - 1
            kb_entity = kb_dict.get(kb_id)
            if kb_entity is not None:
                entity_cands, entity_ids, entity_text = data_utils.get_entity_cands(kb_entity, kb_id, pd_df)
            else:
                continue
            data_list.append({'text_id': text_id, 'mention_text': text, 'mention': mention,
                              'mention_position': [start, end], 'entity_cands': entity_cands,
                              'entity_text': entity_text, 'entity_ids': entity_ids})
    com_utils.pickle_save(data_list, fileConfig.dir_nel + fileConfig.file_nel_entity_link_train_data)
    print("success create nel entity link train data")
Exemplo n.º 3
0
def create_fasttext_sup_train_data(index, train_data_file, kb_dict_file, kb_alia_file, stopword_file, out_file,
                                   mode=fasttextConfig.create_data_word):
    print("create {} sup train data".format(index))
    kb_alias_df = pd.read_csv(kb_alia_file)
    stopwords = data_utils.get_stopword_list(stopword_file)
    train_datas = open(train_data_file, 'r', encoding='utf-8').readlines()
    kb_dict = com_utils.pickle_load(kb_dict_file)
    train_out_file = open(out_file, 'w', encoding='utf-8')
    text_ids = {}
    max_extend_countd = 3
    for line in tqdm(train_datas, desc='deal {} train file'.format(index)):
        jstr = ujson.loads(line)
        text = jstr['text']
        text_id = jstr['text_id']
        if text_ids.get(text_id) == max_extend_countd:
            continue
        mentions = jstr['mention_data']
        for mention in mentions:
            mention_id = mention['kb_id']
            mention_text = mention['mention']
            neighbor_text = com_utils.get_neighbor_sentence(text, mention_text)
            # true values
            kb_entity = kb_dict.get(mention_id)
            if kb_entity is not None:
                out_str = com_utils.get_entity_mention_pair_text(kb_entity['text'], neighbor_text, stopwords,
                                                                 cut_client,
                                                                 fasttextConfig.label_true, mode)
                train_out_file.write(out_str)
            # false values
            alia_ids = []
            alia_count = 0
            alias_df = kb_alias_df[kb_alias_df['subject'] == com_utils.cht_to_chs(mention_text)]
            for _, item in alias_df.iterrows():
                a_id = str(item['subject_id'])
                if a_id != mention_id:
                    alia_ids.append(a_id)
                    alia_count += 1
                    if alia_count == max_extend_countd:
                        break
            if len(alia_ids) > 0:
                for alia_id in alia_ids:
                    alia_entity = kb_dict.get(alia_id)
                    if alia_entity is not None:
                        out_str = com_utils.get_entity_mention_pair_text(alia_entity['text'], neighbor_text, stopwords,
                                                                         cut_client,
                                                                         fasttextConfig.label_false, mode)
                        train_out_file.write(out_str)
        # add text
        text_ids = com_utils.dict_add(text_ids, text_id)
    # 清理资源
    train_out_file.close()
    train_datas = None
    train_out_file = None
    kb_alias_df = None
    stopwords = None
    kb_dict = None
Exemplo n.º 4
0
def analysis_train_data():
    print("start use the fasttext model to predict test data")
    if not os.path.exists(fileConfig.dir_analysis):
        os.mkdir(fileConfig.dir_analysis)
    kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info +
                                    fileConfig.file_kb_dict)
    train_file = open(fileConfig.dir_data + fileConfig.file_train_data,
                      'r',
                      encoding='utf-8')
    out_file = open(fileConfig.dir_analysis +
                    fileConfig.file_analysis_train_untind,
                    'w',
                    encoding='utf-8')
    count = 1
    for line in tqdm(train_file, 'find unmatch'):
        jstr = ujson.loads(line)
        mention_data = jstr['mention_data']
        for mention in mention_data:
            mention_text = mention['mention']
            mention_id = mention['kb_id']
            kb_entity = kb_dict.get(mention_id)
            is_match = False
            if kb_entity is not None:
                kb_subject = kb_entity['subject']
                kb_alias = kb_entity['alias']
                if kb_subject == mention_text:
                    is_match = True
                if not is_match:
                    for alia in kb_alias:
                        if alia == mention_text:
                            is_match = True
                if not is_match:
                    out_file.write('-' * 20)
                    out_file.write('\n')
                    out_file.write("num:{}--text_id:{}--text:{}".format(
                        count, jstr['text_id'], jstr['text']))
                    out_file.write('\n')
                    out_file.write("not match:")
                    out_file.write('\n')
                    out_file.write('*' * 20)
                    out_file.write('\n')
                    out_file.write('mention_original: {}'.format(
                        ujson.dumps(mention, ensure_ascii=False)))
                    out_file.write('\n')
                    out_file.write("kb: {}".format(
                        'subject:{} alias:{}'.format(kb_entity['subject'],
                                                     kb_entity['alias'])))
                    out_file.write('\n')
                    out_file.write('*' * 20)
                    out_file.write('\n')
                    count += 1
    train_file.close()
    out_file.close()
    print("success analysis train file find miss match entities")
Exemplo n.º 5
0
def create_pandas_kb_alias_data():
    kb_file = open(fileConfig.dir_data + fileConfig.file_kb_data, 'r', encoding='utf-8')
    train_file = open(fileConfig.dir_data + fileConfig.file_train_data, 'r', encoding='utf-8')
    kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict)
    subject_id_list = []
    subject_list = []
    subjects = {}
    # from kb file
    for line in tqdm(kb_file, desc='deal kb_file'):
        jstr = ujson.loads(line)
        subject_id = jstr['subject_id']
        subject = com_utils.cht_to_chs(jstr['subject'].strip().lower())
        subject_id_list.append(subject_id)
        subject_list.append(subject)
        alias = jstr['alias']
        subjects[subject] = 1
        for alia in alias:
            alia_str = com_utils.cht_to_chs(alia.strip().lower())
            if subjects.get(alia_str) is not None:
                continue
            else:
                subjects[alia_str] = 1
                subject_id_list.append(subject_id)
                subject_list.append(alia_str)
    # from train file
    for line in tqdm(train_file, desc='deal train file'):
        jstr = ujson.loads(line)
        mention_data = jstr['mention_data']
        for mention in mention_data:
            mention_text = mention['mention']
            mention_text = com_utils.cht_to_chs(mention_text.lower())
            kb_id = mention['kb_id']
            kb_entity = kb_dict.get(kb_id)
            is_match = False
            if kb_entity is not None:
                kb_subject = kb_entity['subject']
                kb_alias = kb_entity['alias']
                if kb_subject == mention_text:
                    is_match = True
                if not is_match:
                    for alia in kb_alias:
                        if alia == mention_text:
                            is_match = True
                if not is_match:
                    if subjects.get(mention_text) is not None:
                        continue
                    else:
                        subjects[mention_text] = 1
                        subject_id_list.append(kb_id)
                        subject_list.append(mention_text)
    pandas_dict = {'subject_id': subject_id_list, 'subject': subject_list}
    df = pd.DataFrame.from_dict(pandas_dict)
    df.to_csv(fileConfig.dir_kb_info + fileConfig.file_kb_pandas_alias_data)
    print("success create pandas kb alia data file")
Exemplo n.º 6
0
def create_nel_vocab():
    # create mention entity vocab
    train_file = open(fileConfig.dir_data + fileConfig.file_train_data, mode='r', encoding='utf-8')
    dev_file = open(fileConfig.dir_data + fileConfig.file_dev_data, mode='r', encoding='utf-8')
    out_file = open(fileConfig.dir_nel + fileConfig.file_nel_mention_context_vocab, mode='w', encoding='utf-8')
    vocab_dict = {}
    for line in tqdm(train_file, desc='deal train file'):
        jstr = ujson.loads(line)
        text = jstr['text']
        words = text.strip()
        for word in words:
            if word not in vocab_dict:
                vocab_dict[word] = 1
            else:
                vocab_dict[word] += 1
    for line in tqdm(dev_file, desc='deal dev file'):
        jstr = ujson.loads(line)
        text = jstr['text']
        words = text.strip()
        for word in words:
            if word not in vocab_dict:
                vocab_dict[word] = 1
            else:
                vocab_dict[word] += 1
    out_file.write('[PAD]\n')
    out_file.write('[UNK]\n')
    vocab_length = len(vocab_dict)
    for i, item in enumerate(Counter(vocab_dict).most_common()):
        out_file.write(item[0] + '\n') if i < vocab_length - 1 else out_file.write(item[0])
    out_file.close()
    print("success create mention entity vocab data")

    # create entity context vocab
    entity_dict = set()
    text_dict = defaultdict(int)
    kb_data = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict)
    for key, value in kb_data.items():
        text = value['text']
        subject = key
        if subject not in entity_dict:
            entity_dict.add(subject)
        else:
            raise Exception(f'entity : {subject} duplicated!!!')
        for t in text:
            text_dict[t] += 1
    entity_vocab = open(fileConfig.dir_nel + fileConfig.file_nel_entity_vocab, 'w', encoding='utf-8')
    entity_context_vocab = open(fileConfig.dir_nel + fileConfig.file_nel_entity_context_vocab, 'w', encoding='utf-8')
    for entity in entity_dict:
        entity_vocab.write(entity + '\n')
    entity_context_vocab.write('[PAD]\n')
    entity_context_vocab.write('[UNK]\n')
    for token, value in Counter(text_dict).most_common():
        entity_context_vocab.write(token + '\n')
    print("success create entity context vocab / entity vacab data")
Exemplo n.º 7
0
def split_train_data(train_file_path=None, out_train_file=None, out_dev_file=None, is_split=True):
    data_list = com_utils.pickle_load(train_file_path)
    if not is_split:
        dev_list = com_utils.pickle_load(fileConfig.dir_ner + fileConfig.file_extend_ner_dev_data)
    data_len = len(data_list)
    # train_size = int(data_len * comConfig.train_ratio)
    test_size = 10000
    random.seed(comConfig.random_seed)
    random.shuffle(data_list)

    # train_data = data_list[:train_size]
    if is_split:
        train_data = data_list[:data_len - test_size]
        dev_data = data_list[data_len - test_size:data_len]
        com_utils.pickle_save(train_data, out_train_file)
        com_utils.pickle_save(dev_data, out_dev_file)
    else:
        train_data = data_list
        com_utils.pickle_save(train_data, out_train_file)
        com_utils.pickle_save(dev_list, out_dev_file)
    print("success split data set")
Exemplo n.º 8
0
def analysis_test_result():
    print("start analysis test result...")
    test_result_file = open(fileConfig.dir_result +
                            fileConfig.file_result_fasttext_test,
                            'r',
                            encoding='utf-8')
    kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info +
                                    fileConfig.file_kb_dict)
    error_results = []
    more_dict = {}
    for line in tqdm(test_result_file, 'read from file'):
        jstr = ujson.loads(line)
        gen_mentions = jstr['mention_data']
        original_mentions = jstr['mention_data_original']
        error_list = get_result_error_list(gen_mentions, original_mentions,
                                           more_dict)
        if (len(error_list)) > 0:
            error_results.append({
                'text_id': jstr['text_id'],
                'text': jstr['text'],
                'errors': error_list
            })
    test_result_file.close()
    test_result_file = None
    out_file = open(fileConfig.dir_result +
                    fileConfig.file_result_fasttext_test_analysis,
                    'w',
                    encoding='utf-8')
    for item in tqdm(error_results, 'write result'):
        out_file.write('-' * 20)
        out_file.write('\n')
        out_file.write("text_id:{}--text:{}".format(item['text_id'],
                                                    item['text']))
        out_file.write('\n')
        out_file.write("errors:")
        out_file.write('\n')
        for error in item['errors']:
            out_file.write('*' * 20)
            out_file.write('\n')
            out_file.write('error type:{}'.format(
                get_error_type(error['error_type'])))
            out_file.write('\n')
            out_file.write(get_error_content(kb_dict, error))
            out_file.write('*' * 20)
            out_file.write('\n')
        out_file.write('-' * 20)
        out_file.write('\n')
    out_more = open(fileConfig.dir_result + fileConfig.file_analysis_gen_more,
                    'w',
                    encoding='utf-8')
    for item in tqdm(Counter(more_dict).most_common(), 'write more'):
        out_more.write(item[0])
        out_more.write('\n')
Exemplo n.º 9
0
def split_eval_mention(num):
    dev_mention_data = com_utils.pickle_load(
        fileConfig.dir_ner + fileConfig.file_ner_eval_mention_data)
    data_len = len(dev_mention_data)
    block_size = data_len / num
    for i in range(1, num + 1):
        data_iter = dev_mention_data[int((i - 1) * block_size):int(i *
                                                                   block_size)]
        com_utils.pickle_save(
            data_iter, fileConfig.dir_ner_split +
            fileConfig.file_ner_eval_mention_split.format(i))
    print("success split test mention to:{} files".format(num))
Exemplo n.º 10
0
def create_eval_cands_from_split(num):
    out_file = open(fileConfig.dir_ner + fileConfig.file_ner_eval_cands_data,
                    'w',
                    encoding='utf-8')
    for i in range(1, num + 1):
        datas = com_utils.pickle_load(
            fileConfig.dir_ner_split +
            fileConfig.file_ner_eval_cands_split.format(i))
        for line in datas:
            text = ujson.dumps(line, ensure_ascii=False)
            out_file.write(text)
            out_file.write('\n')
    print("merge eval cands data success!")
Exemplo n.º 11
0
def train():
    datas = com_utils.pickle_load(fileConfig.dir_kb_info + fileConfig.file_kb_dict)
    vectorizer = TfidfVectorizer()
    train_sentence = []
    print("prepare train data")
    for key, data in tqdm(datas.items(), desc='init train data'):
        train_sentence.append(' '.join(cut_client.cut_text(data['text'])))
    print("start train tfidf model")
    X = vectorizer.fit_transform(train_sentence)
    print("save model and keyword")
    tfidf_save_data = [X, vectorizer]
    if not os.path.exists(fileConfig.dir_tfidf):
        os.mkdir(fileConfig.dir_tfidf)
    com_utils.pickle_save(tfidf_save_data, fileConfig.dir_tfidf + fileConfig.file_tfidf_save_data)
    print("success train and save tfidf file")
Exemplo n.º 12
0
def test():
    print("start test the tfidf model")
    tfidf_data = com_utils.pickle_load(fileConfig.dir_tfidf + fileConfig.file_tfidf_save_data)
    vectorizer = tfidf_data[1]
    X = tfidf_data[0]
    # init test data
    ratio = 0.01
    print("init test datas use ratio:{}".format(ratio))
    test_datas = []
    train_file = open(fileConfig.dir_data + fileConfig.file_train_data, 'r', encoding='utf-8')
    for line in train_file:
        test_datas.append(ujson.loads(line))
    test_data_len = int(len(test_datas) * ratio)
    random.seed(comConfig.random_seed)
    random.shuffle(test_datas)
    test_datas = test_datas[:test_data_len]
    mentions = []
    for data in test_datas:
        mention_datas = data['mention_data']
        for mention in mention_datas:
            if mention['kb_id'] != 'NIL':
                mention_copy = mention.copy()
                mention_copy['sentence'] = data['text']
                mentions.append(mention_copy)
    # start test model
    print("start find mention")
    y_pred = []
    y_true = []
    for mention in tqdm(mentions, desc='find mention'):
        # for mention in mentions:
        y_true.append(int(mention['kb_id']))
        text = mention['mention']
        text_len = len(text)
        sentence = mention['sentence']
        for i in range(len(sentence) - text_len + 1):
            if sentence[i:i + text_len] == text:
                if i > 10 and i + text_len < len(sentence) - 9:
                    neighbor_sentence = sentence[i - 10:i + text_len + 9]
                elif i < 10:
                    neighbor_sentence = sentence[:20]
                elif i + text_len > len(sentence) - 9:
                    neighbor_sentence = sentence[-20:]
                kb_id = get_entityid(neighbor_sentence, vectorizer, X)
                y_pred.append(kb_id)
                break
    # calc the f1
    acc, f1 = acc_f1(y_pred, y_true)
    print("acc:{:.4f} f1:{:.4f}".format(acc, f1))
Exemplo n.º 13
0
def gen_simi_subject_list(file_path):
    print('start gen similar subject...')
    file_datas = com_utils.pickle_load(file_path)
    gensim_model = word2vec.Word2VecKeyedVectors.load(
        fileConfig.dir_fasttext + fileConfig.file_gensim_tencent_unsup_model)
    for item in tqdm(file_datas, 'gen simi subject'):
        mention_data = item['mention_data']
        for mention in mention_data:
            mention_text = mention['mention']
            try:
                mention['gen_subjects'] = get_simi_subject_list(
                    gensim_model.most_similar(positive=[mention_text], topn=5))
            except BaseException:
                mention['gen_subjects'] = []
    com_utils.pickle_save(file_datas, file_path)
    print('success gen similar subject...')
Exemplo n.º 14
0
def create_nel_batch_iter(mode, entity_context_vocab, mention_context_vocab,
                          entity_vocab):
    if mode == nelConfig.mode_train:
        datas = com_utils.pickle_load(
            fileConfig.dir_nel + fileConfig.file_nel_entity_link_train_data)
        data_len = len(datas)
        train_size = int(data_len * comConfig.train_ratio)
        random.seed(comConfig.random_seed)
        random.shuffle(datas)
        train_data = datas[:train_size]
        dev_data = datas[train_size:]
        print("create nel data train len:{} dev len:{}".format(
            len(train_data), len(dev_data)))
        train_dataset = DataSet(train_data,
                                transform=transforms.Compose([
                                    NELToTensor(entity_context_vocab,
                                                mention_context_vocab,
                                                entity_vocab)
                                ]))
        dev_dataset = DataSet(dev_data,
                              transform=transforms.Compose([
                                  NELToTensor(entity_context_vocab,
                                              mention_context_vocab,
                                              entity_vocab)
                              ]))
        train_iter = DataLoader(train_dataset,
                                num_workers=4,
                                batch_size=nelConfig.batch_size,
                                shuffle=True,
                                collate_fn=collate_fn_entity_link)
        dev_iter = DataLoader(dev_dataset,
                              num_workers=4,
                              batch_size=nelConfig.batch_size,
                              shuffle=True,
                              collate_fn=collate_fn_entity_link)
        return train_iter, dev_iter
Exemplo n.º 15
0
def analysis_test_ner_result():
    print("start analysis test ner result...")
    ner_test_datas = com_utils.pickle_load(
        fileConfig.dir_ner + fileConfig.file_ner_test_predict_tag)
    jieba_dict = com_utils.pickle_load(fileConfig.dir_kb_info +
                                       fileConfig.file_jieba_kb)
    out_file = open(fileConfig.dir_result +
                    fileConfig.file_ner_test_result_analysis,
                    'w',
                    encoding='utf-8')
    stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword +
                                             fileConfig.file_stopword)
    gen_more_words = data_utils.get_stopword_list(
        fileConfig.dir_stopword + fileConfig.file_analysis_gen_more)
    text_id = 1
    for data in tqdm(ner_test_datas, 'find entity'):
        text = ''.join(data['text'])
        tag_list = data['tag']
        start_index = 0
        mention_length = 0
        is_find = False
        mentions = []
        type_dict = {}
        # use tag find
        for i, tag in enumerate(tag_list):
            # if tag == nerConfig.B_seg + nerConfig.KB_seg:
            if tag.find(nerConfig.B_seg) > -1 or (
                    tag.find(nerConfig.I_seg) > -1 and not is_find):
                type_str = tag.split('_')[1]
                type_dict = com_utils.dict_add(type_dict, type_str)
                start_index = i
                mention_length += 1
                is_find = True
            elif tag.find(nerConfig.E_seg) > -1 and not is_find:
                type_str = tag.split('_')[1]
                type_dict = com_utils.dict_add(type_dict, type_str)
                start_index = i
                mention_length += 1
                mention = text[start_index:start_index + mention_length]
                # mention = data_utils.strip_punctuation(mention)
                type_list = Counter(type_dict).most_common()
                mentions.append({
                    'T': 'NER',
                    'mention': mention,
                    'offset': str(start_index),
                    'type': type_list[0][0]
                })
                is_find = False
                mention_length = 0
                type_dict = {}
            # elif tag == nerConfig.I_seg + nerConfig.KB_seg and is_find:
            elif tag.find(nerConfig.I_seg) > -1 and is_find:
                type_str = tag.split('_')[1]
                type_dict = com_utils.dict_add(type_dict, type_str)
                mention_length += 1
            # elif tag == nerConfig.E_seg + nerConfig.KB_seg and is_find:
            elif tag.find(nerConfig.E_seg) > -1 and is_find:
                type_str = tag.split('_')[1]
                type_dict = com_utils.dict_add(type_dict, type_str)
                mention_length += 1
                mention = text[start_index:start_index + mention_length]
                # mention = data_utils.strip_punctuation(mention)
                type_list = Counter(type_dict).most_common()
                mentions.append({
                    'T': 'NER',
                    'mention': mention,
                    'offset': str(start_index),
                    'type': type_list[0][0]
                })
                is_find = False
                mention_length = 0
                type_dict = {}
            elif tag == nerConfig.O_seg:
                is_find = False
                mention_length = 0
                type_dict = {}
        # use jieba find
        jieba_entities = cut_client.cut_text(text)
        for i, tag in enumerate(tag_list):
            # if tag == nerConfig.B_seg + nerConfig.KB_seg or tag == nerConfig.I_seg + nerConfig.KB_seg or tag == nerConfig.E_seg + nerConfig.KB_seg:
            # if tag.find(nerConfig.B_seg) > -1 or tag.find(nerConfig.I_seg) > -1 or tag.find(nerConfig.E_seg) > -1:
            jieba_offset = i
            jieba_char = text[i]
            jieba_text = get_jieba_mention(jieba_entities, jieba_char)
            if jieba_text is None:
                continue
            elif jieba_text == '_' or jieba_text == '-':
                continue
            elif len(jieba_text) == 1:
                continue
            elif stopwords.get(jieba_text) is not None:
                continue
            elif gen_more_words.get(jieba_text) is not None:
                continue
            jieba_offset = jieba_offset - jieba_text.find(jieba_char)
            if len(jieba_text) <= comConfig.max_jieba_cut_len and (
                    jieba_dict.get(jieba_text) is not None):
                type_str = tag.split('_')[1] if tag.find('_') > -1 else 'O'
                if jieba_text is None:
                    continue
                if not is_already_find_mention(mentions, jieba_text,
                                               jieba_offset):
                    # jieba_offset = text.find(jieba_text)
                    mentions.append({
                        'T': 'JIEBA',
                        'mention': jieba_text,
                        'offset': str(jieba_offset),
                        'type': type_str
                    })
        # find inner brackets mentions
        bracket_mentions = data_utils.get_mention_inner_brackets(
            text, tag_list)
        for mention in bracket_mentions:
            mention['T'] = 'bracket'
        if len(bracket_mentions) > 0:
            mentions += bracket_mentions
        # completion mentions
        # mentions_com = []
        # for mention in mentions:
        #     mention_str = mention['mention']
        #     try:
        #         for find in re.finditer(mention_str, text):
        #             find_offset = find.span()[0]
        #             if find_offset != int(mention['offset']):
        #                 mentions_com.append(
        #                     {'T': 'COM', 'mention': mention['mention'], 'offset': str(find_offset),
        #                      'type': mention['type']})
        #     except BaseException:
        #         # print("occur error when match mention str in completion mentions, error value:{} text:{}".format(
        #         #     mention_str, text))
        #         pass
        #     mentions_com.append(mention)
        # mentions = mentions_com
        # completion mentions
        out_file.write('\n')
        result_str = ''
        for i in range(len(text)):
            result_str += text[i] + '-' + tag_list[i] + ' '
        out_file.write(' text_id:{}, text:{} '.format(text_id, result_str))
        out_file.write('\n')
        out_file.write(' gen_mentions:{} '.format(
            ujson.dumps(mentions, ensure_ascii=False)))
        out_file.write('\n')
        text_id += 1
Exemplo n.º 16
0
 def get_train_examples(self):
     lines = com_utils.pickle_load(fileConfig.dir_ner +
                                   fileConfig.file_ner_train_data)
     examples = self._create_example(lines, nerConfig.mode_train)
     return examples
Exemplo n.º 17
0
def test_sup(mode=fasttextConfig.create_data_word):
    print("start use the fasttext model/supervise model to predict test data")
    if not os.path.exists(fileConfig.dir_result):
        os.mkdir(fileConfig.dir_result)
    unsup_model_fasttext = fastText.load_model(
        fileConfig.dir_fasttext +
        fileConfig.file_fasttext_model.format(fasttextConfig.choose_model))
    unsup_model_gensim = word2vec.Word2VecKeyedVectors.load(
        fileConfig.dir_fasttext + fileConfig.file_gensim_tencent_unsup_model)
    sup_model = fastText.load_model(fileConfig.dir_fasttext +
                                    fileConfig.file_fasttext_sup_word_model)
    stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword +
                                             fileConfig.file_stopword)
    kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info +
                                    fileConfig.file_kb_dict)
    dev_file = open(fileConfig.dir_ner + fileConfig.file_ner_test_cands_data,
                    'r',
                    encoding='utf-8')
    out_file = open(fileConfig.dir_result +
                    fileConfig.file_result_fasttext_test,
                    'w',
                    encoding='utf-8')
    # f1 parmas
    gen_mention_count = 0
    original_mention_count = 0
    correct_mention_count = 0
    # count = 0
    # entity diambiguation
    for line in tqdm(dev_file, 'entity diambiguation'):
        # count += 1
        # if count < 3456:
        #     continue
        jstr = ujson.loads(line)
        dev_entity = {}
        text = com_utils.cht_to_chs(jstr['text'].lower())
        dev_entity['text_id'] = jstr['text_id']
        dev_entity['text'] = jstr['text']
        mention_data = jstr['mention_data']
        original_mention_data = jstr['mention_data_original']
        mentions = []
        for mention in mention_data:
            mention_text = mention['mention']
            if mention_text is None:
                continue
            cands = mention['cands']
            if len(cands) == 0:
                continue
            # use supervised model to choose mention
            supervise_cands = []
            for cand in cands:
                neighbor_text = com_utils.get_neighbor_sentence(
                    text, com_utils.cht_to_chs(mention_text.lower()))
                cand_entity = kb_dict.get(cand['cand_id'])
                if cand_entity is not None:
                    out_str = com_utils.get_entity_mention_pair_text(
                        com_utils.cht_to_chs(cand_entity['text'].lower()),
                        neighbor_text,
                        stopwords,
                        cut_client,
                        mode=mode)
                    # print(out_str)
                    result = sup_model.predict(out_str.replace('\n',
                                                               ' '))[0][0]
                    if result == fasttextConfig.label_true:
                        supervise_cands.append(cand)
            # unsupervise model choose item
            max_cand = None
            if len(supervise_cands) == 0:
                supervise_cands = cands
            # score list
            score_list = []
            mention_neighbor_sentence = text
            for i, cand in enumerate(supervise_cands):
                # score_fasttext = fasttext_get_sim(unsup_model_fasttext, mention_neighbor_sentence,
                #                          com_utils.cht_to_chs(cand['cand_text'].lower()), stopwords)
                score_gensim = gensim_get_sim(
                    unsup_model_gensim, mention_neighbor_sentence,
                    com_utils.cht_to_chs(cand['cand_text'].lower()), stopwords)
                # score = (0.8 * score_gensim) + (0.2 * score_fasttext)
                score = score_gensim
                # if score > max_score:
                #     max_score = score
                #     max_index = score
                if score < fasttextConfig.min_entity_similarity_threshold:
                    continue
                score_list.append({
                    'cand_id': cand['cand_id'],
                    'cand_score': score,
                    'cand_type': cand['cand_type']
                })
            # if max_score < fasttextConfig.min_entity_similarity_threshold:
            #     continue
            # find the best cand
            # find_type = False
            score_list.sort(key=get_socre_key, reverse=True)
            # for item in score_list:
            #     if item['cand_type'] == mention['type']:
            #         find_type = True
            # if find_type:
            #     for item in score_list:
            #         if item['cand_score'] > fasttextConfig.choose_entity_similarity_threshold:
            #             max_cand = item
            if max_cand is None:
                if len(score_list) > 0:
                    max_cand = score_list[0]
            # find the best cand
            if max_cand is not None:
                mentions.append({
                    'kb_id': max_cand['cand_id'],
                    'mention': mention['mention'],
                    'offset': mention['offset']
                })
        # optim mentions
        delete_mentions = []
        mentions.sort(key=get_mention_len)
        for mention in mentions:
            mention_offset = int(mention['offset'])
            mention_len = len(mention['mention'])
            for sub_mention in mentions:
                if mention_offset != int(sub_mention['offset']) and int(
                        sub_mention['offset']) in range(
                            mention_offset, mention_offset + mention_len):
                    if not data_utils.is_mention_already_in_list(
                            delete_mentions, sub_mention):
                        delete_mentions.append(sub_mention)
                if mention_offset == int(sub_mention['offset']) and len(
                        mention['mention']) > len(sub_mention['mention']):
                    if not data_utils.is_mention_already_in_list(
                            delete_mentions, sub_mention):
                        delete_mentions.append(sub_mention)
        if len(delete_mentions) > 0:
            change_mentions = []
            for mention in mentions:
                if not data_utils.is_mention_already_in_list(
                        delete_mentions, mention):
                    change_mentions.append(mention)
            mentions = change_mentions
        change_mentions = []
        for mention in mentions:
            if not data_utils.is_mention_already_in_list(
                    change_mentions, mention
            ) and mention['mention'] not in comConfig.punctuation:
                change_mentions.append(mention)
        mentions = change_mentions
        mentions.sort(key=get_mention_offset)
        # optim mentions
        # calc f1
        for mention in mentions:
            if is_find_correct_entity(mention['kb_id'], original_mention_data):
                correct_mention_count += 1
        gen_mention_count += len(mentions)
        for orginal_mention in original_mention_data:
            if orginal_mention['kb_id'] != 'NIL':
                original_mention_count += 1
        # out result
        dev_entity['mention_data'] = mentions
        dev_entity['mention_data_original'] = original_mention_data
        out_file.write(ujson.dumps(dev_entity, ensure_ascii=False))
        out_file.write('\n')
    precision = correct_mention_count / gen_mention_count
    recall = correct_mention_count / original_mention_count
    f1 = 2 * precision * recall / (precision + recall)
    print("success create test result, p:{:.4f} r:{:.4f} f1:{:.4f}".format(
        precision, recall, f1))
Exemplo n.º 18
0
def test():
    print("start use the fasttext model to predict test data")
    if not os.path.exists(fileConfig.dir_result):
        os.mkdir(fileConfig.dir_result)
    model = fastText.load_model(
        fileConfig.dir_fasttext +
        fileConfig.file_fasttext_model.format(fasttextConfig.choose_model))
    stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword +
                                             fileConfig.file_stopword)
    kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info +
                                    fileConfig.file_kb_dict)
    dev_file = open(fileConfig.dir_ner + fileConfig.file_ner_test_cands_data,
                    'r',
                    encoding='utf-8')
    out_file = open(fileConfig.dir_result +
                    fileConfig.file_result_fasttext_test,
                    'w',
                    encoding='utf-8')
    # f1 parmas
    gen_mention_count = 0
    original_mention_count = 0
    correct_mention_count = 0
    # entity diambiguation
    for line in tqdm(dev_file, 'entity diambiguation'):
        jstr = ujson.loads(line)
        dev_entity = {}
        text = jstr['text']
        dev_entity['text_id'] = jstr['text_id']
        dev_entity['text'] = jstr['text']
        mention_data = jstr['mention_data']
        original_mention_data = jstr['mention_data_original']
        mentions = []
        for mention in mention_data:
            mention_text = mention['mention']
            cands = mention['cands']
            if len(cands) == 0:
                continue
            # if len(cands) == 1:
            #     mentions.append(
            #         {'kb_id': str(cands[0]['cand_id']), 'mention': mention['mention'],
            #          'offset': str(mention['offset'])})
            #     continue
            max_index = 0
            max_score = 0.0
            max_cand = None
            # mention_neighbor_sentence = get_neighbor_sentence(text, mention_text)
            # score list
            score_list = []
            mention_neighbor_sentence = text
            for i, cand in enumerate(cands):
                score = fasttext_get_sim(model, mention_neighbor_sentence,
                                         cand['cand_text'], stopwords)
                # if score > max_score:
                #     max_score = score
                #     max_index = i
                if score < fasttextConfig.min_entity_similarity_threshold:
                    continue
                score_list.append({
                    'cand_id': cand['cand_id'],
                    'cand_score': score,
                    'cand_type': cand['cand_type']
                })
            # if max_score < fasttextConfig.min_entity_similarity_threshold:
            #     continue
            # find the best cand
            find_type = False
            score_list.sort(key=get_socre_key, reverse=True)
            for item in score_list:
                if item['cand_type'] == mention['type']:
                    find_type = True
            if find_type:
                for item in score_list:
                    if item['cand_score'] > fasttextConfig.choose_entity_similarity_threshold:
                        max_cand = item
            if max_cand is None:
                if len(score_list) > 0:
                    max_cand = score_list[0]
            # find the best cand
            if max_cand is not None:
                if is_find_correct_entity(max_cand['cand_id'],
                                          original_mention_data):
                    correct_mention_count += 1
                mentions.append({
                    'kb_id': max_cand['cand_id'],
                    'mention': mention['mention'],
                    'offset': mention['offset']
                })
        # calc f1 params
        gen_mention_count += len(mentions)
        original_mention_count += len(original_mention_data)

        dev_entity['mention_data'] = mentions
        dev_entity['mention_data_original'] = original_mention_data
        out_file.write('-' * 20)
        out_file.write('\n')
        out_file.write("text_id:{}--text:{}".format(dev_entity['text_id'],
                                                    dev_entity['text']))
        out_file.write('\n')
        out_file.write("mention_data:")
        out_file.write('\n')
        # generate mention
        for mention in dev_entity['mention_data']:
            kb_mention = ''
            if mention['kb_id'] != 'NIL':
                kb_mention = ujson.dumps(kb_dict[mention['kb_id']],
                                         ensure_ascii=False)
            out_file.write('*' * 20)
            out_file.write('\n')
            out_file.write('mention_original: {}'.format(mention))
            out_file.write('\n')
            out_file.write("kb: {}".format(kb_mention))
            out_file.write('\n')
            out_file.write('*' * 20)
            out_file.write('\n')
        # original mention
        out_file.write("kb_data:")
        out_file.write('\n')
        for mention in dev_entity['mention_data_original']:
            kb_mention = ''
            if mention['kb_id'] != 'NIL':
                kb_mention = ujson.dumps(kb_dict[mention['kb_id']],
                                         ensure_ascii=False)
            out_file.write('*' * 20)
            out_file.write('\n')
            out_file.write('kb_original: {}'.format(mention))
            out_file.write('\n')
            out_file.write("kb: {}".format(kb_mention))
            out_file.write('\n')
            out_file.write('*' * 20)
            out_file.write('\n')
        out_file.write('-' * 20)
        out_file.write('\n')
    precision = correct_mention_count / gen_mention_count
    recall = correct_mention_count / original_mention_count
    f1 = 2 * precision * recall / (precision + recall)
    print("success create test result, p:{:.4f} r:{:.4f} f1:{:.4f}".format(
        precision, recall, f1))
Exemplo n.º 19
0
def eval_sup(mode=fasttextConfig.create_data_word):
    print("start use the fasttext/supervised model to predict eval data")
    if not os.path.exists(fileConfig.dir_result):
        os.mkdir(fileConfig.dir_result)
    # unsup_model = fastText.load_model(
    #     fileConfig.dir_fasttext + fileConfig.file_fasttext_model.format(fasttextConfig.model_skipgram))
    unsup_model = word2vec.Word2VecKeyedVectors.load(
        fileConfig.dir_fasttext + fileConfig.file_gensim_tencent_unsup_model)
    sup_model = fastText.load_model(fileConfig.dir_fasttext +
                                    fileConfig.file_fasttext_sup_word_model)
    kb_dict = com_utils.pickle_load(fileConfig.dir_kb_info +
                                    fileConfig.file_kb_dict)
    stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword +
                                             fileConfig.file_stopword)
    dev_file = open(fileConfig.dir_ner + fileConfig.file_ner_eval_cands_data,
                    'r',
                    encoding='utf-8')
    out_file = open(fileConfig.dir_result + fileConfig.file_result_eval_data,
                    'w',
                    encoding='utf-8')
    # entity diambiguation
    for line in tqdm(dev_file, 'entity diambiguation'):
        if len(line.strip('\n')) == 0:
            continue
        jstr = ujson.loads(line)
        dev_entity = {}
        text = com_utils.cht_to_chs(jstr['text'].lower())
        dev_entity['text_id'] = jstr['text_id']
        dev_entity['text'] = jstr['text']
        mention_data = jstr['mention_data']
        mentions = []
        for mention in mention_data:
            mention_text = mention['mention']
            if mention_text is None:
                continue
            cands = mention['cands']
            if len(cands) == 0:
                continue
            # use supervised model to choose mention
            supervise_cands = []
            for cand in cands:
                neighbor_text = com_utils.get_neighbor_sentence(
                    text, com_utils.cht_to_chs(mention_text.lower()))
                cand_entity = kb_dict.get(cand['cand_id'])
                if cand_entity is not None:
                    out_str = com_utils.get_entity_mention_pair_text(
                        com_utils.cht_to_chs(cand_entity['text'].lower()),
                        neighbor_text,
                        stopwords,
                        cut_client,
                        mode=mode)
                    result = sup_model.predict(out_str.strip('\n'))[0][0]
                    if result == fasttextConfig.label_true:
                        supervise_cands.append(cand)
            if len(supervise_cands) == 0:
                supervise_cands = cands
            # unsupervise model choose item
            max_cand = None
            # score list
            score_list = []
            mention_neighbor_sentence = text
            for i, cand in enumerate(supervise_cands):
                # score = fasttext_get_sim(unsup_model, mention_neighbor_sentence,
                #                          com_utils.cht_to_chs(cand['cand_text'].lower()), stopwords)
                score = gensim_get_sim(
                    unsup_model, mention_neighbor_sentence,
                    com_utils.cht_to_chs(cand['cand_text'].lower()), stopwords)
                if score < fasttextConfig.min_entity_similarity_threshold:
                    continue
                score_list.append({
                    'cand_id': cand['cand_id'],
                    'cand_score': score,
                    'cand_type': cand['cand_type']
                })
            score_list.sort(key=get_socre_key, reverse=True)
            if len(score_list) > 0:
                max_cand = score_list[0]
            # find the best cand
            if max_cand is not None:
                mentions.append({
                    'kb_id': max_cand['cand_id'],
                    'mention': mention['mention'],
                    'offset': mention['offset']
                })
        # optim mentions
        delete_mentions = []
        mentions.sort(key=get_mention_len)
        for optim_mention in mentions:
            mention_offset = int(optim_mention['offset'])
            mention_len = len(optim_mention['mention'])
            for sub_mention in mentions:
                if mention_offset != int(sub_mention['offset']) and int(
                        sub_mention['offset']) in range(
                            mention_offset, mention_offset + mention_len):
                    if not data_utils.is_mention_already_in_list(
                            delete_mentions, sub_mention):
                        delete_mentions.append(sub_mention)
        if len(delete_mentions) > 0:
            change_mentions = []
            for optim_mention in mentions:
                if not data_utils.is_mention_already_in_list(
                        delete_mentions, optim_mention):
                    change_mentions.append(optim_mention)
            mentions = change_mentions
        change_mentions = []
        for optim_mention in mentions:
            if not data_utils.is_mention_already_in_list(
                    change_mentions, optim_mention
            ) and optim_mention['mention'] not in comConfig.punctuation:
                change_mentions.append(optim_mention)
        mentions = change_mentions
        mentions.sort(key=get_mention_offset)
        dev_entity['mention_data'] = mentions
        out_file.write(ujson.dumps(dev_entity, ensure_ascii=False))
        out_file.write('\n')
    print("success create supervised eval result")
Exemplo n.º 20
0
 def get_extend_dev_examples(self):
     lines = com_utils.pickle_load(fileConfig.dir_ner +
                                   fileConfig.file_ner_extend_dev_data)
     examples = self._create_example(lines, nerConfig.mode_dev)
     return examples
Exemplo n.º 21
0
def create_dev_mention_cands_data(index, mention_file, pd_file, alia_kb_df,
                                  out_file):
    print("start create {} mention cands".format(index))
    dev_mention_data = com_utils.pickle_load(mention_file)
    print("{} data length is {}".format(index, len(dev_mention_data)))
    pd_df = pandas.read_csv(pd_file)
    alia_kb_df = pandas.read_csv(alia_kb_df)
    alia_kb_df.fillna('')
    count = 0
    for dev_data in tqdm(dev_mention_data, desc='find {} cands'.format(index)):
        # count += 1
        # if (count < 465):
        #     continue
        mention_data = dev_data['mention_data']
        for mention in mention_data:
            mention_text = mention['mention']
            if mention_text is None:
                continue
            cands = []
            cand_ids = {}
            # match orginal
            mention_text_proc = com_utils.cht_to_chs(mention_text.lower())
            mention_text_proc = com_utils.complete_brankets(mention_text_proc)
            # print(mention_text_proc)
            mention_text_proc_extend = mention_text_proc[
                0:len(mention_text_proc) - 1]
            subject_df = data_utils.pandas_query(pd_df, 'subject',
                                                 mention_text_proc)
            for _, item in subject_df.iterrows():
                s_id = str(item['subject_id'])
                if cand_ids.get(s_id) is not None:
                    continue
                cand_ids[s_id] = 1
                subject = item['subject']
                # text = data_utils.get_text(ast.literal_eval(item['data']), item['subject'])
                text = data_utils.get_all_text(item['subject'],
                                               ast.literal_eval(item['data']))
                cands.append({
                    'cand_id':
                    s_id,
                    'cand_subject':
                    subject,
                    'cand_text':
                    text,
                    'cand_type':
                    com_utils.get_kb_type(ast.literal_eval(item['type']))
                })
            # match more
            # subject_df = data_utils.pandas_query(pd_df, 'subject', mention_text_proc_extend)
            # for _, item in subject_df.iterrows():
            #     s_id = str(item['subject_id'])
            #     if cand_ids.get(s_id) is not None:
            #         continue
            #     cand_ids[s_id] = 1
            #     subject = item['subject']
            #     # text = data_utils.get_text(ast.literal_eval(item['data']), item['subject'])
            #     text = data_utils.get_all_text(item['subject'], ast.literal_eval(item['data']))
            #     cands.append({'cand_id': s_id, 'cand_subject': subject, 'cand_text': text,
            #                   'cand_type': com_utils.get_kb_type(ast.literal_eval(item['type']))})
            # match alias
            alias_subject_ids = []
            # match orginal
            alias_df = data_utils.pandas_query(alia_kb_df, 'subject',
                                               mention_text_proc)
            for _, item in alias_df.iterrows():
                a_id = str(item['subject_id'])
                if alias_subject_ids.__contains__(a_id):
                    continue
                alias_subject_ids.append(a_id)
            # match more
            # alias_df = data_utils.pandas_query(alia_kb_df, 'subject', mention_text_proc_extend)
            # for _, item in alias_df.iterrows():
            #     a_id = str(item['subject_id'])
            #     if alias_subject_ids.__contains__(a_id):
            #         continue
            #     alias_subject_ids.append(a_id)
            for alia_id in alias_subject_ids:
                alias_df = pd_df[pd_df['subject_id'] == int(alia_id)]
                for _, item in alias_df.iterrows():
                    b_id = str(item['subject_id'])
                    if cand_ids.get(b_id) is not None:
                        continue
                    cand_ids[b_id] = 1
                    subject = item['subject']
                    # text = data_utils.get_text(ast.literal_eval(item['data']), item['subject'])
                    text = data_utils.get_all_text(
                        item['subject'], ast.literal_eval(item['data']))
                    cands.append({
                        'cand_id':
                        b_id,
                        'cand_subject':
                        subject,
                        'cand_text':
                        text,
                        'cand_type':
                        com_utils.get_kb_type(ast.literal_eval(item['type']))
                    })
            # match gen subject
            # gen_subject_ids = []
            # for gen_subject in mention['gen_subjects']:
            #     gen_text = com_utils.cht_to_chs(gen_subject.lower())
            #     alias_df = alia_kb_df[alia_kb_df['subject'] == gen_text]
            #     for _, item in alias_df.iterrows():
            #         a_id = str(item['subject_id'])
            #         if gen_subject_ids.__contains__(a_id):
            #             continue
            #         gen_subject_ids.append(a_id)
            #     for alia_id in gen_subject_ids:
            #         alias_df = pd_df[pd_df['subject_id'] == int(alia_id)]
            #         for _, item in alias_df.iterrows():
            #             b_id = str(item['subject_id'])
            #             if cand_ids.get(b_id) is not None:
            #                 continue
            #             cand_ids[b_id] = 1
            #             subject = item['subject']
            #             # text = data_utils.get_text(ast.literal_eval(item['data']), item['subject'])
            #             text = data_utils.get_all_text(item['subject'], ast.literal_eval(item['data']))
            #             cands.append({'cand_id': b_id, 'cand_subject': subject, 'cand_text': text,
            #                           'cand_type': com_utils.get_kb_type(ast.literal_eval(item['type']))})
            mention['cands'] = cands
    com_utils.pickle_save(dev_mention_data, out_file)
    print("success create {} dev data with mention and cands!".format(index))
Exemplo n.º 22
0
 def get_eval_examples(self):
     lines = com_utils.pickle_load(fileConfig.dir_ner +
                                   fileConfig.file_ner_eval_data)
     examples = self._create_example(lines, nerConfig.mode_predict)
     return examples
Exemplo n.º 23
0
def create_dev_mention_data(mode, ner_datas, out_file):
    ner_datas = com_utils.pickle_load(ner_datas)
    jieba_dict = com_utils.pickle_load(fileConfig.dir_kb_info +
                                       fileConfig.file_jieba_kb)
    stopwords = data_utils.get_stopword_list(fileConfig.dir_stopword +
                                             fileConfig.file_stopword)
    gen_more_words = data_utils.get_stopword_list(
        fileConfig.dir_stopword + fileConfig.file_analysis_gen_more)
    text_id = 1
    dev_mention_data = []
    # count = 0
    for data in tqdm(ner_datas, 'find entity'):
        # count += 1
        # if count < 1496:
        #     continue
        text = ''.join(data['text'])
        tag_list = data['tag']
        start_index = 0
        mention_length = 0
        is_find = False
        mentions = []
        type_dict = {}
        # use tag find
        for i, tag in enumerate(tag_list):
            # if tag == nerConfig.B_seg + nerConfig.KB_seg:
            if tag.find(nerConfig.B_seg) > -1 or (
                    tag.find(nerConfig.I_seg) > -1 and not is_find):
                type_str = tag.split('_')[1]
                type_dict = com_utils.dict_add(type_dict, type_str)
                start_index = i
                mention_length = 1
                is_find = True
            elif tag.find(nerConfig.E_seg) > -1 and not is_find:
                type_str = tag.split('_')[1]
                type_dict = com_utils.dict_add(type_dict, type_str)
                start_index = i
                mention_length += 1
                mention = text[start_index:start_index + mention_length]
                mention = data_utils.strip_punctuation(mention)
                type_list = Counter(type_dict).most_common()
                mentions.append({
                    'mention': mention,
                    'offset': str(start_index),
                    'type': type_list[0][0]
                })
                is_find = False
                mention_length = 0
                type_dict = {}
            # elif tag == nerConfig.I_seg + nerConfig.KB_seg and is_find:
            elif tag.find(nerConfig.I_seg) > -1 and is_find:
                type_str = tag.split('_')[1]
                type_dict = com_utils.dict_add(type_dict, type_str)
                mention_length += 1
            # elif tag == nerConfig.E_seg + nerConfig.KB_seg and is_find:
            elif tag.find(nerConfig.E_seg) > -1 and is_find:
                type_str = tag.split('_')[1]
                type_dict = com_utils.dict_add(type_dict, type_str)
                mention_length += 1
                mention = text[start_index:start_index + mention_length]
                mention = data_utils.strip_punctuation(mention)
                type_list = Counter(type_dict).most_common()
                mentions.append({
                    'mention': mention,
                    'offset': str(start_index),
                    'type': type_list[0][0]
                })
                is_find = False
                mention_length = 0
                type_dict = {}
            elif tag == nerConfig.O_seg:
                is_find = False
                mention_length = 0
                type_dict = {}
        # use jieba find
        jieba_entities = cut_client.cut_text(text)
        for i, tag in enumerate(tag_list):
            # if tag == nerConfig.B_seg + nerConfig.KB_seg or tag == nerConfig.I_seg + nerConfig.KB_seg or tag == nerConfig.E_seg + nerConfig.KB_seg:
            if tag.find(nerConfig.B_seg) > -1 or tag.find(
                    nerConfig.I_seg) > -1 or tag.find(nerConfig.E_seg) > -1:
                jieba_offset = i
                jieba_char = text[i]
                jieba_text = get_jieba_mention(jieba_entities, jieba_char,
                                               jieba_offset)
                if jieba_text is None:
                    continue
                elif jieba_text == '_' or jieba_text == '-':
                    continue
                elif data_utils.is_punctuation(jieba_text):
                    continue
                elif len(jieba_text) == 1:
                    continue
                elif stopwords.get(jieba_text) is not None:
                    continue
                # elif gen_more_words.get(jieba_text) is not None:
                #     continue
                jieba_offset = jieba_offset - jieba_text.find(jieba_char)
                if len(jieba_text) <= comConfig.max_jieba_cut_len and (
                        jieba_dict.get(jieba_text) is not None):
                    type_str = tag.split('_')[1] if tag.find('_') > -1 else 'O'
                    if jieba_text is None:
                        continue
                    if not is_already_find_mention(mentions, jieba_text,
                                                   jieba_offset):
                        mentions.append({
                            'mention': jieba_text,
                            'offset': str(jieba_offset),
                            'type': type_str
                        })
        # find inner brackets mentions
        bracket_mentions = data_utils.get_mention_inner_brackets(
            text, tag_list)
        if len(bracket_mentions) > 0:
            mentions += bracket_mentions
        # completion mentions
        # mentions_com = []
        # for mention in mentions:
        #     mention_str = mention['mention']
        #     try:
        #         for find in re.finditer(mention_str, text):
        #             find_offset = find.span()[0]
        #             if find_offset != int(mention['offset']):
        #                 mentions_com.append(
        #                     {'mention': mention['mention'], 'offset': str(find_offset), 'type': mention['type']})
        #     except BaseException:
        #         # print("occur error when match mention str in completion mentions, error value:{} text:{}".format(
        #         #     mention_str, text))
        #         pass
        #     mentions_com.append(mention)
        # mentions = mentions_com
        # optim mentions
        delete_mentions = []
        mentions.sort(key=get_mention_len)
        for mention in mentions:
            mention_offset = int(mention['offset'])
            mention_len = len(mention['mention'])
            for sub_mention in mentions:
                if mention_offset != int(sub_mention['offset']) and int(
                        sub_mention['offset']) in range(
                            mention_offset, mention_offset + mention_len):
                    if not data_utils.is_mention_already_in_list(
                            delete_mentions, sub_mention):
                        delete_mentions.append(sub_mention)
                if mention_offset == int(sub_mention['offset']) and len(
                        mention['mention']) > len(sub_mention['mention']):
                    if not data_utils.is_mention_already_in_list(
                            delete_mentions, sub_mention):
                        delete_mentions.append(sub_mention)
        if len(delete_mentions) > 0:
            change_mentions = []
            for mention in mentions:
                if not data_utils.is_mention_already_in_list(
                        delete_mentions, mention):
                    change_mentions.append(mention)
            mentions = change_mentions
        change_mentions = []
        for mention in mentions:
            if not data_utils.is_mention_already_in_list(
                    change_mentions, mention
            ) and mention['mention'] not in comConfig.punctuation:
                change_mentions.append(mention)
        mentions = change_mentions
        # optim mentions
        # sort mentions
        mentions.sort(key=get_offset)
        # optimize the mention data
        mentions_optim = []
        for mention in mentions:
            mentions_optim.append({
                'mention':
                get_optim_mention_text(jieba_entities, mention['mention']),
                'offset':
                mention['offset'],
                'type':
                mention['type']
            })
        if mode == 1:
            dev_mention_data.append({
                'text_id': str(text_id),
                'text': text,
                'mention_data': mentions_optim
            })
        elif mode == 2:
            dev_mention_data.append({
                'text_id':
                str(text_id),
                'text':
                text,
                'mention_data':
                mentions_optim,
                'mention_data_original':
                data['mention_data_original']
            })
        elif mode == 3:
            dev_mention_data.append({
                'text_id': str(text_id),
                'text': text,
                'mention_data': mentions_optim
            })
        text_id += 1
    com_utils.pickle_save(dev_mention_data, out_file)
    print("success create dev data with mentions, mode:{}".format(mode))