def test(model, args): data_path = args.data_path gpu = args.gpu n_classes = args.classes data_width = args.width data_height = args.height # set device configuration device_ids = [] if gpu == 'gpu': if not torch.cuda.is_available(): print("No cuda available") raise SystemExit device = torch.device(args.device1) device_ids.append(args.device1) if args.device2 != -1: device_ids.append(args.device2) if args.device3 != -1: device_ids.append(args.device3) if args.device4 != -1: device_ids.append(args.device4) else: device = torch.device("cpu") if len(device_ids) > 1: model = nn.DataParallel(model, device_ids=device_ids) model = model.to(device) # set testdataset test_dataset = SampleDataset(data_path) test_loader = DataLoader( test_dataset, batch_size=10, num_workers=4, ) print('test_dataset : {}, test_loader : {}'.format(len(test_dataset), len(test_loader))) avg_score = 0.0 # test model.eval() # Set model to evaluate mode with torch.no_grad(): for batch_idx, (inputs, labels) in enumerate(test_loader): inputs = inputs.to(device).float() labels = labels.to(device).long() target = make_one_hot(labels[:, 0, :, :], n_classes, device) pred = model(inputs) loss = dice_score(pred, target) avg_score += loss.data.cpu().numpy() del inputs, labels, target, pred, loss avg_score /= len(test_loader) print('dice_score : {:.4f}'.format(avg_score))
from loss import dice_score torch.backends.cudnn.benchmark = True torch.manual_seed(10) file_path = './brats18_dataset/npy_test/test_t2.npy' model_path = './weight/unet_4.pth' train_data = Dataset_brain_4(file_path) batch_size = 64 train_loader = data.DataLoader(dataset=train_data, batch_size=batch_size, num_workers=4) unet = Unet(4) unet.load_state_dict(torch.load(model_path)) unet.cuda() unet.eval() batch_score = 0 num_batch = 0 with torch.no_grad(): for i, (img, label) in enumerate(train_loader): # print(img.shape) seg = unet(img.float().cuda()) seg = seg.cpu() seg[seg >= 0.5] = 1. seg[seg != 1] = 0. batch_score += dice_score(seg, label.float()).data.numpy() # print(num_batch) num_batch += img.size(0) del seg, img, label batch_score /= num_batch print(' test_score = %.4f' % (batch_score))
def test(model, args): data_path = args.data_path n_channels = args.channels n_classes = args.classes data_width = args.width data_height = args.height gpu = args.gpu # Hyper paremter for MagNet thresholds = [0.01, 0.05, 0.001, 0.005] reformer_model = None if args.reformer == 'autoencoder1': reformer_model = autoencoder(n_channels) elif args.reformer == 'autoencoder2': reformer_model = autoencoder2(n_channels) else: print("wrong reformer model : must be autoencoder1 or autoencoder2") raise SystemExit print('reformer model') summary(reformer_model, input_size=(n_channels, data_height, data_width), device='cpu') detector_model = None if args.detector == 'autoencoder1': detector_model = autoencoder(n_channels) elif args.detector == 'autoencoder2': detector_model = autoencoder2(n_channels) else: print("wrong detector model : must be autoencoder1 or autoencoder2") raise SystemExit print('detector model') summary(detector_model, input_size=(n_channels, data_height, data_width), device='cpu') # set device configuration device_ids = [] if gpu == 'gpu': if not torch.cuda.is_available(): print("No cuda available") raise SystemExit device = torch.device(args.model_device1) device_defense = torch.device(args.defense_model_device) device_ids.append(args.model_device1) if args.model_device2 != -1: device_ids.append(args.model_device2) if args.model_device3 != -1: device_ids.append(args.model_device3) if args.model_device4 != -1: device_ids.append(args.model_device4) else: device = torch.device("cpu") device_defense = torch.device("cpu") detector = AEDetector(detector_model, device_defense, args.detector_path, p=2) reformer = SimpleReformer(reformer_model, device_defense, args.reformer_path) classifier = Classifier(model, device, args.model_path, device_ids) # set testdataset test_dataset = SampleDataset(data_path) test_loader = DataLoader( test_dataset, batch_size=10, num_workers=4, ) print('test_dataset : {}, test_loader : {}'.format(len(test_dataset), len(test_loader))) # Defense with MagNet print('test start') for thrs in thresholds: print('----------------------------------------') counter = 0 avg_score = 0.0 thrs = torch.tensor(thrs) with torch.no_grad(): for batch_idx, (inputs, labels) in enumerate(test_loader): inputs = inputs.float() labels = labels.to(device).long() target = make_one_hot(labels[:, 0, :, :], n_classes, device) operate_results = operate(reformer, classifier, inputs) all_pass, _ = filters(detector, inputs, thrs) if len(all_pass) == 0: continue filtered_results = operate_results[all_pass] pred = filtered_results.to(device).float() target = target[all_pass] loss = dice_score(pred, target) avg_score += loss.data.cpu().numpy() # statistics counter += 1 del inputs, labels, pred, target, loss if counter: avg_score = avg_score / counter print('threshold : {:.4f}, avg_score : {:.4f}'.format( thrs, avg_score)) else: print( 'threshold : {:.4f} , no images pass from filter'.format(thrs))