Exemplo n.º 1
0
def main(args):
    if args.cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")
    model = UNet(n_channels=args.colordim, n_classes=args.num_class)
    model2 = UNet(n_channels=args.colordim2, n_classes=args.num_class2)
    if args.cuda:
        model = model.cuda()
        model2 = model2.cuda()
    model.load_state_dict(torch.load(args.pretrain_net))
    model2.load_state_dict(torch.load(args.pretrain_net2))
    model.eval()
    model2.eval()
    predDataset = generateDataset(args.pre_root_dir,
                                  args.img_size,
                                  args.colordim,
                                  isTrain=False)
    predLoader = DataLoader(dataset=predDataset,
                            batch_size=args.predictbatchsize,
                            num_workers=args.threads)
    with torch.no_grad():
        cm_w = np.zeros((2, 2))
        for batch_idx, (batch_x, batch_name) in enumerate(predLoader):
            batch_x = batch_x
            if args.cuda:
                batch_x = batch_x.float().cuda()

            out1 = model(batch_x)
            prediction2 = torch.cat((batch_x, out1), 1)
            out = model2(prediction2)
            pred_prop, pred_label = torch.max(out, 1)
            pred_label_np = pred_label.cpu().numpy()
            for id in range(len(batch_name)):
                predLabel_filename = args.preDir + '/' + batch_name[id] + '.png'

                pred_label_single = pred_label_np[id, :, :]
                label_filename = args.label_root_dir + batch_name[id] + '.png'
                label = io.imread(label_filename)
                cm = confusion_matrix(label.ravel(), pred_label_single.ravel())
                pred_label_single = np.where(pred_label_single > 0, 255, 0)
                print(np.max(pred_label_single))
                print(batch_name[id])
                if (np.max(pred_label_single) > 0):
                    io.imsave(predLabel_filename,
                              pred_label_single.astype(np.uint8))
                    #else:
                    #io.imsave(predLabel_filename, pred_label_single.astype(np.int32))
                    cm_w = cm_w + cm
                #OA_s, F1_s, IoU_s = evaluate(cm)
                #print('OA_s = ' + str(OA_s) + ', F1_s = ' + str(F1_s) + ', IoU = ' + str(IoU_s))

        print(cm_w)
        OA_w, F1_w, IoU_w = evaluate(cm_w)
        print('OA_w = ' + str(OA_w) + ', F1_w = ' + str(F1_w) + ', IoU = ' +
              str(IoU_w))
Exemplo n.º 2
0
    writer = SummaryWriter()

    image = Image.open('./ht2-c2.jpg')
    out = TF.to_tensor(image)
    out = out.reshape(1, 3, 640, 640)
    inp = torch.rand(1, 3, 640, 640)

    fig = plt.figure()
    plt.imshow(out[0].permute(1, 2, 0).numpy())
    # plt.show
    writer.add_figure("Ground Truth", fig)

    fig = plt.figure()
    plt.imshow(inp[0].permute(1, 2, 0).numpy())
    writer.add_figure("Input", fig)

    num_iter = 500
    writer.add_scalar("Number_of_Iterations", num_iter)

    model = UNet(3, 3)
    if torch.cuda.is_available():
        model.cuda()

    criterion = nn.MSELoss()

    learning_rate = 0.1
    writer.add_scalar("Learning_Rate", learning_rate)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    train(num_iter, inp, out, model, optimizer, criterion)
    writer.close()
Exemplo n.º 3
0
                                in_size = 320,
                                train_in_size = input_size)
    train_dataloader = torch.utils.data.DataLoader(SRRFDATASET, batch_size=batch_size, shuffle=True, pin_memory=True) # better than for loop
    
    SRRFDATASET2 = ReconsDataset(all_data_path="/media/star/LuhongJin/UNC_data/SRRF/New_training_20190829/0NPY_Dataset/Dataset/Microtubule/",
                                maximum_intensity_4normalization_path="/home/star/0_code_lhj/DL-SIM-github/Training_codes/UNetMax_intensity.npy",
                                transform = ToTensor(),
                                training_dataset = False,
                                in_size = 320,
                                train_in_size = input_size)
    validation_dataloader = torch.utils.data.DataLoader(SRRFDATASET2, batch_size=batch_size, shuffle=True, pin_memory=True) # better than for loop

    model = UNet(n_channels=input_size, n_classes=output_size)

    print("{} paramerters in total".format(sum(x.numel() for x in model.parameters())))
    model.cuda(cuda)
    optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate,  betas=(0.9, 0.999))

    loss_all = np.zeros((2000, 4))
    for epoch in range(2000):
        
        mae_m, mae_s = val_during_training(train_dataloader)
        loss_all[epoch,0] = mae_m
        loss_all[epoch,1] = mae_s
        mae_m, mae_s = val_during_training(validation_dataloader) 
        loss_all[epoch,2] = mae_m
        loss_all[epoch,3] = mae_s
        
        file = Workbook(encoding = 'utf-8')
        table = file.add_sheet('loss_all')
        for i,p in enumerate(loss_all):
Exemplo n.º 4
0
    #train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
    #val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)

    netG = UNet(n_channels=15, n_classes=1)
    #print(summary(netG,(15,128,128)))
    print('# generator parameters:',
          sum(param.numel() for param in netG.parameters()))
    netD = Discriminator()
    print('# discriminator parameters:',
          sum(param.numel() for param in netD.parameters()))
    #print(summary(netD,(1,256,256)))

    generator_criterion = GeneratorLoss()

    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    results = {
        'd_loss': [],
        'g_loss': [],
        'd_score': [],
        'g_score': [],
        'psnr': [],
        'ssim': []
    }
    data, target = get_train_data(X_test, y_test, batch_size)
Exemplo n.º 5
0
    LE_img = imgread(os.path.join(dir_path_LE, "LE_01.tif"))

    LE_512 = cropImage(LE_img, IMG_SHAPE[0],IMG_SHAPE[1])
    sample_le = {}
    for le_512 in LE_512:
        tiles = crop_prepare(le_512, CROP_STEP, IMG_SIZE)
        for n,img in enumerate(tiles):
            if n not in sample_le:
                sample_le[n] = []
            img = transform.resize(img,(IMG_SIZE*2, IMG_SIZE*2),preserve_range=True,order=3)
            sample_le[n].append(img)

	SNR_model = UNet(n_channels=15, n_classes=15)
	print("{} paramerters in total".format(sum(x.numel() for x in SNR_model.parameters())))
	SNR_model.cuda(cuda)
	SNR_model.load_state_dict(torch.load(SNR_model_path))
	# SNR_model.load_state_dict(torch.load(os.path.join(dir_path,"model","LE_HE_mito","LE_HE_0825.pkl")))
	SNR_model.eval()

	SIM_UNET = UNet(n_channels=15, n_classes=1)
	print("{} paramerters in total".format(sum(x.numel() for x in SIM_UNET.parameters())))
	SIM_UNET.cuda(cuda)
	SIM_UNET.load_state_dict(torch.load(SIM_UNET_model_path))
	# SIM_UNET.load_state_dict(torch.load(os.path.join(dir_path,"model","HE_HER_mito","HE_X2_HER_0825.pkl")))
	SIM_UNET.eval()

    SRRFDATASET = ReconsDataset(
    img_dict=sample_le,
    transform=ToTensor(),
    in_norm = LE_in_norm,
Exemplo n.º 6
0
#quit after interrupt
import sys

dir = 'data'
ids = []

for f in os.listdir(dir + '/train'):
    id = f[:-4]
    ids.append([id, 0])
    ids.append([id, 1])

np.random.shuffle(ids)
#%%

net = UNet(3, 1)
net.cuda()


def train(net):
    optimizer = optim.Adam(net.parameters(), lr=1)
    criterion = DiceLoss()

    epochs = 5
    for epoch in range(epochs):
        print('epoch {}/{}...'.format(epoch + 1, epochs))
        l = 0

        for i, c in enumerate(ids):
            id = c[0]
            pos = c[1]
            im = PIL.Image.open(dir + '/train/' + id + '.jpg')
def test(experiment_path, test_epoch):
    # ========= CONFIG FILE TO READ FROM =======
    config = configparser.RawConfigParser()
    config.read('./' + experiment_path + '/' + experiment_path + '_config.txt')
    # ===========================================
    # run the training on invariant or local
    path_data = config.get('data paths', 'path_local')
    model = config.get('training settings', 'model')
    # original test images (for FOV selection)
    DRIVE_test_imgs_original = path_data + config.get('data paths', 'test_imgs_original')
    test_imgs_orig = load_hdf5(DRIVE_test_imgs_original)
    full_img_height = test_imgs_orig.shape[2]
    full_img_width = test_imgs_orig.shape[3]
    # the border masks provided by the DRIVE
    DRIVE_test_border_masks = path_data + config.get('data paths', 'test_border_masks')
    test_border_masks = load_hdf5(DRIVE_test_border_masks)
    # dimension of the patches
    patch_height = int(config.get('data attributes', 'patch_height'))
    patch_width = int(config.get('data attributes', 'patch_width'))
    # the stride in case output with average
    stride_height = int(config.get('testing settings', 'stride_height'))
    stride_width = int(config.get('testing settings', 'stride_width'))
    assert (stride_height < patch_height and stride_width < patch_width)
    # model name
    name_experiment = config.get('experiment name', 'name')
    path_experiment = './' + name_experiment + '/'
    # N full images to be predicted
    Imgs_to_test = int(config.get('testing settings', 'full_images_to_test'))
    # Grouping of the predicted images
    N_visual = int(config.get('testing settings', 'N_group_visual'))
    # ====== average mode ===========
    average_mode = config.getboolean('testing settings', 'average_mode')
    #N_subimgs = int(config.get('training settings', 'N_subimgs'))
    #batch_size = int(config.get('training settings', 'batch_size'))
    #epoch_size = N_subimgs // (batch_size)
    # #ground truth
    # gtruth= path_data + config.get('data paths', 'test_groundTruth')
    # img_truth= load_hdf5(gtruth)
    # visualize(group_images(test_imgs_orig[0:20,:,:,:],5),'original')#.show()
    # visualize(group_images(test_border_masks[0:20,:,:,:],5),'borders')#.show()
    # visualize(group_images(img_truth[0:20,:,:,:],5),'gtruth')#.show()

    # ============ Load the data and divide in patches
    patches_imgs_test = None
    new_height = None
    new_width = None
    masks_test = None
    patches_masks_test = None

    if average_mode == True:
        patches_imgs_test, new_height, new_width, masks_test= get_data_testing_overlap(
            DRIVE_test_imgs_original = DRIVE_test_imgs_original, #original'DRIVE_datasets_training_testing/test_hard_masks.npy'
            DRIVE_test_groudTruth = path_data + config.get('data paths', 'test_groundTruth'),  #masks
            Imgs_to_test = int(config.get('testing settings', 'full_images_to_test')),
            patch_height = patch_height,
            patch_width = patch_width,
            stride_height = stride_height,
            stride_width = stride_width)
    else:
        patches_imgs_test, patches_masks_test = get_data_testing_test(
            DRIVE_test_imgs_original = DRIVE_test_imgs_original,  #original
            DRIVE_test_groudTruth = path_data + config.get('data paths', 'test_groundTruth'),  #masks
            Imgs_to_test = int(config.get('testing settings', 'full_images_to_test')),
            patch_height = patch_height,
            patch_width = patch_width
        )
    #np.save(path_experiment + 'test_patches.npy', patches_imgs_test)
    #visualize(group_images(patches_imgs_test,100),'./'+name_experiment+'/'+"test_patches")

    # ================ Run the prediction of the patches ==================================
    best_last = config.get('testing settings', 'best_last')
    # Load the saved model
    if model == 'UNet':
        net = UNet(n_channels=1, n_classes=2)
    elif model == 'UNet_cat':
        net = UNet_cat(n_channels=1, n_classes=2)
    else:
        net = UNet_level4_our(n_channels=1, n_classes=2)
    # load data
    test_data = data.TensorDataset(torch.tensor(patches_imgs_test),torch.zeros(patches_imgs_test.shape[0]))
    test_loader = data.DataLoader(test_data, batch_size=1, pin_memory=True, shuffle=False)
    trained_model = path_experiment + 'DRIVE_' + str(test_epoch) + 'epoch.pth'
    print(trained_model)
    # trained_model= path_experiment+'DRIVE_unet2_B'+str(60*epoch_size)+'.pth'
    net.load_state_dict(torch.load(trained_model))
    net.eval()
    print('Finished loading model :' + trained_model)
    net = net.cuda()
    cudnn.benchmark = True
    # Calculate the predictions
    predictions_out = np.empty((patches_imgs_test.shape[0],patch_height*patch_width,2))
    for i_batch, (images, targets) in enumerate(test_loader):
        images = Variable(images.float().cuda())
        out1= net(images)

        pred = out1.permute(0,2,3,1)

        pred = F.softmax(pred, dim=-1)

        pred = pred.data.view(-1,patch_height*patch_width,2)

        predictions_out[i_batch] = pred

    # ===== Convert the prediction arrays in corresponding images
    pred_patches_out = pred_to_imgs(predictions_out, patch_height, patch_width, "original")
    #np.save(path_experiment + 'pred_patches_' + str(test_epoch) + "_epoch" + '.npy', pred_patches_out)
    #visualize(group_images(pred_patches_out,100),'./'+name_experiment+'/'+"pred_patches")


    #========== Elaborate and visualize the predicted images ====================
    pred_imgs_out = None
    orig_imgs = None
    gtruth_masks = None
    if average_mode == True:
        pred_imgs_out = recompone_overlap(pred_patches_out,new_height,new_width, stride_height, stride_width)
        orig_imgs = my_PreProc(test_imgs_orig[0:pred_imgs_out.shape[0],:,:,:])    #originals
        gtruth_masks = masks_test  #ground truth masks
    else:
        pred_imgs_out = recompone(pred_patches_out,10,9)       # predictions
        orig_imgs = recompone(patches_imgs_test,10,9)  # originals
        gtruth_masks = recompone(patches_masks_test,10,9)  #masks

    # apply the DRIVE masks on the repdictions #set everything outside the FOV to zero!!
    # DRIVE MASK  #only for visualization
    kill_border(pred_imgs_out, test_border_masks)
    # back to original dimensions
    orig_imgs = orig_imgs[:,:,0:full_img_height,0:full_img_width]
    pred_imgs_out = pred_imgs_out[:, :, 0:full_img_height, 0:full_img_width]
    gtruth_masks = gtruth_masks[:, :, 0:full_img_height, 0:full_img_width]

    print ("Orig imgs shape: "+str(orig_imgs.shape))
    print("pred imgs shape: " + str(pred_imgs_out.shape))
    print("Gtruth imgs shape: " + str(gtruth_masks.shape))
    np.save(path_experiment + 'pred_img_' + str(test_epoch) + "_epoch" + '.npy',pred_imgs_out)
    # visualize(group_images(orig_imgs,N_visual),path_experiment+"all_originals")#.show()
    if average_mode == True:
        visualize(group_images(pred_imgs_out, N_visual),
                  path_experiment + "all_predictions_" + str(test_epoch) + "thresh_epoch")
    else:
        visualize(group_images(pred_imgs_out, N_visual),
                  path_experiment + "all_predictions_" + str(test_epoch) + "epoch_no_average")
    visualize(group_images(gtruth_masks, N_visual), path_experiment + "all_groundTruths")

    # visualize results comparing mask and prediction:
    # assert (orig_imgs.shape[0] == pred_imgs_out.shape[0] and orig_imgs.shape[0] == gtruth_masks.shape[0])
    # N_predicted = orig_imgs.shape[0]
    # group = N_visual
    # assert (N_predicted%group == 0)
    

    # ====== Evaluate the results
    print("\n\n========  Evaluate the results =======================")
   
    # predictions only inside the FOV
    y_scores, y_true = pred_only_FOV(pred_imgs_out, gtruth_masks, test_border_masks)  # returns data only inside the FOV
    '''
    print("Calculating results only inside the FOV:")
    print("y scores pixels: " + str(
        y_scores.shape[0]) + " (radius 270: 270*270*3.14==228906), including background around retina: " + str(
        pred_imgs_out.shape[0] * pred_imgs_out.shape[2] * pred_imgs_out.shape[3]) + " (584*565==329960)")
    print("y true pixels: " + str(
        y_true.shape[0]) + " (radius 270: 270*270*3.14==228906), including background around retina: " + str(
        gtruth_masks.shape[2] * gtruth_masks.shape[3] * gtruth_masks.shape[0]) + " (584*565==329960)")
    '''
    # Area under the ROC curve
    fpr, tpr, thresholds = roc_curve((y_true), y_scores)
    AUC_ROC = roc_auc_score(y_true, y_scores)
    # test_integral = np.trapz(tpr,fpr) #trapz is numpy integration
    print("\nArea under the ROC curve: " + str(AUC_ROC))
    rOc_curve = plt.figure()
    plt.plot(fpr, tpr, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_ROC)
    plt.title('ROC curve')
    plt.xlabel("FPR (False Positive Rate)")
    plt.ylabel("TPR (True Positive Rate)")
    plt.legend(loc="lower right")
    plt.savefig(path_experiment + "ROC.png")

    # Precision-recall curve
    precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
    precision = np.fliplr([precision])[0]  # so the array is increasing (you won't get negative AUC)
    recall = np.fliplr([recall])[0]  # so the array is increasing (you won't get negative AUC)
    AUC_prec_rec = np.trapz(precision, recall)
    print("\nArea under Precision-Recall curve: " + str(AUC_prec_rec))
    prec_rec_curve = plt.figure()
    plt.plot(recall, precision, '-', label='Area Under the Curve (AUC = %0.4f)' % AUC_prec_rec)
    plt.title('Precision - Recall curve')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend(loc="lower right")
    plt.savefig(path_experiment + "Precision_recall.png")

    # Confusion matrix
    threshold_confusion = 0.5
    print("\nConfusion matrix:  Custom threshold (for positive) of " + str(threshold_confusion))
    y_pred = np.empty((y_scores.shape[0]))
    for i in range(y_scores.shape[0]):
        if y_scores[i] >= threshold_confusion:
            y_pred[i] = 1
        else:
            y_pred[i] = 0
    confusion = confusion_matrix(y_true, y_pred)
    print(confusion)
    accuracy = 0
    if float(np.sum(confusion)) != 0:
        accuracy = float(confusion[0, 0] + confusion[1, 1]) / float(np.sum(confusion))
    print("Global Accuracy: " + str(accuracy))
    specificity = 0
    if float(confusion[0, 0] + confusion[0, 1]) != 0:
        specificity = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])
    print("Specificity: " + str(specificity))
    sensitivity = 0
    if float(confusion[1, 1] + confusion[1, 0]) != 0:
        sensitivity = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])
    print("Sensitivity: " + str(sensitivity))
    precision = 0
    if float(confusion[1, 1] + confusion[0, 1]) != 0:
        precision = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[0, 1])
    print("Precision: " + str(precision))

    # Jaccard similarity index
    jaccard_index = jaccard_similarity_score(y_true, y_pred, normalize=True)
    print("\nJaccard similarity score: " + str(jaccard_index))

    # F1 score
    F1_score = f1_score(y_true, y_pred, labels=None, average='binary', sample_weight=None)
    print("\nF1 score (F-measure): " + str(F1_score))
    ####evaluate the thin vessels
    thin_3pixel_recall_indivi = []
    thin_3pixel_auc_roc = []
    for j in range(pred_imgs_out.shape[0]):
        thick3=opening(gtruth_masks[j, 0, :, :], square(3))
        thin_gt = gtruth_masks[j, 0, :, :] - thick3
        
        thin_pred=pred_imgs_out[j, 0, :, :]
        
        thin_pred[thick3==1]=0
        thin_3pixel_recall_indivi.append(round(thin_recall(thin_gt, pred_imgs_out[j, 0, :, :], thresh=0.5), 4))
        thin_3pixel_auc_roc.append(round(roc_auc_score(thin_gt.flatten(), thin_pred.flatten()), 4))
    thin_2pixel_recall_indivi = []
    thin_2pixel_auc_roc = []
    for j in range(pred_imgs_out.shape[0]):
        thick=opening(gtruth_masks[j, 0, :, :], square(2))
        thin_gt = gtruth_masks[j, 0, :, :] - thick
        #thin_gt_only=thin_gt[thin_gt==1]
        #print(thin_gt_only)
        thin_pred=pred_imgs_out[j, 0, :, :]
        #thin_pred=thin_pred[thin_gt==1]
        thin_pred[thick==1]=0
        thin_2pixel_recall_indivi.append(round(thin_recall(thin_gt, pred_imgs_out[j, 0, :, :], thresh=0.5), 4))
        thin_2pixel_auc_roc.append(round(roc_auc_score(thin_gt.flatten(), thin_pred.flatten()), 4))
    
    #print("thin 2vessel recall:", thin_2pixel_recall_indivi)
    #print('thin 2vessel auc score', thin_2pixel_auc_roc)
    # Save the results
    with open(path_experiment + 'test_performances_all_epochs.txt', mode='a') as f:
        f.write("\n\n" + path_experiment + " test epoch:" + str(test_epoch)
                + '\naverage mode is:' + str(average_mode)
                + "\nArea under the ROC curve: %.4f" % (AUC_ROC)
                + "\nArea under Precision-Recall curve: %.4f" % (AUC_prec_rec)
                + "\nJaccard similarity score: %.4f" % (jaccard_index)
                + "\nF1 score (F-measure): %.4f" % (F1_score)
                + "\nConfusion matrix:"
                + str(confusion)
                + "\nACCURACY: %.4f" % (accuracy)
                + "\nSENSITIVITY: %.4f" % (sensitivity)
                + "\nSPECIFICITY: %.4f" % (specificity)
                + "\nPRECISION: %.4f" % (precision)
                + "\nthin 2vessels recall indivi:\n" + str(thin_2pixel_recall_indivi)
                + "\nthin 2vessels recall mean:%.4f" % (np.mean(thin_2pixel_recall_indivi))
                + "\nthin 2vessels auc indivi:\n" + str(thin_2pixel_auc_roc)
                + "\nthin 2vessels auc score mean:%.4f" % (np.mean(thin_2pixel_auc_roc))
                + "\nthin 3vessels recall indivi:\n" + str(thin_3pixel_recall_indivi)
                + "\nthin 3vessels recall mean:%.4f" % (np.mean(thin_3pixel_recall_indivi))
                + "\nthin 3vessels auc indivi:\n" + str(thin_3pixel_auc_roc)
                + "\nthin 3vessels auc score mean:%.4f" % (np.mean(thin_3pixel_auc_roc))
                )
Exemplo n.º 8
0
    LE_512 = cropImage(LE_img, IMG_SHAPE[0], IMG_SHAPE[1])
    sample_le = {}
    for le_512 in LE_512:
        tiles = crop_prepare(le_512, CROP_STEP, IMG_SIZE)
        for n, img in enumerate(tiles):
            if n not in sample_le:
                sample_le[n] = []
            img = transform.resize(img, (IMG_SIZE * 2, IMG_SIZE * 2),
                                   preserve_range=True,
                                   order=3)
            sample_le[n].append(img)

    SC_UNET = UNet(n_channels=15, n_classes=1)
    print("{} paramerters in total".format(
        sum(x.numel() for x in SC_UNET.parameters())))
    SC_UNET.cuda(cuda)
    SC_UNET.load_state_dict(torch.load(model_path))
    # SC_UNET.load_state_dict(torch.load(os.path.join(dir_path,"model","HE_HER_mito","HE_X2_HER_0825.pkl")))
    SC_UNET.eval()

    SRRFDATASET = ReconsDataset(img_dict=sample_he,
                                transform=ToTensor(),
                                in_norm=LE_in_norm,
                                img_type=".tif",
                                in_size=256)
    test_dataloader = torch.utils.data.DataLoader(
        SRRFDATASET, batch_size=1, shuffle=False,
        pin_memory=True)  # better than for loop
    result = np.zeros((256, 256, len(SRRFDATASET)))
    for batch_idx, items in enumerate(test_dataloader):
        image = items['image_in']