def pred_input(str_input, path_hyper_parameter=path_hyper_parameters): # 输入预测 # 加载超参数 hyper_parameters = load_json(path_hyper_parameter) pt = PreprocessTextMulti() # 模式初始化和加载 graph = Graph(hyper_parameters) graph.load_model() ra_ed = graph.word_embedding ques = str_input # str to token ques_embed = ra_ed.sentence2idx(ques) if hyper_parameters['embedding_type'] == 'bert': x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed # 预测 pred = graph.predict(x_val) print(pred) # 取id to label and pred pre = pt.prereocess_idx(pred[0]) ls_nulti = [] for ls in pre[0]: if ls[1] >= 0.73: ls_nulti.append(ls) print(str_input) print(pre[0]) print(ls_nulti)
def pred_tet(path_hyper_parameter=path_hyper_parameters, path_test=None, rate=1.0): """ 测试集测试与模型评估 :param hyper_parameters: json, 超参数 :param path_test:str, path of test data, 测试集 :param rate: 比率, 抽出rate比率语料取训练 :return: None """ hyper_parameters = load_json(path_hyper_parameter) if path_test: # 从外部引入测试数据地址 hyper_parameters['data']['val_data'] = path_test time_start = time.time() # graph初始化 graph = Graph(hyper_parameters) print("graph init ok!") graph.load_model() print("graph load ok!") ra_ed = graph.word_embedding # 数据预处理 pt = PreprocessText() y, x = read_and_process(hyper_parameters['data']['val_data']) # 取该数据集的百分之几的语料测试 len_rate = int(len(y) * rate) x = x[1:len_rate] y = y[1:len_rate] y_pred = [] count = 0 for x_one in x: count += 1 ques_embed = ra_ed.sentence2idx(x_one) if hyper_parameters['embedding_type'] == 'bert': # bert数据处理, token x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed # 预测 pred = graph.predict(x_val) pre = pt.prereocess_idx(pred[0]) label_pred = pre[0][0][0] if count % 1000 == 0: print(label_pred) y_pred.append(label_pred) print("data pred ok!") # 预测结果转为int类型 index_y = [pt.l2i_i2l['l2i'][i] for i in y] index_pred = [pt.l2i_i2l['l2i'][i] for i in y_pred] target_names = [ pt.l2i_i2l['i2l'][str(i)] for i in list(set((index_pred + index_y))) ] # 评估 report_predict = classification_report(index_y, index_pred, target_names=target_names, digits=9) print(report_predict) print("耗时:" + str(time.time() - time_start))
def pred_input(path_hyper_parameter=path_hyper_parameters): # 输入预测 # 加载超参数 hyper_parameters = load_json(path_hyper_parameter) pt = PreprocessTextMulti(path_model_dir) # 模式初始化和加载 graph = Graph(hyper_parameters) graph.load_model() ra_ed = graph.word_embedding ques = '我要打王者荣耀' # str to token ques_embed = ra_ed.sentence2idx(ques) if hyper_parameters['embedding_type'] in ['bert', 'albert']: x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed # 预测 pred = graph.predict(x_val) print(pred) # 取id to label and pred pre = pt.prereocess_idx(pred[0]) ls_nulti = [] for ls in pre[0]: if ls[1] >= 0.5: ls_nulti.append(ls) print(pre[0]) print(ls_nulti) while True: print("请输入: ") ques = input() ques_embed = ra_ed.sentence2idx(ques) print(ques_embed) if hyper_parameters['embedding_type'] in ['bert', 'albert']: x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed pred = graph.predict(x_val) pre = pt.prereocess_idx(pred[0]) ls_nulti = [] for ls in pre[0]: if ls[1] >= 0.5: ls_nulti.append(ls) print(pre[0]) print(ls_nulti)
def pred_input(path_hyper_parameter=path_hyper_parameters): """ 输入预测 :param path_hyper_parameter: str, 超参存放地址 :return: None """ # 加载超参数 hyper_parameters = load_json(path_hyper_parameter) pt = PreprocessSim(path_model_dir) # 模式初始化和加载 graph = Graph(hyper_parameters) graph.load_model() ra_ed = graph.word_embedding sen1 = '我要打王者荣耀' sen2 = '我要打梦幻西游' # str to token ques_embed = ra_ed.sentence2idx(text=sen1, second_text=sen2) if hyper_parameters['embedding_type'] in ['bert', 'albert']: x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] # 预测 pred = graph.predict(x_val) # 取id to label and pred pre = pt.prereocess_idx(pred[0]) print(pre) while True: print("请输入sen1: ") sen1 = input() print("请输入sen2: ") sen2 = input() ques_embed = ra_ed.sentence2idx(text=sen1, second_text=sen2) print(ques_embed) if hyper_parameters['embedding_type'] in ['bert', 'albert']: x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] pred = graph.predict(x_val) pre = pt.prereocess_idx(pred[0]) print(pre) else: print("error, just support bert or albert") else: print("error, just support bert or albert")
def pred_input(path_hyper_parameter=path_hyper_parameters): """ 输入预测 :param path_hyper_parameter: str, 超参存放地址 :return: None """ # 加载超参数 hyper_parameters = load_json(path_hyper_parameter) pt = PreprocessText() # 模式初始化和加载 graph = Graph(hyper_parameters) graph.load_model() ra_ed = graph.word_embedding ques = '我要打王者荣耀' # str to token ques_embed = ra_ed.sentence2idx(ques) if hyper_parameters['embedding_type'] == 'bert': x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed # 预测 pred = graph.predict(x_val) # 取id to label and pred pre = pt.prereocess_idx(pred[0]) print(pre) while True: print("请输入: ") ques = input() ques_embed = ra_ed.sentence2idx(ques) print(ques_embed) if hyper_parameters['embedding_type'] == 'bert': x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed pred = graph.predict(x_val) pre = pt.prereocess_idx(pred[0]) print(pre)
def preprocess_label_ques_to_idx(self, embedding_type, batch_size, path, embed, rate=1, epcoh=20): label_set, len_all = self.preprocess_get_label_set(path) # 获取label转index字典等, 如果label2index存在则不转换了, dev验证集合的时候用 if not os.path.exists(self.path_fast_text_model_l2i_i2l): count = 0 label2index = {} index2label = {} for label_one in label_set: label2index[label_one] = count index2label[count] = label_one count = count + 1 l2i_i2l = {} l2i_i2l['l2i'] = label2index l2i_i2l['i2l'] = index2label save_json(l2i_i2l, self.path_fast_text_model_l2i_i2l) else: l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l) # 读取数据的比例 len_ql = int(rate * len_all) if len_ql <= 500: # sample时候不生效,使得语料足够训练 len_ql = len_all def process_line(line): # 对每一条数据操作,获取label和问句index line_sp = line.split(",") ques = str(line_sp[1]).strip().upper() label = str(line_sp[0]).strip().upper() label = "NAN" if label == "" else label que_embed = embed.sentence2idx(ques) label_zeros = [0] * len(l2i_i2l['l2i']) label_zeros[l2i_i2l['l2i'][label]] = 1 return que_embed, label_zeros for _ in range(epcoh): while True: file_csv = open(path, "r", encoding="utf-8") cout_all_line = 0 cnt = 0 x, y = [], [] # 跳出循环 if len_ql < cout_all_line: break for line in file_csv: cout_all_line += 1 if cout_all_line > 1: # 第一条是标签'label,ques',不选择 x_line, y_line = process_line(line) x.append(x_line) y.append(y_line) cnt += 1 if cnt == batch_size: if embedding_type in ['bert', 'albert']: x_, y_ = np.array(x), np.array(y) x_1 = np.array([x[0] for x in x_]) x_2 = np.array([x[1] for x in x_]) x_all = [x_1, x_2] elif embedding_type == 'xlnet': x_, y_ = x, np.array(y) x_1 = np.array([x[0][0] for x in x_]) x_2 = np.array([x[1][0] for x in x_]) x_3 = np.array([x[2][0] for x in x_]) x_all = [x_1, x_2, x_3] else: x_all, y_ = np.array(x), np.array(y) cnt = 0 yield (x_all, y_) x, y = [], [] file_csv.close() print("preprocess_label_ques_to_idx ok")
def preprocess_label_ques_to_idx(self, embedding_type, batch_size, path, embed, rate=1, epcoh=20): label_set, len_all = self.preprocess_get_label_set(path) # 获取label转index字典等, 如果label2index存在则不转换了, dev验证集合的时候用 if not os.path.exists(self.path_fast_text_model_l2i_i2l): count = 0 label2index = {} index2label = {} for label_one in label_set: label2index[label_one] = count index2label[count] = label_one count = count + 1 l2i_i2l = {} l2i_i2l['l2i'] = label2index l2i_i2l['i2l'] = index2label save_json(l2i_i2l, self.path_fast_text_model_l2i_i2l) else: l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l) # 读取数据的比例 len_ql = int(rate * len_all) if len_ql <= 500: # sample时候不生效,使得语料足够训练 len_ql = len_all def process_line(line): # 对每一条数据操作,获取label和问句index data = json.loads(line) label = data['label'] ques_1 = data['sentence1'] ques_2 = data['sentence2'] offset = data['offset'] mention_1 = data["mention"] offset_i = int(offset) que_embed_1 = embed.sentence2idx(text=ques_1) que_embed_2 = embed.sentence2idx(text=ques_2) """ques1""" [input_id_1, input_type_id_1, input_mask_1] = que_embed_1 input_start_mask_1 = [0] * len(input_id_1) input_start_mask_1[offset_i] = 1 input_end_mask_1 = [0] * len(input_id_1) input_end_mask_1[offset_i + len(mention_1) - 1] = 1 input_entity_mask_1 = [0] * len(input_id_1) input_entity_mask_1[offset_i:offset_i + len(mention_1)] = [1] * len(mention_1) """ques2""" [input_id_2, input_type_id_2, input_mask_2] = que_embed_2 kind_2 = [0] * len(input_type_id_2) kind_21 = [0] * len(input_type_id_2) que_2_sp = ques_2.split("|") if len(que_2_sp) >= 2: que_2_sp_sp = que_2_sp[0].split(":") if len(que_2_sp_sp) == 2: kind_2_start = len(que_2_sp_sp[0]) - 1 kind_2_end = kind_2_start + len(que_2_sp_sp[1]) - 1 kind_2[kind_2_start:kind_2_end] = [1] * (kind_2_end - kind_2_start) if "标签:" in que_2_sp[1]: que_21_sp_sp = que_2_sp[1].split(":") kind_21_start = len(que_2_sp[0]) + len(que_21_sp_sp[0]) - 1 kind_21_end = len(que_2_sp[0]) + len( que_21_sp_sp[0]) + len(que_21_sp_sp[1]) - 1 kind_21[kind_21_start:kind_21_end] = [1] * (kind_21_end - kind_21_start) que_embed_x = [ input_id_1, input_type_id_1, input_mask_1, input_start_mask_1, input_end_mask_1, input_entity_mask_1, input_id_2, input_type_id_2, input_mask_2, kind_2, kind_21 ] label_zeros = [0] * len(l2i_i2l['l2i']) label_zeros[l2i_i2l['l2i'][label]] = 1 return que_embed_x, label_zeros for _ in range(epcoh): while True: file_csv = open(path, "r", encoding="utf-8") cout_all_line = 0 cnt = 0 x, y = [], [] # 跳出循环 if len_ql < cout_all_line: break for line in file_csv: cout_all_line += 1 x_line, y_line = process_line(line) x.append(x_line) y.append(y_line) cnt += 1 if cnt == batch_size: if embedding_type in ['bert', 'albert']: x_, y_ = np.array(x), np.array(y) x_all = [] for i in range(len(x_[0])): x_1 = np.array([x[i] for x in x_]) x_all.append(x_1) elif embedding_type == 'xlnet': x_, y_ = x, np.array(y) x_1 = np.array([x[0][0] for x in x_]) x_2 = np.array([x[1][0] for x in x_]) x_3 = np.array([x[2][0] for x in x_]) x_all = [x_1, x_2, x_3] else: x_all, y_ = np.array(x), np.array(y) cnt = 0 yield (x_all, y_) x, y = [], [] file_csv.close() print("preprocess_label_ques_to_idx ok")
def __init__(self, path_model_dir): self.l2i_i2l = None self.path_fast_text_model_vocab2index = path_model_dir + 'vocab2index.json' self.path_fast_text_model_l2i_i2l = path_model_dir + 'l2i_i2l.json' if os.path.exists(self.path_fast_text_model_l2i_i2l): self.l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l)
def preprocess_label_ques_to_idx_old(self, embedding_type, batch_size, path, embed, rate=1, epcoh=20): label_set, len_all = self.preprocess_get_label_set(path) # 获取label转index字典等, 如果label2index存在则不转换了, dev验证集合的时候用 if not os.path.exists(self.path_fast_text_model_l2i_i2l): count = 0 label2index = {} index2label = {} for label_one in label_set: label2index[label_one] = count index2label[count] = label_one count = count + 1 l2i_i2l = {} l2i_i2l['l2i'] = label2index l2i_i2l['i2l'] = index2label save_json(l2i_i2l, self.path_fast_text_model_l2i_i2l) else: l2i_i2l = load_json(self.path_fast_text_model_l2i_i2l) # 读取数据的比例 len_ql = int(rate * len_all) if len_ql <= 500: # sample时候不生效,使得语料足够训练 len_ql = len_all def process_line(line): # 对每一条数据操作,获取label和问句index data = json.loads(line) label = data['label'] ques_1 = data['sentence1'] ques_2 = data['sentence2'] offset = data['offset'] mention = data["mention"] offset_i = int(offset) # if data.get("label_l2i"): # ques_entity = data.get("label_l2i") + "#" + ques_1[:offset_i] + "#" + mention + "#" + ques_1[offset_i+len(mention):] # else: # ques_entity = ques_1[:offset_i] + "#" + mention + "#" + ques_1[offset_i+len(mention):] + "$$" + ques_2 # que_embed = embed.sentence2idx(text=ques_entity) que_embed = embed.sentence2idx(ques_1, second_text=ques_2) label_zeros = [0] * len(l2i_i2l['l2i']) label_zeros[l2i_i2l['l2i'][label]] = 1 return que_embed, label_zeros for _ in range(epcoh): while True: file_csv = open(path, "r", encoding="utf-8") cout_all_line = 0 cnt = 0 x, y = [], [] # 跳出循环 if len_ql < cout_all_line: break for line in file_csv: cout_all_line += 1 x_line, y_line = process_line(line) x.append(x_line) y.append(y_line) cnt += 1 if cnt == batch_size: if embedding_type in ['bert', 'albert']: x_, y_ = np.array(x), np.array(y) x_1 = np.array([x[0] for x in x_]) x_2 = np.array([x[1] for x in x_]) x_all = [x_1, x_2] elif embedding_type == 'xlnet': x_, y_ = x, np.array(y) x_1 = np.array([x[0][0] for x in x_]) x_2 = np.array([x[1][0] for x in x_]) x_3 = np.array([x[2][0] for x in x_]) x_all = [x_1, x_2, x_3] else: x_all, y_ = np.array(x), np.array(y) cnt = 0 yield (x_all, y_) x, y = [], [] file_csv.close() print("preprocess_label_ques_to_idx ok")
def evaluate(path_hyper_parameter=path_hyper_parameters, rate=1.0): # 输入预测 # 加载超参数 hyper_parameters = load_json(path_hyper_parameter) pt = PreprocessTextMulti() # 模式初始化和加载 graph = Graph(hyper_parameters) graph.load_model() ra_ed = graph.word_embedding # get validation data ques_list, val_list, que, val = pt.preprocess_label_ques_to_idx( hyper_parameters['embedding_type'], hyper_parameters['data']['val_data'], ra_ed, rate=rate, shuffle=True) print(len(ques_list)) print("que:", len(que)) # print(val) # str to token ques_embed_list = [] count = 0 acc_count = 0 not_none_count = 0 not_none_acc_count = 0 sum_iou = 0 sum_all_iou = 0 for index, que___ in enumerate(que): # print("原句 ", index, que[index]) # print("真实分类 ", index, val[index]) # print("ques: ", ques) ques_embed = ra_ed.sentence2idx(que[index]) if hyper_parameters['embedding_type'] == 'albert': x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) ques_embed = [x_val_1, x_val_2] else: x_val = ques_embed # print("ques_embed: ", ques_embed) if hyper_parameters['embedding_type'] == 'bert': x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed # print("x_val", x_val) ques_embed_list.append(x_val) # 预测 pred = graph.predict(x_val) # print(pred) # 取id to label and pred pre = pt.prereocess_idx(pred[0]) # print("pre",pre) ls_nulti = [] threshold = 0.44 top_threshold = 0 for i, ls in enumerate(pre[0]): if i == 0 or ls[1] > threshold: ls_nulti.append(ls) top_threshold = ls[1] elif abs(ls[1] - top_threshold) < top_threshold / 4.0: ls_nulti.append(ls) # print("预测结果", index, pre[0]) # print(ls_nulti) res = cal_acc(ls_nulti, val[index].split(",")) res_iou, res_all_iou = cal_iou(ls_nulti, val[index].split(",")) sum_iou += res_iou sum_all_iou += res_all_iou if res: if val[index] != "无": not_none_acc_count += 1 acc_count += 1 else: print("原句 ", index, que[index]) print("真实分类 ", index, val[index]) print("pre ", pre) print("iou ", res_iou) count += 1 if val[index] != "无": not_none_count += 1 print("acc: ", acc_count / count) print("not none acc: ", not_none_acc_count / not_none_count) print("average iou: ", sum_iou / sum_all_iou) # log append_log(hyper_parameters, acc_count / count, not_none_acc_count / not_none_count, threshold)
def evaluate(path_hyper_parameter=path_hyper_parameters, rate=1.0): # 输入预测 # 加载超参数 hyper_parameters = load_json(path_hyper_parameter) pt = PreprocessTextMulti() # 模式初始化和加载 graph = Graph(hyper_parameters) graph.load_model() ra_ed = graph.word_embedding # init confusion table dict_all = initConfusion() # get validation data ques_list, val_list, que, val = pt.preprocess_label_ques_to_idx( hyper_parameters['embedding_type'], hyper_parameters['data']['test_data'], ra_ed, rate=rate, shuffle=True) print(len(ques_list)) print("que:", len(que)) # print(val) # str to token ques_embed_list = [] count = 0 acc_count = 0 not_none_count = 0 not_none_acc_count = 0 sum_iou = 0 sum_all_iou = 0 for index, que___ in enumerate(que): # print("原句 ", index, que[index]) # print("真实分类 ", index, val[index]) # print("ques: ", ques) ques_embed = ra_ed.sentence2idx(que[index]) if hyper_parameters['embedding_type'] == 'albert': x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) ques_embed = [x_val_1, x_val_2] else: x_val = ques_embed # print("ques_embed: ", ques_embed) if hyper_parameters['embedding_type'] == 'bert': x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed # print("x_val", x_val) ques_embed_list.append(x_val) # 预测 pred = graph.predict(x_val) # print(pred) # 取id to label and pred pre = pt.prereocess_idx(pred[0]) # print("pre",pre) ls_nulti = [] threshold = 0.65 has_scope = False has_dense = False for i, ls in enumerate(pre[0]): if ls[0] in ['多发', '散发', '无']: if not has_scope: has_scope = True ls_nulti.append(ls) if ls[0] in val[index].split(","): dict_all[ls[0]]['TP'] += 1 else: dict_all[ls[0]]['FN'] += 1 else: if ls[0] in val[index].split(","): dict_all[ls[0]]['FP'] += 1 else: dict_all[ls[0]]['TN'] += 1 if ls[0] not in ['多发', '散发', '无']: if ls[1] > threshold or not has_dense: ls_nulti.append(ls) if ls[0] in val[index].split(","): dict_all[ls[0]]['TP'] += 1 else: dict_all[ls[0]]['FP'] += 1 has_dense = True else: if ls[0] in val[index].split(","): dict_all[ls[0]]['FN'] += 1 else: dict_all[ls[0]]['TN'] += 1 # print("预测结果", index, pre[0]) # print(ls_nulti) res = cal_acc(ls_nulti, val[index].split(",")) res_iou = cal_iou(ls_nulti, val[index].split(",")) sum_iou += res_iou # sum_all_iou+=res_all_iou if res: # if val[index] != "无": # not_none_acc_count += 1 acc_count += 1 else: print("原句 ", index, que[index]) print("真实分类 ", index, val[index]) print("pre ", pre) print("iou ", res_iou) print(ls_nulti) count += 1 if val[index] != "无": not_none_count += 1 print("acc: ", acc_count / count) # print("not none acc: ", not_none_acc_count / not_none_count) print("average iou: ", sum_iou / count) import prettytable as pt tb = pt.PrettyTable() tb.field_names = [" ", "Recall", "Precision", "TP", "FP", "TN", "FN"] for item in dict_all: if dict_all[item]['TP'] + dict_all[item]['FN'] == 0: recall = 1 else: recall = dict_all[item]['TP'] / (dict_all[item]['TP'] + dict_all[item]['FN']) if dict_all[item]['TP'] + dict_all[item]['FP'] == 0: precision = 1 else: precision = dict_all[item]['TP'] / (dict_all[item]['TP'] + dict_all[item]['FP']) # print(item,recall,precision) tb.add_row([ item, recall, precision, dict_all[item]['TP'], dict_all[item]['FP'], dict_all[item]['TN'], dict_all[item]['FN'] ]) print(tb) # log append_log(hyper_parameters, acc_count / count, not_none_acc_count / not_none_count, threshold)
pathlib.Path(os.path.abspath(__file__)).parent.parent.parent) sys.path.append(project_path) # 地址 from keras_textclassification.conf.path_config import path_model, path_fineture, path_model_dir, path_hyper_parameters # 训练验证数据地址 from keras_textclassification.conf.path_config import path_baidu_qa_2019_train, path_baidu_qa_2019_valid # 数据预处理, 删除文件目录下文件 from keras_textclassification.data_preprocess.text_preprocess import PreprocessText, read_and_process, load_json # 模型图 from keras_textclassification.m02_TextCNN.graph import TextCNNGraph as Graph import numpy as np # flask from flask import Flask, request, jsonify app = Flask(__name__) hyper_parameters = load_json(path_hyper_parameters) pt = PreprocessText(path_model_dir) # 模式初始化和加载 graph = Graph(hyper_parameters) graph.load_model() ra_ed = graph.word_embedding ques = '我要打王者荣耀' # str to token ques_embed = ra_ed.sentence2idx(ques) if hyper_parameters['embedding_type'] in ['bert', 'albert']: x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] else: x_val = ques_embed # 预测
def __init__(self): self.l2i_i2l = None if os.path.exists(path_fast_text_model_l2i_i2l): self.l2i_i2l = load_json(path_fast_text_model_l2i_i2l)
def pred_tet(path_hyper_parameter=path_hyper_parameters, path_test=None, rate=1.0): """ 测试集测试与模型评估 :param hyper_parameters: json, 超参数 :param path_test:str, path of test data, 测试集 :param rate: 比率, 抽出rate比率语料取训练 :return: None """ hyper_parameters = load_json(path_hyper_parameter) if path_test: # 从外部引入测试数据地址 hyper_parameters['data']['test_data'] = path_test time_start = time.time() # graph初始化 graph = Graph(hyper_parameters) print("graph init ok!") graph.load_model() print("graph load ok!") ra_ed = graph.word_embedding # 数据预处理 pt = PreprocessSim(path_model_dir) data = pd.read_csv(hyper_parameters['data']['test_data']) sentence_1 = data["sentence1"].values.tolist() sentence_2 = data["sentence2"].values.tolist() labels = data["label"].values.tolist() sentence_1 = [extract_chinese(str(line1).upper()) for line1 in sentence_1] sentence_2 = [extract_chinese(str(line2).upper()) for line2 in sentence_2] labels = [extract_chinese(str(line3).upper()) for line3 in labels] # 取该数据集的百分之几的语料测试 len_rate = int(len(labels) * rate) sentence_1 = sentence_1[0:len_rate] sentence_2 = sentence_2[0:len_rate] labels = labels[0:len_rate] y_pred = [] count = 0 for i in range(len_rate): count += 1 ques_embed = ra_ed.sentence2idx(text=sentence_1[i], second_text=sentence_2[i]) # print(hyper_parameters['embedding_type']) if hyper_parameters['embedding_type'] in ['bert', 'albert']: # bert数据处理, token x_val_1 = np.array([ques_embed[0]]) x_val_2 = np.array([ques_embed[1]]) x_val = [x_val_1, x_val_2] # 预测 pred = graph.predict(x_val) pre = pt.prereocess_idx(pred[0]) label_pred = pre[0][0][0] if count % 1000 == 0: print(label_pred) y_pred.append(label_pred) print("data pred ok!") # 预测结果转为int类型 index_y = [pt.l2i_i2l['l2i'][i] for i in labels] index_pred = [pt.l2i_i2l['l2i'][i] for i in y_pred] target_names = [ pt.l2i_i2l['i2l'][str(i)] for i in list(set((index_pred + index_y))) ] # 评估 report_predict = classification_report(index_y, index_pred, target_names=target_names, digits=9) print(report_predict) print("耗时:" + str(time.time() - time_start))