コード例 #1
0
ファイル: kbqa_test.py プロジェクト: Estherbdf/SafeGo
import pandas as pd
import urllib.request
import urllib.parse
import tensorflow as tf
# 连接数据库
from Data.load_dbdata import upload_data
from global_config import Logger

from run_similarity import BertSim
# 模块导入 https://blog.csdn.net/xiongchengluo1129/article/details/80453599

loginfo = Logger("recommend_articles.log", "info")
file = "./Data/NER_Data/q_t_a_testing_predict.txt"

bs = BertSim()
bs.set_mode(tf.estimator.ModeKeys.PREDICT)


def dataset_test():
    '''
    用训练问答对中的实体+属性,去知识库中进行问答测试准确率上限
    :return:
    '''
    with open(file) as f:
        total = 0
        recall = 0
        correct = 0

        for line in f:
            question, entity, attribute, answer, ner = line.split("\t")
            ner = ner.replace("#", "").replace("[UNK]", "%")
コード例 #2
0
ファイル: qa_my.py プロジェクト: wangbq18/KBQA-BERT
def predict_online():
    """
    do online prediction. each time make prediction for one instance.
    you can change to a batch if you want.

    :param line: a list. element is: [dummy_label,text_a,text_b]
    :return:
    """

    #driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "Nic180319"))
    def convert(line):
        feature = convert_single_example(0, line, label_list,
                                         FLAGS.max_seq_length, tokenizer, 'p')
        input_ids = np.reshape([feature.input_ids],
                               (batch_size, FLAGS.max_seq_length))
        input_mask = np.reshape([feature.input_mask],
                                (batch_size, FLAGS.max_seq_length))
        segment_ids = np.reshape([feature.segment_ids],
                                 (batch_size, FLAGS.max_seq_length))
        label_ids = np.reshape([feature.label_ids],
                               (batch_size, FLAGS.max_seq_length))
        return input_ids, input_mask, segment_ids, label_ids

    global graph
    with graph.as_default():
        print(id2label)
        while True:
            print('input the test sentence:')
            sentence_l = input()
            sentence = str(sentence_l)
            start = datetime.now()
            if len(sentence) < 2:
                print(sentence)
                continue
            sentence = tokenizer.tokenize(sentence)
            # print('your input is:{}'.format(sentence))
            input_ids, input_mask, segment_ids, label_ids = convert(sentence)

            feed_dict = {
                input_ids_p: input_ids,
                input_mask_p: input_mask,
                segment_ids_p: segment_ids,
                label_ids_p: label_ids
            }
            # run session get current feed_dict result
            pred_ids_result = sess.run([pred_ids], feed_dict)
            pred_label_result = convert_id_to_label(pred_ids_result, id2label)
            print(pred_label_result)
            #todo: 组合策略
            result = strage_combined_link_org_loc(sentence,
                                                  pred_label_result[0], True)
            print('识别的实体有:{}'.format(' '.join(result)))
            #print('Time used: {} sec'.format((datetime.now() - start).seconds))

            #   yueuu
            #driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "Nic180319"))
            hd_graph = Klg()
            all_rel = hd_graph.old_allrel(driver, name=result[0])
            print('知识图谱中跟实体关联的所有关系为:', all_rel)

            time.sleep(5)

            sim = BertSim()
            sim.set_mode(tf.estimator.ModeKeys.PREDICT)
            sim_score = []
            for j in range(len(all_rel)):
                sim_score.append(sim.predict(sentence_l, all_rel[j])[0][1])
            for j in range(len(sim_score)):
                print(all_rel[j], f'similarity:{sim_score[j]}')
            max_idx = sim_score.index(max(sim_score))
            print('相似度最高的关系为:', all_rel[max_idx])

            answer = hd_graph.find(driver, result[0], all_rel[max_idx])
            if answer:
                print("answer:", answer)
            else:
                print("不知道")