def cosegmentation_plot_closure(iter_number, losses, other_outs, mask_outs, same_outs, original_images, show_every=1000): """ :param iter_number: the number of the iteration :param int show_every: :return: """ # TODO: handle other ratios other_outs_np = [torch_to_np(other_out) for other_out in other_outs] same_outs_np = [torch_to_np(same_out) for same_out in same_outs] mask_outs_np = [torch_to_np(mask_out) for mask_out in mask_outs] print(('Iteration %05d Loss ' + (" %f " * len(losses))) % (iter_number, *[l.item() for l in losses]), '\r', end='') if iter_number % show_every == 0: for i, (other_out_np, same_out_np, mask_out_np) in enumerate( zip(other_outs_np, same_outs_np, mask_outs_np)): plot_image_grid( "segment_{}_{}".format(iter_number, i), [np.clip(other_out_np, 0, 1), np.clip(same_out_np, 0, 1)]) plot_image_grid( "mask_{}_{}".format(iter_number, i), [np.clip(mask_out_np, 0, 1), np.clip(1 - mask_out_np, 0, 1)])
def _update_result_closure(self, step): self.current_result = ManyImageWatermarkResult( cleans=[torch_to_np(c) for c in self.clean_nets_outputs], watermark=torch_to_np(self.watermark_net_output), mask=torch_to_np(self.mask_net_output), psnr=self.current_psnr) if self.best_result is None or self.best_result.psnr <= self.current_result.psnr: self.best_result = self.current_result
def _update_result_closure(self, step): self.current_result = TwoImageWatermarkResult( clean1=torch_to_np(self.clean_net_output1), clean2=torch_to_np(self.clean_net_output2), watermark=torch_to_np(self.watermark_net_output), psnr=self.current_psnr) if self.best_result is None or self.best_result.psnr <= self.current_result.psnr: self.best_result = self.current_result
def finalize(self): self.net_output1 = self.net1(self.net_input1) self.net_output2 = self.net2(self.net_input2) save_image("original_image1", self.image1) save_image("original_image2", self.image2) save_image("learn_on", self.first_half) save_image("apply_on", self.second_half) save_image("learned_image1", torch_to_np(self.net_output1)) save_image("learned_image2", torch_to_np(self.net_output2))
def _iteration_plot_closure(self, step_number, iter_number): clean_out_np1 = torch_to_np(self.clean_net_output1) watermark_out_np = torch_to_np(self.watermark_net_output) self.current_psnr = compare_psnr( self.image1, self.watermark_hint * watermark_out_np + clean_out_np1) if self.current_gradient is not None: print('Iteration {:5d} total_loss {:5f} grad {:5f} PSNR {:5f} '. format(iter_number, self.total_loss.item(), self.current_gradient.item(), self.current_psnr), '\r', end='') else: print('Iteration {:5d} total_loss {:5f} PSNR {:5f} '.format( iter_number, self.total_loss.item(), self.current_psnr), '\r', end='')
def _step_plot_closure(self, step_number): """ runs at the end of each step :param step_number: :return: """ for image_name, image, clean_net_output in zip( self.images_names, self.images, self.clean_nets_outputs): # plot_image_grid(image_name + "_watermark_clean_{}".format(step_number), # [np.clip(torch_to_np(self.watermark_net_output), 0, 1), # np.clip(torch_to_np(clean_net_output), 0, 1)]) plot_image_grid( image_name + "_learned_image_{}".format(step_number), [ np.clip( torch_to_np(self.watermark_net_output) * torch_to_np(self.mask_net_output) + (1 - torch_to_np(self.mask_net_output)) * torch_to_np(clean_net_output), 0, 1), image ])
def _iteration_plot_closure(self, step_number, iter_number): clean_out_nps = [ torch_to_np(clean_net_output) for clean_net_output in self.clean_nets_outputs ] watermark_out_np = torch_to_np(self.watermark_net_output) mask_out_np = torch_to_np(self.mask_net_output) self.current_psnr = compare_psnr( self.images[0], clean_out_nps[0] * (1 - mask_out_np) + mask_out_np * watermark_out_np) if self.current_gradient is not None: print('Iteration {:5d} total_loss {:5f} grad {:5f} PSNR {:5f} '. format(iter_number, self.total_loss.item(), self.current_gradient.item(), self.current_psnr), '\r', end='') else: print('Iteration {:5d} total_loss {:5f} PSNR {:5f} '.format( iter_number, self.total_loss.item(), self.current_psnr), '\r', end='')
def _step_plot_closure(self, step_number): """ runs at the end of each step :param step_number: :return: """ if self.watermark_hint is not None: plot_image_grid("watermark_hint_{}".format(step_number), [ np.clip(self.watermark_hint, 0, 1), np.clip(1 - self.watermark_hint, 0, 1) ]) plot_image_grid("watermark_clean_{}".format(step_number), [ np.clip(torch_to_np(self.watermark_net_output), 0, 1), np.clip(torch_to_np(self.clean_net_output1), 0, 1) ]) plot_image_grid("learned_image1_{}".format(step_number), [ np.clip( self.watermark_hint * torch_to_np(self.watermark_net_output) + torch_to_np(self.clean_net_output1), 0, 1), self.image1 ]) plot_image_grid("learned_image2_{}".format(step_number), [ np.clip( self.watermark_hint * torch_to_np(self.watermark_net_output) + torch_to_np(self.clean_net_output2), 0, 1), self.image2 ])
def deblurring_plot_closure(iter_number, total_loss, image_out, blurred_image, blur_mask, original_image, show_every=1000): """ :param iter_number: the number of the iteration :param total_loss: the total loss the the current iteration step :param image_out: torch tensor left output :param original_image: original numpy image :param blur_mask: the learned torch mask :param blurred_image: the blurred image that is obtained :param int show_every: :return: """ image_out_np = torch_to_np(image_out) blurred_image_np = torch_to_np(blurred_image) psrn_gt = compare_psnr(original_image, blurred_image_np) print('Iteration %05d Loss %f PSRN_gt: %f ' % (iter_number, total_loss.item(), psrn_gt), '\r', end='') if iter_number % show_every == 0: plot_image_grid([np.clip(image_out_np, 0, 1), np.clip(original_image, 0, 1)], 4, 5) plot_image_grid([np.clip(original_image, 0, 1), np.clip(blurred_image_np, 0, 1)], 4, 5)
def denoising(noise_im, clean_im, LR=1e-2, sigma=5, rho=1, eta=0.5, total_step=20, prob1_iter=500, result_root=None, f=None): input_depth = 3 latent_dim = 3 en_net = Encoder(input_depth, latent_dim, down_sample_norm='batchnorm', up_sample_norm='batchnorm').cuda() de_net = Decoder(latent_dim, input_depth, down_sample_norm='batchnorm', up_sample_norm='batchnorm').cuda() model = net(3, 3, nc=64, nb=20, act_mode='R') model_path = '/home/dihan/KAIR/model_zoo/dncnn_color_blind.pth' model.load_state_dict(torch.load(model_path), strict=True) model.eval() for k, v in model.named_parameters(): v.requires_grad = False model = model.cuda() noise_im_torch = np_to_torch(noise_im) noise_im_torch = noise_im_torch.cuda() with torch.no_grad(): r_dncnn_np = torch_to_np(model(noise_im_torch)) psnr_dncnn = compare_psnr(clean_im.transpose(1, 2, 0), r_dncnn_np.transpose(1, 2, 0), 1) ssim_dncnn = compare_ssim(r_dncnn_np.transpose(1, 2, 0), clean_im.transpose(1, 2, 0), multichannel=True, data_range=1) print('PSNR_DNCNN: {}, SSIM_DNCNN: {}'.format(psnr_dncnn, ssim_dncnn), file=f, flush=True) parameters = [p for p in en_net.parameters() ] + [p for p in de_net.parameters()] optimizer = torch.optim.Adam(parameters, lr=LR) l2_loss = torch.nn.MSELoss(reduction='sum').cuda() i0 = np_to_torch(noise_im).cuda() Y = torch.zeros_like(noise_im_torch).cuda() i0_til_torch = np_to_torch(noise_im).cuda() diff_original_np = noise_im.astype(np.float32) - clean_im.astype( np.float32) diff_original_name = 'Original_dis.png' save_hist(diff_original_np, result_root + diff_original_name) best_psnr = 0 best_ssim = 0 for i in range(total_step): ############################### sub-problem 1 ################################# for i_1 in range(prob1_iter): optimizer.zero_grad() mean, log_var = en_net(noise_im_torch) z = sample_z(mean, log_var) out = de_net(z) total_loss = 0.5 * l2_loss(out, noise_im_torch) total_loss += kl_loss(mean, log_var, i0, sigma) total_loss += (rho / 2) * l2_loss(i0 + Y, i0_til_torch) total_loss.backward() optimizer.step() with torch.no_grad(): i0 = ((1 / sigma**2) * mean + rho * (i0_til_torch - Y)) / ((1 / sigma**2) + rho) with torch.no_grad(): ############################### sub-problem 2 ################################# i0_til_torch = model(i0 + Y) ############################### sub-problem 3 ################################# Y = Y + eta * (i0 - i0_til_torch) ############################################################################### i0_np = torch_to_np(i0) Y_np = torch_to_np(Y) denoise_obj_pil = np_to_pil((i0_np + Y_np).clip(0, 1)) Y_norm_np = np.sqrt((Y_np * Y_np).sum(0)) i0_pil = np_to_pil(i0_np) mean_np = torch_to_np(mean) mean_pil = np_to_pil(mean_np) out_np = torch_to_np(out) out_pil = np_to_pil(out_np) diff_np = mean_np - clean_im denoise_obj_name = 'denoise_obj_{:04d}'.format(i) + '.png' Y_name = 'Y_{:04d}'.format(i) + '.png' i0_name = 'i0_num_epoch_{:04d}'.format(i) + '.png' mean_i_name = 'Latent_im_num_epoch_{:04d}'.format(i) + '.png' out_name = 'res_of_dec_num_epoch_{:04d}'.format(i) + '.png' diff_name = 'Latent_dis_num_epoch_{:04d}'.format(i) + '.png' denoise_obj_pil.save(result_root + denoise_obj_name) save_heatmap(Y_norm_np, result_root + Y_name) i0_pil.save(result_root + i0_name) mean_pil.save(result_root + mean_i_name) out_pil.save(result_root + out_name) save_hist(diff_np, result_root + diff_name) i0_til_np = torch_to_np(i0_til_torch).clip(0, 1) psnr = compare_psnr(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), 1) ssim = compare_ssim(clean_im.transpose(1, 2, 0), i0_til_np.transpose(1, 2, 0), multichannel=True, data_range=1) i0_til_pil = np_to_pil(i0_til_np) i0_til_pil.save(os.path.join(result_root, '{}'.format(i) + '.png')) print('Iteration: %02d, VAE Loss: %f, PSNR: %f, SSIM: %f' % (i, total_loss.item(), psnr, ssim), file=f, flush=True) if best_psnr < psnr: best_psnr = psnr best_ssim = ssim else: break return i0_til_np, best_psnr, best_ssim