class RnnModel: def __init__(self): self.categories, self.cat_to_id = read_category() self.words, self.word_to_id = read_vocab(vocab_file) self.model = TextRNN() self.model.load_state_dict(torch.load('model_params.pkl')) def predict(self, message): content = message data = [self.word_to_id[x] for x in content if x in self.word_to_id] data = kr.preprocessing.sequence.pad_sequences([data], 600) data = torch.LongTensor(data) y_pred_cls = self.model(data) class_index = torch.argmax(y_pred_cls[0]).item() return self.categories[class_index]
def test(): # 配置文件 cf = Config('./config.yaml') # 有GPU用GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 测试数据 test_data = NewsDataset("./data/cnews_final_test.txt", cf.max_seq_len) test_dataloader = DataLoader(test_data, batch_size=cf.batch_size, shuffle=True) # 预训练词向量矩阵 embedding_matrix = get_pre_embedding_matrix("./data/final_vectors") # 模型 model = TextRNN(cf, torch.tensor(embedding_matrix)) # model.load_state_dict(torch.load("./output/model.bin",map_location='cpu')) model.load_state_dict(torch.load("./output/model.bin")) # 把模型放到指定设备 model.to(device) # 让模型并行化运算 if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) # 训练 start_time = time.time() data_len = len(test_dataloader) model.eval() y_pred = np.array([]) y_test = np.array([]) # for step,batch in enumerate(tqdm(test_dataloader,"batch",total=len(test_dataloader))): for step, batch in enumerate(test_dataloader): label_id = batch['label_id'].squeeze(1).to(device) seq_len = batch["seq_len"].to(device) segment_ids = batch['segment_ids'].to(device) # 将序列按长度降序排列 seq_len, perm_idx = seq_len.sort(0, descending=True) label_id = label_id[perm_idx] segment_ids = segment_ids[perm_idx].transpose(0, 1) with torch.no_grad(): pred = model.get_labels(segment_ids, seq_len) y_pred = np.hstack((y_pred, pred)) y_test = np.hstack((y_test, label_id.to("cpu").numpy())) # 评估 print("Precision, Recall and F1-Score...") print( metrics.classification_report(y_test, y_pred, target_names=get_labels('./data/label'))) # 混淆矩阵 print("Confusion Matrix...") cm = metrics.confusion_matrix(y_test, y_pred) print(cm)