def test(model, split, data_loader, device, task, dir_path): ''' testing function :param model: the model to test :param split: the data to test, 'train/val/test' :param data_loader: DataLoader of data :param device: cpu or cuda :param task: task of current dataset, binary-class/multi-class/multi-label, binary-class :param dir_path: where to save data ''' model.eval() y_true = torch.tensor([]).to(device) y_score = torch.tensor([]).to(device) with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(data_loader): outputs = model(inputs.to(device)) if task == 'multi-label, binary-class': targets = targets.to(torch.float32).to(device) m = nn.Sigmoid() outputs = m(outputs).to(device) else: targets = targets.squeeze().long().to(device) m = nn.Softmax(dim=1) outputs = m(outputs).to(device) targets = targets.float().resize_(len(targets), 1) y_true = torch.cat((y_true, targets), 0) y_score = torch.cat((y_score, outputs), 0) y_true = y_true.cpu().numpy() y_score = y_score.detach().cpu().numpy() auc = getAUC(y_true, y_score, task) acc = getACC(y_true, y_score, task) print('%s AUC: %.5f ACC: %.5f' % (split, auc, acc)) # if args.output_root is not None: # output_dir = os.path.join(args.output_root, args.data_name) # if not os.path.exists(output_dir): # os.mkdir(output_dir) # output_path = os.path.join(output_dir, '%s.csv' % (split)) output_path = os.path.join(dir_path, '%s.csv' % (split)) save_results(y_true, y_score, output_path)
def test(model, split, data_loader, device, flag, task, output_root=None): """ testing function :param model: the model to test :param split: the data to test, 'train/val/test' :param data_loader: DataLoader of data :param device: cpu or cuda :param flag: subset name :param task: task of current dataset, binary-class/multi-class/multi-label, binary-class :param data_name: data name """ configuration = { "pathmnist": { "classes": ['0', '1', '2', '3', '4', '5', '6', '7', '8'] }, "chestmnist": { "classes": [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13' ] }, "dermamnist": { "classes": ['0', '1', '2', '3', '4', '5', '6'] }, "octmnist": { "classes": ['0', '1', '2', '3'] }, "retinamnist": { "classes": ['0', '1', '2', '3', '4'] }, "pneumoniamnist": { "classes": ['0', '1'] }, "breastmnist": { "classes": ['0', '1'] }, "organmnist_axial": { "classes": ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] }, "organmnist_coronal": { "classes": ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] }, "organmnist_sagittal": { "classes": ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] } } args = configuration[flag] model.eval() y_true = torch.tensor([]).to(device) y_score = torch.tensor([]).to(device) pred_labels = [] true_labels = [] scores = [] with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(data_loader): outputs = model(inputs.to(device)) if task == 'multi-label, binary-class': targets = targets.to(torch.float32).to(device) m = nn.Sigmoid() outputs = m(outputs).to(device) else: targets = targets.squeeze().long().to(device) m = nn.Softmax(dim=1) outputs = m(outputs).to(device) targets = targets.float().resize_(len(targets), 1) predict_label = outputs.data.max(1)[1].cpu().numpy() pred_labels.extend(predict_label) true_label = targets.data.cpu().numpy() true_labels.extend(true_label) scores.extend(outputs.data.cpu().numpy().tolist()) y_true = torch.cat((y_true, targets), 0) y_score = torch.cat((y_score, outputs), 0) y_true = y_true.cpu().numpy() y_score = y_score.detach().cpu().numpy() if split == 'test' and flag != 'chestmnist': result = metric_results(pred_labels, true_labels) printMetricResults(result) plot_confusion_matrix(result['confusion_matrix'], classes=args["classes"], normalize=False, title=flag) confusion_matrixPath = os.path.join(config.confusionMatrixPath, flag) plt.savefig(confusion_matrixPath, dpi=600) plt.clf() rocPicName = flag plotRocCurve(flag, true_labels, scores, config.rocPic_path, rocPicName) auc = getAUC(y_true, y_score, task) acc = getACC(y_true, y_score, task) print('%s AUC: %.5f ACC: %.5f' % (split, auc, acc)) if output_root is not None: output_dir = os.path.join(output_root, flag) if not os.path.exists(output_dir): os.mkdir(output_dir) output_path = os.path.join(output_dir, '%s.csv' % (split)) save_results(y_true, y_score, output_path)