コード例 #1
0
def test_wienerUnet(data_path,\
                    psf_path,
                    method,\
                    scale, \
                    model_path,\
                    visual, \
                    use_gpu,\
                    n_iter=10,\
                    b_size=1):
    '''
    Gradient descent scheme + predicted with UNet per-image gradient of a regularizer
    '''

    model_name = method + '_poisson'
    save_images_path = './Results/' + model_name + '_peak_' + str(
        int(scale)) + '/'

    test_dataset = CellDataset(data_path, psf_path, 'poisson', scale, 0.0)
    test_loader = DataLoader(test_dataset, batch_size=b_size, shuffle=False)

    model = model_load('poisson', method, model_path)
    model.eval()

    if use_gpu == 1:
        model.cuda()

    psnr_values_test = []
    ssim_values_test = []

    distorted_psnr_test = []
    distorted_ssim_test = []

    with torch.no_grad():

        for i_batch, ((gt, image), psf, index, image_name, peak,
                      _) in enumerate(tqdm(test_loader)):

            image = image.reshape(
                (b_size, 1, image.shape[-2], image.shape[-1]))

            gt = gt.reshape((b_size, 1, gt.shape[-2], gt.shape[-1]))

            psf = psf.reshape([b_size, 1, psf.shape[-2],
                               psf.shape[-1]]).float()

            if use_gpu == 1:
                image = image.cuda()
                gt = gt.cuda()
                psf = psf.cuda()

            image_batch_tmp = image.clone()

            image = anscombe(image)

            image = EdgeTaper.apply(pad_psf_shape(image, psf), psf[0][0])

            out = model(image, psf, n_iter)

            out = exact_unbiased(out)

            out = crop_psf_shape(out, psf)

            for j in range(out.shape[0]):
                out[j] = out[j] / gt[j].max()
                image_batch_tmp[j] = image_batch_tmp[j] / gt[j].max()
                gt[j] /= gt[j].max()

            distorted_psnr = calc_psnr(image_batch_tmp.clamp(0, 1), gt)
            distorted_ssim = ssim(image_batch_tmp.clamp(0, 1), gt)

            psnr_test = calc_psnr(out.clamp(0, 1), gt)
            s_sim_test = ssim(out.clamp(0, 1), gt)

            psnr_values_test.append(psnr_test.item())
            ssim_values_test.append(s_sim_test.item())
            distorted_psnr_test.append(distorted_psnr.item())
            distorted_ssim_test.append(distorted_ssim.item())

            #Save image
            if visual == 1:

                if not os.path.exists(save_images_path):
                    os.makedirs(save_images_path, exist_ok=True)

                io.imsave(os.path.join(save_images_path, 'output_' + str(image_name[0][:-4]) + '_' + \
                          str(model_name) + '_' + str(int(scale)) + '.png'), np.uint8(out[0][0].detach().cpu().numpy().clip(0,1) * 255.))

    print('Test on Poisson noise with peak %d: PSNR %.2f, SSIM %.4f, distorted PSNR %.2f, distorted SSIM %.4f' % (peak, np.array(psnr_values_test).mean(), \
                                                                                              np.array(ssim_values_test).mean(), \
                                                                                              np.array(distorted_psnr_test).mean(), \
                                                                                              np.array(distorted_ssim_test).mean()))

    return