def test_wienerUnet(data_path,\ psf_path, method,\ scale, \ model_path,\ visual, \ use_gpu,\ n_iter=10,\ b_size=1): ''' Gradient descent scheme + predicted with UNet per-image gradient of a regularizer ''' model_name = method + '_poisson' save_images_path = './Results/' + model_name + '_peak_' + str( int(scale)) + '/' test_dataset = CellDataset(data_path, psf_path, 'poisson', scale, 0.0) test_loader = DataLoader(test_dataset, batch_size=b_size, shuffle=False) model = model_load('poisson', method, model_path) model.eval() if use_gpu == 1: model.cuda() psnr_values_test = [] ssim_values_test = [] distorted_psnr_test = [] distorted_ssim_test = [] with torch.no_grad(): for i_batch, ((gt, image), psf, index, image_name, peak, _) in enumerate(tqdm(test_loader)): image = image.reshape( (b_size, 1, image.shape[-2], image.shape[-1])) gt = gt.reshape((b_size, 1, gt.shape[-2], gt.shape[-1])) psf = psf.reshape([b_size, 1, psf.shape[-2], psf.shape[-1]]).float() if use_gpu == 1: image = image.cuda() gt = gt.cuda() psf = psf.cuda() image_batch_tmp = image.clone() image = anscombe(image) image = EdgeTaper.apply(pad_psf_shape(image, psf), psf[0][0]) out = model(image, psf, n_iter) out = exact_unbiased(out) out = crop_psf_shape(out, psf) for j in range(out.shape[0]): out[j] = out[j] / gt[j].max() image_batch_tmp[j] = image_batch_tmp[j] / gt[j].max() gt[j] /= gt[j].max() distorted_psnr = calc_psnr(image_batch_tmp.clamp(0, 1), gt) distorted_ssim = ssim(image_batch_tmp.clamp(0, 1), gt) psnr_test = calc_psnr(out.clamp(0, 1), gt) s_sim_test = ssim(out.clamp(0, 1), gt) psnr_values_test.append(psnr_test.item()) ssim_values_test.append(s_sim_test.item()) distorted_psnr_test.append(distorted_psnr.item()) distorted_ssim_test.append(distorted_ssim.item()) #Save image if visual == 1: if not os.path.exists(save_images_path): os.makedirs(save_images_path, exist_ok=True) io.imsave(os.path.join(save_images_path, 'output_' + str(image_name[0][:-4]) + '_' + \ str(model_name) + '_' + str(int(scale)) + '.png'), np.uint8(out[0][0].detach().cpu().numpy().clip(0,1) * 255.)) print('Test on Poisson noise with peak %d: PSNR %.2f, SSIM %.4f, distorted PSNR %.2f, distorted SSIM %.4f' % (peak, np.array(psnr_values_test).mean(), \ np.array(ssim_values_test).mean(), \ np.array(distorted_psnr_test).mean(), \ np.array(distorted_ssim_test).mean())) return