def save_torch(img_torch, root): img_np = torch_to_np(img_torch) img_pil = np_to_pil(img_np) img_pil.save(root)
def denoising(noise_im, clean_im, LR=1e-2, sigma=3, rho=1, eta=0.5, total_step=30, prob1_iter=500, noise_level=None, result_root=None, f=None): input_depth = 3 latent_dim = 3 en_net = UNet(input_depth, latent_dim).to(device) de_net = UNet(latent_dim, input_depth).to(device) parameters = [p for p in en_net.parameters()] + [p for p in de_net.parameters()] optimizer = torch.optim.Adam(parameters, lr=LR) l2_loss = torch.nn.MSELoss().cuda() i0 = np_to_torch(noise_im).to(device) noise_im_torch = np_to_torch(noise_im).to(device) i0_til_torch = np_to_torch(noise_im).to(device) Y = torch.zeros_like(noise_im_torch).to(device) diff_original_np = noise_im.astype(np.float32) - clean_im.astype(np.float32) diff_original_name = 'Original_dis.png' save_hist(diff_original_np, result_root+diff_original_name) best_psnr = 0 for i in range(total_step): ################################# sub-problem 1 ############################### for i_1 in range(prob1_iter): optimizer.zero_grad() mean = en_net(noise_im_torch) z = sample_z(mean) out = de_net(z) total_loss = 0.5 * l2_loss(out, noise_im_torch) total_loss += 0.5 * (1/sigma**2)*l2_loss(mean, i0) total_loss += (rho/2) * l2_loss(i0 + Y, i0_til_torch) total_loss.backward() optimizer.step() with torch.no_grad(): i0 = ((1/sigma**2)*mean.detach() + rho*(i0_til_torch - Y)) / ((1/sigma**2) + rho) with torch.no_grad(): ################################# sub-problem 2 ############################### i0_np = torch_to_np(i0) Y_np = torch_to_np(Y) sig = eval_sigma(i+1, noise_level) i0_til_np = bm3d.bm3d_rgb(i0_np.transpose(1, 2, 0) + Y_np.transpose(1, 2, 0), sig).transpose(2, 0, 1) i0_til_torch = np_to_torch(i0_til_np).to(device) ################################# sub-problem 3 ############################### Y = Y + eta * (i0 - i0_til_torch) ############################################################################### Y_name = 'Y_{:04d}'.format(i) + '.png' i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png' mean_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png' out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png' diff_name = 'Latent_dis_num_epoch_{:04d}'.format(i) + '.png' Y_np = torch_to_np(Y) Y_norm_np = np.sqrt((Y_np*Y_np).sum(0)) save_heatmap(Y_norm_np, result_root + Y_name) save_torch(mean, result_root + mean_name) save_torch(out, result_root + out_name) save_torch(i0, result_root + i0_name) mean_np = torch_to_np(mean) diff_np = mean_np - clean_im save_hist(diff_np, result_root + diff_name) i0_til_np = torch_to_np(i0_til_torch).clip(0, 255) psnr = compare_psnr(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), 255) ssim = compare_ssim(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), multichannel=True, data_range=255) i0_til_pil = np_to_pil(i0_til_np) i0_til_pil.save(os.path.join(result_root, '{}'.format(i) + '.png')) print('Iteration: {:02d}, VAE Loss: {:f}, PSNR: {:f}, SSIM: {:f}'.format(i, total_loss.item(), psnr, ssim), file=f, flush=True) if best_psnr < psnr: best_psnr = psnr best_ssim = ssim else: break return i0_til_np, best_psnr, best_ssim
def denoising_gray(noise_im, clean_im, LR=1e-2, sigma=5, rho=1, eta=0.5, total_step=20, prob1_iter=1000, result_root=None, fo=None): input_depth = 1 latent_dim = 1 en_net = UNet(input_depth, latent_dim, need_sigmoid = False).cuda() de_net = UNet(latent_dim, input_depth, need_sigmoid = False).cuda() model = DnCNN(1, 1, nc=64, nb=20, act_mode='R') model_path = './model_zoo/dncnn_gray_blind.pth' model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.cuda() for k, v in model.named_parameters(): v.requires_grad = False en_optimizer = torch.optim.Adam(en_net.parameters(), lr = LR) de_optimizer = torch.optim.Adam(de_net.parameters(), lr = LR) l2_loss = torch.nn.MSELoss().cuda() i0 = np_to_torch(noise_im).to(device) noise_im_torch = np_to_torch(noise_im).to(device) Y = torch.zeros_like(noise_im_torch).to(device) i0_til_torch = np_to_torch(noise_im).to(device) best_psnr = 0 best_ssim = 0 for i in range(total_step): ############################### sub-problem 1 ################################# prob1_iter = iteration_decay(prob1_iter) for i_1 in range(prob1_iter): mean = en_net(noise_im_torch) eps = mean.clone().normal_() out = de_net(mean + eps) total_loss = 0.5 * l2_loss(out, noise_im_torch) total_loss += 0.5 * 1/(sigma**2) * l2_loss(mean, i0) en_optimizer.zero_grad() de_optimizer.zero_grad() total_loss.backward() en_optimizer.step() de_optimizer.step() with torch.no_grad(): i0 = ((1/sigma**2)*mean + rho*(i0_til_torch - Y)) / ((1/sigma**2) + rho) with torch.no_grad(): ############################### sub-problem 2 ################################# i0_til_torch = model((i0+Y)/255) * 255 ############################### sub-problem 3 ################################# Y = Y + eta * (i0 - i0_til_torch) ############################################################################### i0_til_np = torch_to_np(i0_til_torch).clip(0, 255) psnr_gt = compare_psnr(clean_im, i0_til_np, 255) ssim_gt = compare_ssim(i0_til_np, clean_im, multichannel=False, data_range=255) denoise_obj_name = 'denoise_obj_{:04d}'.format(i) + '.png' i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png' result_name = 'num_epoch_{:04d}'.format(i) + '.png' mean_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png' out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png' save_torch(Y+i0, result_root + denoise_obj_name) save_torch(i0, result_root + i0_name) save_torch(i0_til_torch, result_root + result_name) save_torch(mean, result_root + mean_name) save_torch(out, result_root + out_name) print('Iteration %02d Loss %f PSNR_gt: %f, SSIM_gt: %f' % (i, total_loss.item(), psnr_gt, ssim_gt), file=fo, flush=True) if best_psnr < psnr_gt: best_psnr = psnr_gt best_ssim = ssim_gt else: break return best_psnr, best_ssim