def pred_tet(path_hyper_parameter=path_hyper_parameters, path_test=None, rate=1.0): """ 测试集测试与模型评估 :param hyper_parameters: json, 超参数 :param path_test:str, path of test data, 测试集 :param rate: 比率, 抽出rate比率语料取训练 :return: None """ hyper_parameters = load_json(path_hyper_parameter) if path_test: # 从外部引入测试数据地址 hyper_parameters['data']['val_data'] = path_test time_start = time.time() # graph初始化 graph = Graph(hyper_parameters) print("graph init ok!") graph.load_model() print("graph load ok!") ra_ed = graph.word_embedding # 数据预处理 pt = PreprocessText() y, x = read_and_process(hyper_parameters['data']['val_data']) # 取该数据集的百分之几的语料测试 len_rate = int(len(y) * rate) x = x[1:len_rate] y = y[1:len_rate] y_pred = [] count = 0 for x_one in x: count += 1 ques_embed = ra_ed.sentence2idx(x_one) if hyper_parameters['embedding_type'] == 'bert': # bert数据处理, token x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed # 预测 pred = graph.predict(x_val) pre = pt.prereocess_idx(pred[0]) label_pred = pre[0][0][0] if count % 1000 == 0: print(label_pred) y_pred.append(label_pred) print("data pred ok!") # 预测结果转为int类型 index_y = [pt.l2i_i2l['l2i'][i] for i in y] index_pred = [pt.l2i_i2l['l2i'][i] for i in y_pred] target_names = [ pt.l2i_i2l['i2l'][str(i)] for i in list(set((index_pred + index_y))) ] # 评估 report_predict = classification_report(index_y, index_pred, target_names=target_names, digits=9) print(report_predict) print("耗时:" + str(time.time() - time_start))
def pred_input(path_hyper_parameter=path_hyper_parameters): """ 输入预测 :param path_hyper_parameter: str, 超参存放地址 :return: None """ # 加载超参数 hyper_parameters = load_json(path_hyper_parameter) pt = PreprocessText(path_model_dir) # 模式初始化和加载 graph = Graph(hyper_parameters) graph.load_model() ra_ed = graph.word_embedding ques = '我要打王者荣耀' # str to token ques_embed = ra_ed.sentence2idx(ques) if hyper_parameters['embedding_type'] in ['bert', 'albert']: x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed # 预测 pred = graph.predict(x_val) # 取id to label and pred pre = pt.prereocess_idx(pred[0]) print(pre) while True: print("请输入: ") ques = input() ques_embed = ra_ed.sentence2idx(ques) print(ques_embed) if hyper_parameters['embedding_type'] in ['bert', 'albert']: x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed pred = graph.predict(x_val) pre = pt.prereocess_idx(pred[0]) print(pre)
'embedding_type': 'random', 'is_training': False, 'model_path': path_model_fast_text_baiduqa_2019, }, 'embedding': { 'embedding_type': 'random', 'corpus_path': path_embedding_random_char, 'level_type': 'char', 'embed_size': 300, 'len_max': 50, }, } # ns = np.array([1,2,3,4]) # print(type(ns)) pt = PreprocessText graph = FastTextGraph(hyper_parameters) graph.load_model() ra_ed = graph.word_embedding ques = '你好呀' ques_embed = ra_ed.sentence2idx(ques) pred = graph.predict(np.array([ques_embed])) pre = pt.prereocess_idx(pred[0]) print(pre) while True: print("请输入: ") ques = input() ques_embed = ra_ed.sentence2idx(ques) pred = graph.predict(np.array([ques_embed])) pre = pt.prereocess_idx(pred[0]) print(pre)