def save_torch(img_torch, root):
    img_np = torch_to_np(img_torch)
    img_pil = np_to_pil(img_np)
    img_pil.save(root)
def denoising(noise_im, clean_im, LR=1e-2, sigma=3, rho=1, eta=0.5, total_step=30, 
              prob1_iter=500, noise_level=None, result_root=None, f=None):
    
    input_depth = 3
    latent_dim = 3
    
    en_net = UNet(input_depth, latent_dim).to(device)
    de_net = UNet(latent_dim, input_depth).to(device)
    
    parameters = [p for p in en_net.parameters()] + [p for p in de_net.parameters()]
    optimizer = torch.optim.Adam(parameters, lr=LR)
    
    l2_loss = torch.nn.MSELoss().cuda()
    
    i0 = np_to_torch(noise_im).to(device)
    noise_im_torch = np_to_torch(noise_im).to(device)
    i0_til_torch = np_to_torch(noise_im).to(device)
    Y = torch.zeros_like(noise_im_torch).to(device)
        
    diff_original_np = noise_im.astype(np.float32) - clean_im.astype(np.float32)
    diff_original_name = 'Original_dis.png'
    save_hist(diff_original_np, result_root+diff_original_name)  
    
    best_psnr = 0
    
    for i in range(total_step):
        
################################# sub-problem 1 ###############################

        for i_1 in range(prob1_iter):
            
            optimizer.zero_grad()

            mean = en_net(noise_im_torch)
            z = sample_z(mean)
            out = de_net(z)
            
            total_loss =  0.5 * l2_loss(out, noise_im_torch)
            total_loss += 0.5 * (1/sigma**2)*l2_loss(mean, i0)
            total_loss += (rho/2) * l2_loss(i0 + Y, i0_til_torch)
            
            total_loss.backward()
            optimizer.step()
            
            with torch.no_grad():
                i0 = ((1/sigma**2)*mean.detach() + rho*(i0_til_torch - Y)) / ((1/sigma**2) + rho)
        
        with torch.no_grad():
            
################################# sub-problem 2 ###############################
            
            i0_np = torch_to_np(i0)
            Y_np = torch_to_np(Y)
            
            sig = eval_sigma(i+1, noise_level)
            
            i0_til_np = bm3d.bm3d_rgb(i0_np.transpose(1, 2, 0) + Y_np.transpose(1, 2, 0), sig).transpose(2, 0, 1)
            i0_til_torch = np_to_torch(i0_til_np).to(device)
            
################################# sub-problem 3 ###############################

            Y = Y + eta * (i0 - i0_til_torch)

###############################################################################

            Y_name = 'Y_{:04d}'.format(i) + '.png'
            i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png'
            mean_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png'
            out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png'
            diff_name = 'Latent_dis_num_epoch_{:04d}'.format(i) + '.png'
            
            Y_np = torch_to_np(Y)
            Y_norm_np = np.sqrt((Y_np*Y_np).sum(0))
            save_heatmap(Y_norm_np, result_root + Y_name)
            
            save_torch(mean, result_root + mean_name)
            save_torch(out, result_root + out_name)
            save_torch(i0, result_root + i0_name)
            
            mean_np = torch_to_np(mean)
            diff_np = mean_np - clean_im
            save_hist(diff_np, result_root + diff_name)

            i0_til_np = torch_to_np(i0_til_torch).clip(0, 255)
            psnr = compare_psnr(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), 255)
            ssim = compare_ssim(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), multichannel=True, data_range=255)
            
            i0_til_pil = np_to_pil(i0_til_np)
            i0_til_pil.save(os.path.join(result_root, '{}'.format(i) + '.png'))

            print('Iteration: {:02d}, VAE Loss: {:f}, PSNR: {:f}, SSIM: {:f}'.format(i, total_loss.item(), psnr, ssim), file=f, flush=True)
                
            if best_psnr < psnr:
                best_psnr = psnr
                best_ssim = ssim
            else:
                break
            
    return i0_til_np, best_psnr, best_ssim
Exemple #3
0
def denoising_gray(noise_im, clean_im, LR=1e-2, sigma=5, rho=1, eta=0.5, 
                   total_step=20, prob1_iter=1000, result_root=None, fo=None):
    input_depth = 1
    latent_dim = 1
    
    en_net = UNet(input_depth, latent_dim, need_sigmoid = False).cuda()
    de_net = UNet(latent_dim, input_depth, need_sigmoid = False).cuda()
    
    model = DnCNN(1, 1, nc=64, nb=20, act_mode='R')
    model_path = './model_zoo/dncnn_gray_blind.pth'
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.cuda()
    for k, v in model.named_parameters():
        v.requires_grad = False

    en_optimizer = torch.optim.Adam(en_net.parameters(), lr = LR)
    de_optimizer = torch.optim.Adam(de_net.parameters(), lr = LR)
    
    l2_loss = torch.nn.MSELoss().cuda()
    
    i0 = np_to_torch(noise_im).to(device)
    noise_im_torch = np_to_torch(noise_im).to(device)
    Y = torch.zeros_like(noise_im_torch).to(device)
    i0_til_torch = np_to_torch(noise_im).to(device)
    
    best_psnr = 0
    best_ssim = 0
    
    for i in range(total_step):
        
############################### sub-problem 1 #################################

        prob1_iter = iteration_decay(prob1_iter)

        for i_1 in range(prob1_iter):
            mean = en_net(noise_im_torch)

            eps = mean.clone().normal_()
            out = de_net(mean + eps)

            total_loss =  0.5 * l2_loss(out, noise_im_torch)
            total_loss += 0.5 * 1/(sigma**2) * l2_loss(mean, i0)
            
            en_optimizer.zero_grad()
            de_optimizer.zero_grad()

            total_loss.backward()

            en_optimizer.step()
            de_optimizer.step()
            
            with torch.no_grad():
                i0 = ((1/sigma**2)*mean + rho*(i0_til_torch - Y)) / ((1/sigma**2) + rho)
        
        with torch.no_grad():
            
############################### sub-problem 2 #################################
            
            i0_til_torch = model((i0+Y)/255) * 255

############################### sub-problem 3 #################################

            Y = Y + eta * (i0 - i0_til_torch)

###############################################################################

            i0_til_np = torch_to_np(i0_til_torch).clip(0, 255)            
            psnr_gt = compare_psnr(clean_im, i0_til_np, 255)
            ssim_gt = compare_ssim(i0_til_np, clean_im, multichannel=False, data_range=255)
            
            denoise_obj_name = 'denoise_obj_{:04d}'.format(i) + '.png'            
            i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png'
            result_name = 'num_epoch_{:04d}'.format(i) + '.png'
            mean_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png'
            out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png'
            
            save_torch(Y+i0, result_root + denoise_obj_name)
            save_torch(i0, result_root + i0_name)
            save_torch(i0_til_torch, result_root + result_name)
            save_torch(mean, result_root + mean_name)
            save_torch(out, result_root + out_name)

            print('Iteration %02d  Loss %f  PSNR_gt: %f, SSIM_gt: %f' % (i, total_loss.item(), psnr_gt, ssim_gt), file=fo, flush=True)
            
            if best_psnr < psnr_gt:
                best_psnr = psnr_gt
                best_ssim = ssim_gt
            else:
                break
    
    return best_psnr, best_ssim