import pandas as pd import urllib.request import urllib.parse import tensorflow as tf # 连接数据库 from Data.load_dbdata import upload_data from global_config import Logger from run_similarity import BertSim # 模块导入 https://blog.csdn.net/xiongchengluo1129/article/details/80453599 loginfo = Logger("recommend_articles.log", "info") file = "./Data/NER_Data/q_t_a_testing_predict.txt" bs = BertSim() bs.set_mode(tf.estimator.ModeKeys.PREDICT) def dataset_test(): ''' 用训练问答对中的实体+属性,去知识库中进行问答测试准确率上限 :return: ''' with open(file) as f: total = 0 recall = 0 correct = 0 for line in f: question, entity, attribute, answer, ner = line.split("\t") ner = ner.replace("#", "").replace("[UNK]", "%")
def predict_online(): """ do online prediction. each time make prediction for one instance. you can change to a batch if you want. :param line: a list. element is: [dummy_label,text_a,text_b] :return: """ #driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "Nic180319")) def convert(line): feature = convert_single_example(0, line, label_list, FLAGS.max_seq_length, tokenizer, 'p') input_ids = np.reshape([feature.input_ids], (batch_size, FLAGS.max_seq_length)) input_mask = np.reshape([feature.input_mask], (batch_size, FLAGS.max_seq_length)) segment_ids = np.reshape([feature.segment_ids], (batch_size, FLAGS.max_seq_length)) label_ids = np.reshape([feature.label_ids], (batch_size, FLAGS.max_seq_length)) return input_ids, input_mask, segment_ids, label_ids global graph with graph.as_default(): print(id2label) while True: print('input the test sentence:') sentence_l = input() sentence = str(sentence_l) start = datetime.now() if len(sentence) < 2: print(sentence) continue sentence = tokenizer.tokenize(sentence) # print('your input is:{}'.format(sentence)) input_ids, input_mask, segment_ids, label_ids = convert(sentence) feed_dict = { input_ids_p: input_ids, input_mask_p: input_mask, segment_ids_p: segment_ids, label_ids_p: label_ids } # run session get current feed_dict result pred_ids_result = sess.run([pred_ids], feed_dict) pred_label_result = convert_id_to_label(pred_ids_result, id2label) print(pred_label_result) #todo: 组合策略 result = strage_combined_link_org_loc(sentence, pred_label_result[0], True) print('识别的实体有:{}'.format(' '.join(result))) #print('Time used: {} sec'.format((datetime.now() - start).seconds)) # yueuu #driver = GraphDatabase.driver("bolt://localhost:7687", auth=("neo4j", "Nic180319")) hd_graph = Klg() all_rel = hd_graph.old_allrel(driver, name=result[0]) print('知识图谱中跟实体关联的所有关系为:', all_rel) time.sleep(5) sim = BertSim() sim.set_mode(tf.estimator.ModeKeys.PREDICT) sim_score = [] for j in range(len(all_rel)): sim_score.append(sim.predict(sentence_l, all_rel[j])[0][1]) for j in range(len(sim_score)): print(all_rel[j], f'similarity:{sim_score[j]}') max_idx = sim_score.index(max(sim_score)) print('相似度最高的关系为:', all_rel[max_idx]) answer = hd_graph.find(driver, result[0], all_rel[max_idx]) if answer: print("answer:", answer) else: print("不知道")