Ejemplo n.º 1
0
def cosegmentation_plot_closure(iter_number,
                                losses,
                                other_outs,
                                mask_outs,
                                same_outs,
                                original_images,
                                show_every=1000):
    """

    :param iter_number: the number of the iteration
    :param int show_every:

    :return:
    """
    # TODO: handle other ratios
    other_outs_np = [torch_to_np(other_out) for other_out in other_outs]
    same_outs_np = [torch_to_np(same_out) for same_out in same_outs]
    mask_outs_np = [torch_to_np(mask_out) for mask_out in mask_outs]
    print(('Iteration %05d    Loss ' + (" %f " * len(losses))) %
          (iter_number, *[l.item() for l in losses]),
          '\r',
          end='')
    if iter_number % show_every == 0:
        for i, (other_out_np, same_out_np, mask_out_np) in enumerate(
                zip(other_outs_np, same_outs_np, mask_outs_np)):
            plot_image_grid(
                "segment_{}_{}".format(iter_number, i),
                [np.clip(other_out_np, 0, 1),
                 np.clip(same_out_np, 0, 1)])
            plot_image_grid(
                "mask_{}_{}".format(iter_number, i),
                [np.clip(mask_out_np, 0, 1),
                 np.clip(1 - mask_out_np, 0, 1)])
Ejemplo n.º 2
0
 def _update_result_closure(self, step):
     self.current_result = ManyImageWatermarkResult(
         cleans=[torch_to_np(c) for c in self.clean_nets_outputs],
         watermark=torch_to_np(self.watermark_net_output),
         mask=torch_to_np(self.mask_net_output),
         psnr=self.current_psnr)
     if self.best_result is None or self.best_result.psnr <= self.current_result.psnr:
         self.best_result = self.current_result
Ejemplo n.º 3
0
 def _update_result_closure(self, step):
     self.current_result = TwoImageWatermarkResult(
         clean1=torch_to_np(self.clean_net_output1),
         clean2=torch_to_np(self.clean_net_output2),
         watermark=torch_to_np(self.watermark_net_output),
         psnr=self.current_psnr)
     if self.best_result is None or self.best_result.psnr <= self.current_result.psnr:
         self.best_result = self.current_result
Ejemplo n.º 4
0
 def finalize(self):
     self.net_output1 = self.net1(self.net_input1)
     self.net_output2 = self.net2(self.net_input2)
     save_image("original_image1", self.image1)
     save_image("original_image2", self.image2)
     save_image("learn_on", self.first_half)
     save_image("apply_on", self.second_half)
     save_image("learned_image1", torch_to_np(self.net_output1))
     save_image("learned_image2", torch_to_np(self.net_output2))
Ejemplo n.º 5
0
 def _iteration_plot_closure(self, step_number, iter_number):
     clean_out_np1 = torch_to_np(self.clean_net_output1)
     watermark_out_np = torch_to_np(self.watermark_net_output)
     self.current_psnr = compare_psnr(
         self.image1,
         self.watermark_hint * watermark_out_np + clean_out_np1)
     if self.current_gradient is not None:
         print('Iteration {:5d} total_loss {:5f} grad {:5f} PSNR {:5f} '.
               format(iter_number, self.total_loss.item(),
                      self.current_gradient.item(), self.current_psnr),
               '\r',
               end='')
     else:
         print('Iteration {:5d} total_loss {:5f} PSNR {:5f} '.format(
             iter_number, self.total_loss.item(), self.current_psnr),
               '\r',
               end='')
Ejemplo n.º 6
0
    def _step_plot_closure(self, step_number):
        """
        runs at the end of each step
        :param step_number:
        :return:
        """
        for image_name, image, clean_net_output in zip(
                self.images_names, self.images, self.clean_nets_outputs):

            # plot_image_grid(image_name + "_watermark_clean_{}".format(step_number),
            #                 [np.clip(torch_to_np(self.watermark_net_output), 0, 1),
            #                  np.clip(torch_to_np(clean_net_output), 0, 1)])

            plot_image_grid(
                image_name + "_learned_image_{}".format(step_number), [
                    np.clip(
                        torch_to_np(self.watermark_net_output) *
                        torch_to_np(self.mask_net_output) +
                        (1 - torch_to_np(self.mask_net_output)) *
                        torch_to_np(clean_net_output), 0, 1), image
                ])
Ejemplo n.º 7
0
 def _iteration_plot_closure(self, step_number, iter_number):
     clean_out_nps = [
         torch_to_np(clean_net_output)
         for clean_net_output in self.clean_nets_outputs
     ]
     watermark_out_np = torch_to_np(self.watermark_net_output)
     mask_out_np = torch_to_np(self.mask_net_output)
     self.current_psnr = compare_psnr(
         self.images[0], clean_out_nps[0] * (1 - mask_out_np) +
         mask_out_np * watermark_out_np)
     if self.current_gradient is not None:
         print('Iteration {:5d} total_loss {:5f} grad {:5f} PSNR {:5f} '.
               format(iter_number, self.total_loss.item(),
                      self.current_gradient.item(), self.current_psnr),
               '\r',
               end='')
     else:
         print('Iteration {:5d} total_loss {:5f} PSNR {:5f} '.format(
             iter_number, self.total_loss.item(), self.current_psnr),
               '\r',
               end='')
Ejemplo n.º 8
0
    def _step_plot_closure(self, step_number):
        """
        runs at the end of each step
        :param step_number:
        :return:
        """
        if self.watermark_hint is not None:
            plot_image_grid("watermark_hint_{}".format(step_number), [
                np.clip(self.watermark_hint, 0, 1),
                np.clip(1 - self.watermark_hint, 0, 1)
            ])

        plot_image_grid("watermark_clean_{}".format(step_number), [
            np.clip(torch_to_np(self.watermark_net_output), 0, 1),
            np.clip(torch_to_np(self.clean_net_output1), 0, 1)
        ])

        plot_image_grid("learned_image1_{}".format(step_number), [
            np.clip(
                self.watermark_hint * torch_to_np(self.watermark_net_output) +
                torch_to_np(self.clean_net_output1), 0, 1), self.image1
        ])
        plot_image_grid("learned_image2_{}".format(step_number), [
            np.clip(
                self.watermark_hint * torch_to_np(self.watermark_net_output) +
                torch_to_np(self.clean_net_output2), 0, 1), self.image2
        ])
Ejemplo n.º 9
0
def deblurring_plot_closure(iter_number, total_loss, image_out, blurred_image, blur_mask, original_image, show_every=1000):
    """

    :param iter_number: the number of the iteration
    :param total_loss: the total loss the the current iteration step
    :param image_out: torch tensor left output
    :param original_image: original numpy image
    :param blur_mask: the learned torch mask
    :param blurred_image: the blurred image that is obtained
    :param int show_every:

    :return:
    """
    image_out_np = torch_to_np(image_out)
    blurred_image_np = torch_to_np(blurred_image)
    psrn_gt = compare_psnr(original_image, blurred_image_np)
    print('Iteration %05d    Loss %f  PSRN_gt: %f ' % (iter_number, total_loss.item(), psrn_gt), '\r', end='')
    if iter_number % show_every == 0:
        plot_image_grid([np.clip(image_out_np, 0, 1),
                         np.clip(original_image, 0, 1)], 4, 5)
        plot_image_grid([np.clip(original_image, 0, 1),
                         np.clip(blurred_image_np, 0, 1)], 4, 5)
Ejemplo n.º 10
0
def denoising(noise_im,
              clean_im,
              LR=1e-2,
              sigma=5,
              rho=1,
              eta=0.5,
              total_step=20,
              prob1_iter=500,
              result_root=None,
              f=None):

    input_depth = 3
    latent_dim = 3

    en_net = Encoder(input_depth,
                     latent_dim,
                     down_sample_norm='batchnorm',
                     up_sample_norm='batchnorm').cuda()
    de_net = Decoder(latent_dim,
                     input_depth,
                     down_sample_norm='batchnorm',
                     up_sample_norm='batchnorm').cuda()

    model = net(3, 3, nc=64, nb=20, act_mode='R')
    model_path = '/home/dihan/KAIR/model_zoo/dncnn_color_blind.pth'
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    for k, v in model.named_parameters():
        v.requires_grad = False

    model = model.cuda()

    noise_im_torch = np_to_torch(noise_im)
    noise_im_torch = noise_im_torch.cuda()

    with torch.no_grad():
        r_dncnn_np = torch_to_np(model(noise_im_torch))

    psnr_dncnn = compare_psnr(clean_im.transpose(1, 2, 0),
                              r_dncnn_np.transpose(1, 2, 0), 1)
    ssim_dncnn = compare_ssim(r_dncnn_np.transpose(1, 2, 0),
                              clean_im.transpose(1, 2, 0),
                              multichannel=True,
                              data_range=1)

    print('PSNR_DNCNN: {}, SSIM_DNCNN: {}'.format(psnr_dncnn, ssim_dncnn),
          file=f,
          flush=True)

    parameters = [p for p in en_net.parameters()
                  ] + [p for p in de_net.parameters()]
    optimizer = torch.optim.Adam(parameters, lr=LR)
    l2_loss = torch.nn.MSELoss(reduction='sum').cuda()

    i0 = np_to_torch(noise_im).cuda()
    Y = torch.zeros_like(noise_im_torch).cuda()

    i0_til_torch = np_to_torch(noise_im).cuda()

    diff_original_np = noise_im.astype(np.float32) - clean_im.astype(
        np.float32)
    diff_original_name = 'Original_dis.png'
    save_hist(diff_original_np, result_root + diff_original_name)

    best_psnr = 0
    best_ssim = 0

    for i in range(total_step):

        ############################### sub-problem 1 #################################

        for i_1 in range(prob1_iter):

            optimizer.zero_grad()

            mean, log_var = en_net(noise_im_torch)

            z = sample_z(mean, log_var)
            out = de_net(z)

            total_loss = 0.5 * l2_loss(out, noise_im_torch)
            total_loss += kl_loss(mean, log_var, i0, sigma)
            total_loss += (rho / 2) * l2_loss(i0 + Y, i0_til_torch)

            total_loss.backward()
            optimizer.step()

            with torch.no_grad():
                i0 = ((1 / sigma**2) * mean + rho *
                      (i0_til_torch - Y)) / ((1 / sigma**2) + rho)

        with torch.no_grad():

            ############################### sub-problem 2 #################################

            i0_til_torch = model(i0 + Y)

            ############################### sub-problem 3 #################################

            Y = Y + eta * (i0 - i0_til_torch)

            ###############################################################################

            i0_np = torch_to_np(i0)
            Y_np = torch_to_np(Y)

            denoise_obj_pil = np_to_pil((i0_np + Y_np).clip(0, 1))

            Y_norm_np = np.sqrt((Y_np * Y_np).sum(0))

            i0_pil = np_to_pil(i0_np)

            mean_np = torch_to_np(mean)

            mean_pil = np_to_pil(mean_np)

            out_np = torch_to_np(out)
            out_pil = np_to_pil(out_np)

            diff_np = mean_np - clean_im

            denoise_obj_name = 'denoise_obj_{:04d}'.format(i) + '.png'
            Y_name = 'Y_{:04d}'.format(i) + '.png'
            i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png'
            mean_i_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png'
            out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png'
            diff_name = 'Latent_dis_num_epoch_{:04d}'.format(i) + '.png'

            denoise_obj_pil.save(result_root + denoise_obj_name)
            save_heatmap(Y_norm_np, result_root + Y_name)
            i0_pil.save(result_root + i0_name)
            mean_pil.save(result_root + mean_i_name)
            out_pil.save(result_root + out_name)
            save_hist(diff_np, result_root + diff_name)

            i0_til_np = torch_to_np(i0_til_torch).clip(0, 1)

            psnr = compare_psnr(clean_im.transpose(1, 2, 0),
                                i0_til_np.transpose(1, 2, 0), 1)
            ssim = compare_ssim(clean_im.transpose(1, 2, 0),
                                i0_til_np.transpose(1, 2, 0),
                                multichannel=True,
                                data_range=1)
            i0_til_pil = np_to_pil(i0_til_np)
            i0_til_pil.save(os.path.join(result_root, '{}'.format(i) + '.png'))

            print('Iteration: %02d, VAE Loss: %f, PSNR: %f, SSIM: %f' %
                  (i, total_loss.item(), psnr, ssim),
                  file=f,
                  flush=True)

            if best_psnr < psnr:
                best_psnr = psnr
                best_ssim = ssim
            else:
                break

    return i0_til_np, best_psnr, best_ssim