Exemplo n.º 1
0
def text_match(attribute_list, answer_list, sentence):

    assert len(attribute_list) == len(answer_list)

    idx = -1
    for i, attribute in enumerate(attribute_list):
        if attribute in sentence:
            idx = i
            break
    if -1 != idx:
        return attribute_list[idx], answer_list[idx]
    else:
        return "", ""


ner_processor = NerProcessor()
# sim_processor = SimProcessor()
tokenizer_kwards = {
    'do_lower_case': False,
    'max_len': 64,
    # 'vocab_file': args.vob_file
}
tokenizer = BertTokenizer.from_pretrained('voidful/albert_chinese_tiny',
                                          **tokenizer_kwards)

ner_model = get_ner_model(config_file='voidful/albert_chinese_tiny',
                          pre_train_model='best_ner.bin',
                          label_num=len(ner_processor.get_labels()))
ner_model = ner_model.to(device)
ner_model.eval()
Exemplo n.º 2
0
def main():
    tokenizer_inputs = ()
    tokenizer_kwards = {
        'do_lower_case': False,
        'max_len': 64,
        'vocab_file': './input/config/bert-base-chinese-vocab.txt'
    }
    ner_processor = NerProcessor()  # 就是读取文件
    sim_processor = SimProcessor()
    tokenizer = BertTokenizer(*tokenizer_inputs, **tokenizer_kwards)

    ner_model = get_ner_model(
        config_file='./input/config/bert-base-chinese-config.json',
        pre_train_model='./output/best_ner.bin',
        label_num=len(ner_processor.get_labels()))
    ner_model = ner_model.to(device)
    # sent = "我想知道戴维斯是什么国家的人?"
    # sent = "你知道因为有你是谁的作的曲吗?"
    # sent = "陈刚哪里人"
    # sent = "神雕侠侣是什么类型"
    # sent = "李鑫别名"
    # sent = "王磊生辰"
    sent = "西游记作者的"

    entity = get_entity(model=ner_model,
                        tokenizer=tokenizer,
                        sentence=sent,
                        max_len=64)
    print("entity", entity)
    if '' == entity:
        print("未发现实体")
        return

    sql_str = "select * from nlpccqa where entity = '{}'".format(entity)

    triple_list = select_database(sql_str)
    triple_list = list(triple_list)
    if 0 == len(triple_list):
        print("未找到 {} 相关信息".format(entity))
        return
    triple_list = list(zip(*triple_list))
    print("triple_list:", triple_list)
    attribute_list = triple_list[1]
    answer_list = triple_list[2]

    attribute, answer = text_match(attribute_list, answer_list, sent)  #

    if attribute != '' and answer != '':
        ret = "直接匹配出来:{}的{}是{}".format(entity, attribute, answer)

    else:
        sim_model = get_sim_model(
            config_file='./input/config/bert-base-chinese-config.json',
            pre_train_model='./output/best_sim.bin',
            label_num=len(sim_processor.get_labels()))

        sim_model = sim_model.to(device)
        sim_model.eval()
        attribute_idx = semantic_matching(sim_model, tokenizer, sent,
                                          attribute_list, answer_list,
                                          64).item()
        # code.interact(local = locals())
        if -1 == attribute_idx:
            ret = ''
        else:
            attribute = attribute_list[attribute_idx]
            answer = answer_list[attribute_idx]
            ret = "语义匹配:{}的{}是{}".format(entity, attribute, answer)
    if '' == ret:
        print("未找到{}相关信息".format(entity))
    else:
        print(ret)
Exemplo n.º 3
0
from BERT_CRF import BertCrf
from transformers import BertTokenizer
from NER_main import NerProcessor, statistical_real_sentences, flatten, CrfInputFeatures
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from sklearn.metrics import classification_report
import torch
import numpy as np
from tqdm import tqdm, trange

processor = NerProcessor()
tokenizer_inputs = ()
tokenizer_kwards = {
    'do_lower_case': False,
    'max_len': 64,
    'vocab_file': './input/config/bert-base-chinese-vocab.txt'
}
tokenizer = BertTokenizer(*tokenizer_inputs, **tokenizer_kwards)

model = BertCrf(config_name='./input/config/bert-base-chinese-config.json',
                num_tags=len(processor.get_labels()),
                batch_first=True)
model.load_state_dict(torch.load('./output/best_ner.bin'))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# features = torch.load(cached_features_file)
features = torch.load('./input/data/ner_data/cached_dev_64')

all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features],
Exemplo n.º 4
0
def main():

    with torch.no_grad():
        tokenizer_inputs = ()
        tokenizer_kwards = {'do_lower_case': False,
                            'max_len': 64,
                            'vocab_file': './input/config/bert-base-chinese-vocab.txt'}
        ner_processor = NerProcessor()
        sim_processor = SimProcessor()
        tokenizer = BertTokenizer(*tokenizer_inputs, **tokenizer_kwards)


        ner_model = get_ner_model(config_file = './input/config/bert-base-chinese-config.json',
                                  pre_train_model = './output/best_ner.bin',label_num = len(ner_processor.get_labels()))
        ner_model = ner_model.to(device)
        ner_model.eval()

        sim_model = get_sim_model(config_file='./input/config/bert-base-chinese-config.json',
                                  pre_train_model='./output/best_sim.bin',
                                  label_num=len(sim_processor.get_labels()))

        sim_model = sim_model.to(device)
        sim_model.eval()

        while True:
            print("====="*10)
            raw_text = input("问题:\n")
            raw_text = raw_text.strip()
            if ( "quit" == raw_text ):
                print("quit")
                return
            entity = get_entity(model=ner_model, tokenizer=tokenizer, sentence=raw_text, max_len=64)
            print("实体:", entity)
            if '' == entity:
                print("未发现实体")
                continue
            sql_str = "select * from nlpccqa where entity = '{}'".format(entity)
            triple_list = select_database(sql_str)
            triple_list = list(triple_list)
            if 0 == len(triple_list):
                print("未找到 {} 相关信息".format(entity))
                continue
            triple_list = list(zip(*triple_list))
            # print(triple_list)
            attribute_list = triple_list[1]
            answer_list = triple_list[2]
            attribute, answer = text_match(attribute_list, answer_list, raw_text)
            if attribute != '' and answer != '':
                ret = "{}的{}是{}".format(entity, attribute, answer)
            else:
                sim_model = get_sim_model(config_file='./input/config/bert-base-chinese-config.json',
                                          pre_train_model='./output/best_sim.bin',
                                          label_num=len(sim_processor.get_labels()))

                sim_model = sim_model.to(device)
                sim_model.eval()
                attribute_idx = semantic_matching(sim_model, tokenizer, raw_text, attribute_list, answer_list, 64).item()
                if -1 == attribute_idx:
                    ret = ''
                else:
                    attribute = attribute_list[attribute_idx]
                    answer = answer_list[attribute_idx]
                    ret = "{}的{}是{}".format(entity, attribute, answer)
            if '' == ret:
                print("未找到{}相关信息".format(entity))
            else:
                print("回答:",ret)
Exemplo n.º 5
0
def main():
    tokenizer_inputs = ()
    tokenizer_kwards = {
        'do_lower_case':
        False,
        'max_len':
        64,
        'vocab_file':
        '/home/daiyizheng/.cache/torch/transformers/bert-pretrainmodel/bert/bert-base-chinese/vocab.txt'
    }
    ner_processor = NerProcessor()  # 就是读取文件
    sim_processor = SimProcessor()
    tokenizer = BertTokenizer(*tokenizer_inputs, **tokenizer_kwards)

    ner_model = get_ner_model(
        config_file=
        '/home/daiyizheng/.cache/torch/transformers/bert-pretrainmodel/bert/bert-base-chinese/config.json',
        pre_train_model='./output/best_ner.bin',
        label_num=len(ner_processor.get_labels()))
    ner_model = ner_model.to(device)
    sent = "高催乳素血症属于什么?"
    # sent = "你知道因为有你是谁的作的曲吗?"
    # sent = "陈刚哪里人"
    # sent = "神雕侠侣是什么类型"
    # sent = "李鑫别名"
    # sent = "王磊生辰"
    # sent = "西游记作者的"

    entity = get_entity(model=ner_model,
                        tokenizer=tokenizer,
                        sentence=sent,
                        max_len=64)
    print("entity", entity)
    if '' == entity:
        print("未发现实体")
        return

    # Mysql数据库
    # sql_str = 'select * from nlpccQA where entity = "{}"'.format(entity)
    # triple_list = select_database(sql_str)

    conn = Neo4jObj()
    sql_str = 'match (p:Entity)-[r]->(d:Entity) WHERE p.name=~".*{}.*" return r.name, d.name'.format(
        entity)
    triple_dict = conn.query(sql_str)
    triple_list = []

    relations2answers = {}
    for item in triple_dict:
        trip = re.sub("\]|\[|'|>|<|", "", item['r.name'])
        triple_list.append(trip)
        if trip in relations2answers:
            relations2answers[trip].append(
                re.sub("\]|\[|'|>|<|", "", item['d.name']))
        else:
            relations2answers[trip] = [
                re.sub("\]|\[|'|>|<|", "", item['d.name'])
            ]

    triple_list = list(set(triple_list))

    triple_list = list(triple_list)
    if 0 == len(triple_list):
        print("未找到 {} 相关信息".format(entity))
        return
    # triple_list = list(zip(*triple_list))
    print("triple_list:", triple_list)
    # attribute_list = triple_list[1]
    # answer_list = triple_list[2]
    attribute_list = triple_list
    answer_list = relations2answers

    attribute, answer = text_match(attribute_list, answer_list, sent)  #

    if attribute != '' and answer != '':
        ret = "直接匹配出来:{}的{}是{}".format(entity, attribute, answer)

    else:
        sim_model = get_sim_model(
            config_file=
            '/home/daiyizheng/.cache/torch/transformers/bert-pretrainmodel/bert/bert-base-chinese/config.json',
            pre_train_model='./output/best_sim.bin',
            label_num=len(sim_processor.get_labels()))

        sim_model = sim_model.to(device)
        sim_model.eval()
        attribute_idx = semantic_matching(sim_model, tokenizer, sent,
                                          attribute_list, answer_list,
                                          56).item()
        # code.interact(local = locals())
        if -1 == attribute_idx:
            ret = ''
        else:
            attribute = attribute_list[attribute_idx]
            answer = answer_list[attribute]
            ret = "语义匹配:{}的{}是{}".format(entity, attribute, str(answer))
    if '' == ret:
        print("未找到{}相关信息".format(entity))
    else:
        print(ret)
Exemplo n.º 6
0
def main():
    with torch.no_grad():
        tokenizer_inputs = ()
        tokenizer_kwards = {
            'do_lower_case': False,
            'max_len': 128,
            'vocab_file': './input/config/bert-base-chinese-vocab.txt'
        }
        ner_processor = NerProcessor()
        sim_processor = SimProcessor()
        tokenizer = BertTokenizer(*tokenizer_inputs, **tokenizer_kwards)

        ner_model = get_ner_model(
            config_file='./input/config/bert-base-chinese-config.json',
            pre_train_model='./output/best_ner.bin',
            label_num=len(ner_processor.get_labels()))
        ner_model = ner_model.to(device)
        ner_model.eval()

        sim_model = get_sim_model(
            config_file='./input/config/bert-base-chinese-config.json',
            pre_train_model='./output/best_sim.bin',
            label_num=len(sim_processor.get_labels()))

        sim_model = sim_model.to(device)
        sim_model.eval()

        # 寫入答案
        fq = open("./input/data/NLPCC2016KBQA/Kbqa.testing-data",
                  'r',
                  encoding='utf8')
        i = 1
        timestart = time.time()
        fo = open("./input/data/NLPCC2016KBQA/04_Kbqa.testing-data",
                  'w',
                  encoding='utf8')
        fo.close()
        listq = []
        for line in fq:
            if line[1] == 'q':
                listq.append(line[line.index('\t') + 1:].strip())
        print("ListQ ready! Start to predict.....")
        for q in listq:
            fo = open("./input/data/NLPCC2016KBQA/04_Kbqa.testing-data",
                      'a',
                      encoding='utf8')
            q = q.lower()
            fo.write('<question id=' + str(i) + '>\t' + q + '\n')
            entity = get_entity(model=ner_model,
                                tokenizer=tokenizer,
                                sentence=q,
                                max_len=128)

            if len(entity) != 0:
                triple_list = select_triple(entity)

                if len(triple_list) > 0:
                    attribute_list = list(triple_list["attribute"])
                    answer_list = list(triple_list["answer"])

                    attribute, answer = text_match(attribute_list, answer_list,
                                                   q)

                    if attribute != '' and answer != '':
                        ret = answer
                    else:
                        attribute_idx = semantic_matching(
                            sim_model, tokenizer, q, attribute_list,
                            answer_list, 128).item()

                        if -1 == attribute_idx:
                            ret = ''
                        else:
                            ret = answer_list[attribute_idx]
                else:
                    ret = ''
                # print("問題{}{}的答案是{}".format(i, q, ret))
                fo.write('<answer id=' + str(i) + '>\t')
                fo.write(str(ret))
                fo.write(
                    '\n==================================================\n')
            else:
                fo.write('<answer id=' + str(i) + '>\t')
                fo.write(
                    '\n==================================================\n')

            print('processing ' + str(i) + 'th Q.\tAv time cost: ' +
                  str((time.time() - timestart) / i)[:6] + ' sec')
            fo.close()
            i += 1
        print('Finished prediction.')
        fq.close()