def thtest_static(args, enc, dec, test_loader, use_psnr=False): # test a denoising task device = args.device enc.eval() dec.eval() test_loss = 0 correct = 0 idx = 0 with torch.no_grad(): for pic_set, th_img in tqdm(test_loader): set_loss = 0 th_img = th_img.to(device) pic_set = pic_set.to(device) N = len(pic_set) BS = len(pic_set[0]) pshape = pic_set[0][0].shape shape = ( N, BS, ) + pshape rec_set = reconstruct_set(pic_set, enc, dec, args.share_enc) rec_set = rescale_noisy_image(rec_set) cmp_img = th_img.expand(shape) set_loss = F.mse_loss(cmp_img, rec_set).item() if use_psnr: set_loss = mse_to_psnr(set_loss) test_loss += set_loss idx += 1 test_loss /= len(test_loader) print('\nTest set: Average loss: {:2.3e}\n'.format(test_loss)) return test_loss
def cnn_forward(cnn_info, cfg, z, raw_img, t_img): # -- cuda -- gpuid = cnn_info.gpuid raw_img = raw_img.cuda(gpuid) cnn_info.model = cnn_info.model.cuda(gpuid) z = z.cuda(gpuid) raw_img = raw_img.cuda(gpuid) t_img = t_img.cuda(gpuid) # -- init -- BS = raw_img.shape[0] cnn_info.optimizer.zero_grad() cnn_info.model.zero_grad() # -- forward pass -- z_prime = torch.normal(z, 1. / 20) rec_img = cnn_info.model(z_prime) loss = F.mse_loss(t_img, rec_img) # -- sgd step -- loss.backward() cnn_info.optimizer.step() # -- psnr -- loss = F.mse_loss(raw_img, rec_img, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(loss)) return psnr, loss, rec_img
def thtest_denoising(cfg, model, test_loader): model.eval() test_loss = 0 idx = 0 with torch.no_grad(): for noisy_imgs, raw_img in tqdm(test_loader): set_loss = 0 noisy_imgs = noisy_imgs.cuda(non_blocking=True) raw_img = raw_img.cuda(non_blocking=True) noisy_imgs = noisy_imgs.cuda(non_blocking=True) dec_imgs, proj = model(noisy_imgs) dec_imgs = rescale_noisy_image(dec_imgs) N = len(dec_imgs) BS = len(dec_imgs[0]) dshape = ( N, BS, ) + dec_imgs.shape[2:] dec_imgs = dec_imgs.reshape(dshape) raw_img = raw_img.expand(dshape) loss = F.mse_loss(raw_img, dec_imgs).item() if cfg.test_with_psnr: loss = mse_to_psnr(loss) test_loss += loss idx += 1 test_loss /= len(test_loader) print('\nTest set: Average loss: {:2.3e}\n'.format(test_loss)) return test_loss
def attn_forward(attn_info, raw_img, stacked_burst, idx, iters, loss_diff, loss_prev): # -- cuda -- gpuid = attn_info.gpuid stacked_burst = stacked_burst.cuda(gpuid) raw_img = raw_img.cuda(gpuid) attn_info.model = attn_info.model.cuda(gpuid) # -- init -- BS = raw_img.shape[0] attn_info.optimizer.zero_grad() attn_info.model.zero_grad() # -- forward pass -- lossE, rec_img = attn_info.model(stacked_burst, stacked_burst, idx, iters, loss_diff) loss_diff = lossE.item() - loss_prev loss_prev = lossE.item() # -- sgd step -- lossE.backward() attn_info.optimizer.step() # -- psnr -- rec_img += 0.5 loss = F.mse_loss(raw_img, rec_img, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(loss)) return psnr, lossE, rec_img, loss_diff, loss_prev
def test_loop(cfg, model, criterion, test_loader, epoch): model.eval() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 with torch.no_grad(): for batch_idx, (burst_imgs, res_img, raw_img) in enumerate(test_loader): BS = raw_img.shape[0] # reshaping of data raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) img0 = burst_imgs[0] # denoising pred_res = model(img0) rec_img = img0 - pred_res # compare with stacked targets rec_img = rescale_noisy_image(rec_img).detach() loss = F.mse_loss(raw_img, rec_img, reduction='none').reshape(BS, -1) # loss = F.mse_loss(burst_imgs[0]+0.5,raw_img,reduction='none').reshape(BS,-1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = mse_to_psnr(loss) total_psnr += np.mean(psnr) total_loss += np.mean(loss) if (batch_idx % cfg.test_log_interval) == 0: root = Path( f"{settings.ROOT_PATH}/output/n2n/rec_imgs/e{epoch}") if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}.png") nrow = int(np.sqrt(cfg.batch_size)) rec_img = rec_img.detach().cpu() grid_imgs = vutils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) plt.imshow(grid_imgs.permute(1, 2, 0)) plt.savefig(fn) plt.close('all') ave_psnr = total_psnr / len(test_loader) ave_loss = total_loss / len(test_loader) print("Testing results: Ave psnr %2.3e Ave loss %2.3e" % (ave_psnr, ave_loss)) return ave_psnr
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 for batch_idx, (burst_imgs, raw_img) in enumerate(train_loader): # for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) # res_imgs = res_imgs.cuda(non_blocking=True) img0 = burst_imgs[0] # img0,res0 = burst_imgs[0],res_imgs[0] # img1,res1 = burst_imgs[1],res_imgs[1] # -- predict residual -- pred_res = model(img0) rec_img = img0 - pred_res # -- compare with stacked burst -- loss = F.mse_loss(raw_img, rec_img + 0.5) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img, rec_img + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e" % (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr)) running_loss = 0 total_loss /= len(train_loader) return total_loss
def test_denoising(cfg, model, test_loader): model.eval() test_loss = 0 idx = 0 with torch.no_grad(): for noisy_imgs, raw_img in tqdm(test_loader): set_loss = 0 N, BS = noisy_imgs.shape[:2] p_shape = noisy_imgs.shape[2:] noisy_imgs = noisy_imgs.cuda(non_blocking=True) raw_img = raw_img.cuda(non_blocking=True) noisy_imgs = noisy_imgs.cuda(non_blocking=True) noisy_imgs = noisy_imgs.view((N * BS, ) + p_shape) dec_imgs = model(noisy_imgs) dec_no_rescale = dec_imgs dec_imgs = rescale_noisy_image(dec_imgs) dshape = ( N, BS, ) + p_shape dec_imgs = dec_imgs.reshape(dshape) raw_img = raw_img.expand(dshape) if idx == 10: print('dec_no_rescale', dec_no_rescale.mean(), dec_no_rescale.min(), dec_no_rescale.max()) print('noisy', noisy_imgs.mean(), noisy_imgs.min(), noisy_imgs.max()) print('dec', dec_imgs.mean(), dec_imgs.min(), dec_imgs.max()) print('raw', raw_img.mean(), raw_img.min(), raw_img.max()) loss = F.mse_loss(raw_img, dec_imgs).item() if cfg.test_with_psnr: loss = mse_to_psnr(loss) test_loss += loss idx += 1 test_loss /= len(test_loader) print('\nTest set: Average loss: {:2.3e}\n'.format(test_loss)) return test_loss
def kpn_forward(kpn_info, cfg, raw_img, stacked_burst, cat_burst, t_img, idx, iters): # -- cuda -- gpuid = kpn_info.gpuid raw_img = raw_img.cuda(gpuid) kpn_info.model = kpn_info.model.cuda(gpuid) stacked_burst = stacked_burst.cuda(gpuid) cat_burst = cat_burst.cuda(gpuid) t_img = t_img.cuda(gpuid) # -- init -- BS = raw_img.shape[0] kpn_info.optimizer.zero_grad() kpn_info.model.zero_grad() # -- forward pass -- rec_img_i, rec_img = kpn_info.model(cat_burst, stacked_burst) lossE_ = kpn_info.criterion(rec_img_i, rec_img, t_img, cfg.global_step) # lossE_ = criterion(rec_img_i, rec_img, t_img, cfg.global_step) lossE = np.sum(lossE_) # -- sgd step -- lossE.backward() kpn_info.optimizer.step() # -- post -- cfg.global_step += 30 # rec_img += 0.5 # -- psnr -- loss = F.mse_loss(raw_img, rec_img, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(loss)) return psnr, lossE, rec_img
def train_loop(cfg,model_target,model_online,optim_target,optim_online,criterion,train_loader,epoch,record_losses): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # setup for train epoch # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup for training -- model_online.train() model_online = model_online.to(cfg.device) model_target.train() model_target = model_target.to(cfg.device) moving_average_decay = 0.99 ema_updater = EMA(moving_average_decay) # -- init vars -- N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize,1,0,blocksize) D = 5 * 10**3 use_record = False if record_losses is None: record_losses = pd.DataFrame({'burst':[],'ave':[],'ot':[],'psnr':[],'psnr_std':[]}) nc_losses,nc_count = 0,0 al_ot_losses,al_ot_count = 0,0 rec_ot_losses,rec_ot_count = 0,0 write_examples = True write_examples_iter = 800 noise_level = cfg.noise_params['g']['stddev'] one = torch.FloatTensor([1.]).to(cfg.device) switch = True # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # run training epoch # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- for batch_idx, (burst, res_imgs, raw_img, directions) in enumerate(train_loader): if batch_idx > D: break # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # forward pass # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- zero gradient -- optim_online.zero_grad() optim_target.zero_grad() model_online.zero_grad() model_online.denoiser_info.optim.zero_grad() model_target.zero_grad() model_target.denoiser_info.optim.zero_grad() # -- reshaping of data -- N,BS,C,H,W = burst.shape burst = burst.cuda(non_blocking=True) stacked_burst = rearrange(burst,'n b c h w -> b n c h w') # -- create target image -- mid_img = burst[N//2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: t_img = szm(raw_img.cuda(non_blocking=True)) else: t_img = burst[N//2] # -- direct denoising -- aligned_o,aligned_ave_o,denoised_o,rec_img_o,filters_o = model_online(burst) aligned_t,aligned_ave_t,denoised_t,rec_img_t,filters_t = model_target(burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # alignment losses # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute aligned losses to optimize -- rec_img_d_o = rec_img_o.detach() losses = criterion(aligned_o,aligned_ave_o,rec_img_d_o,t_img,raw_zm_img,cfg.global_step) nc_loss,ave_loss,burst_loss,ot_loss = [loss.item() for loss in losses] kpn_loss = losses[1] + losses[2] # np.sum(losses) kpn_coeff = .9997**cfg.global_step # -- OT loss -- al_ot_loss = torch.FloatTensor([0.]).to(cfg.device) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # reconstruction losses # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- decaying rec loss -- rec_mse_coeff = 0.997**cfg.global_step rec_mse_loss = F.mse_loss(rec_img_o,mid_img) # -- BYOL loss -- byol_loss = F.mse_loss(rec_img_o,rec_img_t) # -- OT loss -- rec_ot_coeff = 100 residuals = denoised_o - mid_img.unsqueeze(1).repeat(1,N,1,1,1) residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # rec_ot_loss = ot_pairwise_bp(residuals,K=3) rec_ot_loss = torch.FloatTensor([0.]).to(cfg.device) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # final losses & recording # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- final losses -- align_loss = kpn_coeff * kpn_loss denoise_loss = rec_ot_coeff * rec_ot_loss + byol_loss + rec_mse_coeff * rec_mse_loss # -- update alignment kl loss info -- al_ot_losses += al_ot_loss.item() al_ot_count += 1 # -- update reconstruction kl loss info -- rec_ot_losses += rec_ot_loss.item() rec_ot_count += 1 # -- update info -- if not np.isclose(nc_loss,0): nc_losses += nc_loss nc_count += 1 running_loss += align_loss.item() + denoise_loss.item() total_loss += align_loss.item() + denoise_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # backprop and optimize # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- loss = align_loss + denoise_loss loss.backward() # -- backprop for [online] -- optim_online.step() model_online.denoiser_info.optim.step() # -- exponential moving average for [target] -- update_moving_average(ema_updater,model_target,model_online) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # message to stdout # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img,aligned_ave_o+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst,dim=1) mse_loss = F.mse_loss(raw_img,mis_ave+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0,2)+0.5, sigma_psd=noise_level/255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0,2) b_loss = F.mse_loss(raw_img[b].cpu(),bm3d_rec,reduction='none').reshape(BS,-1) b_loss = torch.mean(b_loss,1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1,N,1,1,1) mse_loss = F.mse_loss(raw_img_repN,denoised_o+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img,rec_img_o+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- write record -- if use_record: record_losses = record_losses.append({'burst':burst_loss,'ave':ave_loss,'ot':ot_loss,'psnr':psnr,'psnr_std':psnr_std},ignore_index=True) # -- update losses -- running_loss /= cfg.log_interval ave_nc_loss = nc_losses / nc_count if nc_count > 0 else 0 # -- alignment kl loss -- ave_al_ot_loss = al_ot_losses / al_ot_count if al_ot_count > 0 else 0 al_ot_losses,al_ot_count = 0,0 # -- reconstruction kl loss -- ave_rec_ot_loss = rec_ot_losses / rec_ot_count if rec_ot_count > 0 else 0 rec_ot_losses,rec_ot_count = 0,0 # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx,len(train_loader),running_loss,psnr,psnr_std, psnr_denoised_ave,psnr_denoised_std,psnr_aligned_ave,psnr_aligned_std, psnr_misaligned_ave,psnr_misaligned_std,bm3d_nb_ave,bm3d_nb_std, ave_nc_loss,ave_rec_ot_loss,ave_al_ot_loss) print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [denoised]: %2.2f +/- %2.2f [aligned]: %2.2f +/- %2.2f [misaligned]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [loss-nc]: %.2e [loss-rot]: %.2e [loss-aot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and (batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg,stacked_burst,aligned_o,denoised_o,filters_o,directions) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss,record_losses
def dip_loop(cfg, test_loader, epoch): total_psnr = 0 total_loss = 0 num_samples = 0 ave_psnrs, std_psnrs = [], [] for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(test_loader): # for batch_idx, (burst_imgs, raw_img) in enumerate(test_loader): BS = raw_img.shape[0] N, BS, C, H, W = burst_imgs.shape # -- selecting input frames -- input_order = np.arange(cfg.N) # print("pre",input_order) # if cfg.blind or True: middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = cfg.N // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle], input_order[middle + 1:]] else: # input_order = np.arange(cfg.N) middle = len(input_order) // 2 middle_img_idx = input_order[middle] input_order = np.arange(cfg.N) # print("post",input_order,middle_img_idx,cfg.blind,cfg.N) # -- reshaping of data -- # raw_img = raw_img.cuda(non_blocking=True) # burst_imgs = burst_imgs.cuda(non_blocking=True) if cfg.color_cat: stacked_burst = torch.cat( [burst_imgs[input_order[x]] for x in range(cfg.input_N)], dim=1) else: stacked_burst = torch.stack( [burst_imgs[input_order[x]] for x in range(cfg.input_N)], dim=1) # stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=0) # stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=0) stacked_burst = torch.stack( [burst_imgs[input_order[x]] for x in range(cfg.input_N)], dim=1) cat_burst = torch.cat( [burst_imgs[input_order[x]] for x in range(cfg.input_N)], dim=1) # -- dip denoising -- # img = burst_imgs[middle_img_idx] + 0.5 t_img = burst_imgs[middle_img_idx] + 0.5 img = stacked_burst + 0.5 # -- baseline psnr -- ave_rec = torch.mean(stacked_burst, dim=1) + 0.5 b_loss = F.mse_loss(raw_img, ave_rec, reduction='none').reshape(BS, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() ave_psnr = np.mean(mse_to_psnr(b_loss)) # -- bm3d -- bm3d_rec = bm3d.bm3d(t_img[0].transpose(0, 2), sigma_psd=25 / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[0], bm3d_rec, reduction='none').reshape(BS, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) # -- blind bm3d -- noisy_mid = t_img[0].transpose(0, 2) bm3d_rec = None if bm3d_rec is None: sigma_est = torch.std(noisy_mid - 0.5) else: sigma_est = torch.std(noisy_mid - bm3d_rec.transpose(0, 2)) bm3d_rec = bm3d.bm3d(noisy_mid, sigma_psd=sigma_est, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[0], bm3d_rec, reduction='none').reshape(BS, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_b_psnr = np.mean(mse_to_psnr(b_loss)) # img = torch.normal(raw_img,25./255) # z = torch.normal(0,torch.ones_like(img[0].unsqueeze(0))) # print(z.shape) # z = z.requires_grad_(True) diff = 100 iters = 2000 tol = 5e-9 # params = [params.data.clone() for params in model.parameters()] # stacked_burst = torch.normal(0,torch.ones( ( BS, N, C, H, W) )) # stacked_burst = stacked_burst.cuda(non_blocking=True) # cat_burst = rearrange(stacked_burst,'bs n c h w -> bs (n c) h w') repeats = 1 best_attn_psnr, best_kpn_psnr, best_cnn_psnr = 0, 0, 0 best_rec = 0 psnrs = [] for repeat in range(repeats): idx = 0 repeat_psnr = 0 attn_info = get_attn_model(cfg, 0) kpn_info = get_kpn_model(cfg, 1) cnn_info = get_cnn_model(cfg, 2) z = torch.normal(0, torch.ones_like(t_img)) z_stack = torch.normal(0, torch.ones_like(stacked_burst)) attn_loss_diff, attn_loss_prev = 1., 0 while (idx < iters): idx += 1 # z_img = z + torch.normal(0,torch.ones_like(z)) * 1./20 # stacked_burst_i = torch.normal(stacked_burst,1./20) # cat_burst_i = torch.normal(cat_burst,1./20) # print('m',torch.mean( (stacked_burst_i - stacked_burst)**2) ) # z_img = z # rec_img = model(z_img) # -- create inputs for kpn -- # -- attn model -- fwd_args = [ attn_info, raw_img, stacked_burst, idx, iters, attn_loss_diff, attn_loss_prev ] attn_psnr, loss_attn, rec_attn, attn_loss_diff, attn_loss_prev = attn_forward( *fwd_args) # attn_psnr,loss_attn,rec_attn,attn_loss_diff,attn_loss_prev = 0,torch.Tensor([0]),0,0,0 # -- kpn model -- fwd_args = [ kpn_info, cfg, raw_img, stacked_burst, cat_burst, t_img, idx, iters ] kpn_psnr, loss_kpn, rec_kpn = kpn_forward(*fwd_args) # -- cnn model; middle frame -- fwd_args = [cnn_info, cfg, z, raw_img, t_img] cnn_psnr, loss_cnn, rec_cnn = cnn_forward(*fwd_args) if (idx % 1) == 0 or idx == 1: if (idx % 250) == 0 or idx == 1: print( "[%d] [%d/%d] [PSNR] [attn: %2.2f] [kpn: %2.2f] [cnn-m: %2.2f] [ave: %2.2f] [bm3d-nb: %2.2f] [bm3d-b: %2.2f]" % (batch_idx, idx, iters, attn_psnr, kpn_psnr, cnn_psnr, ave_psnr, bm3d_nb_psnr, bm3d_b_psnr)) if attn_psnr > best_attn_psnr: best_attn_psnr = attn_psnr best_rec_attn = rec_attn if kpn_psnr > best_kpn_psnr: best_kpn_psnr = kpn_psnr best_rec_kpn = rec_kpn if cnn_psnr > best_cnn_psnr: best_cnn_psnr = cnn_psnr best_rec_cnn = rec_cnn if torch.isinf(loss_attn) or torch.isinf(loss_kpn): print("UH OH! inf loss") break # b = list(model.parameters())[0].clone() # print("EQ?",torch.equal(a.data,b.data)) # print(torch.mean(a.data - b.data)**2) # params_p = [params.data.clone() for params in model.parameters()] # diff = np.mean([float(torch.mean((p - p_p)**2).cpu().item()) for p,p_p in zip(params,params_p)]) # print("diff: {:.2e}".format(diff)) # params = params_p # if best_psnr > 29: break # rec_img = model(z) print( f"Best PSNR [attn: {best_attn_psnr}] [kpn: {best_kpn_psnr}] [cnn-m: {best_cnn_psnr}]" ) # -- compare with stacked targets -- # rec_img = rescale_noisy_image(rec_img) # loss = F.mse_loss(raw_img,rec_img,reduction='none').reshape(BS,-1) # loss = torch.mean(loss,1).detach().cpu().numpy() # psnr = mse_to_psnr(loss) # print(np.mean(psnr)) # total_psnr += best_psnr num_samples += 1 # ave_psnrs.append(np.mean(psnrs)) # std_psnrs.append(np.std(psnrs)) # total_loss += np.mean(loss) # if (batch_idx % cfg.test_log_interval) == 0: if True: root = Path( f"{settings.ROOT_PATH}/output/dip/rec_imgs/N{cfg.N}/e{epoch}/") if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}_attn.png") nrow = 2 # int(np.sqrt(cfg.batch_size)) rec_img = best_rec_attn.detach().cpu() # rec_img -= rec_img.min() # rec_img /= rec_img.max() # print(rec_img.mean(),rec_img.min(),rec_img.max()) # print(raw_img.mean(),raw_img.min(),raw_img.max()) save_img = torch.cat([rec_img, raw_img.cpu()], dim=0) grid_imgs = vutils.make_grid(save_img, padding=2, normalize=False, nrow=nrow, pad_value=0) plt.title(f"PSNR: {best_attn_psnr}") plt.imshow(grid_imgs.permute(1, 2, 0)) plt.savefig(fn) plt.close('all') print(f"Saved figure to {fn}") ave_psnr = total_psnr / num_samples # ave_psnr = total_psnr / len(test_loader) # ave_loss = total_loss / len(test_loader) print( "[Blind: %d | N: %d] Testing results: Ave psnr %2.3e Ave loss %2.3e" % (cfg.blind, cfg.N, ave_psnr, ave_loss)) return ave_psnr
def train_loop(cfg, model, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 align_ot_losses, align_ot_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) steps_per_epoch = len(train_loader) write_examples_iter = steps_per_epoch // 3 # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- model.align_info.model.zero_grad() model.align_info.optim.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, B, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Require Approx Equal Filter Norms # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_filters = rearrange(denoised_filters.detach(), 'b n k2 c h w -> n (b k2 c h w)') norms = denoised_filters.norm(dim=1) norm_loss_denoiser = torch.mean((norms - norms[N // 2])**2) norm_loss_coeff = 100. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Decrease Entropy within a Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_entropy = 0 filters_entropy_coeff = 1000. all_filters = [] L = len(align_hook.filters) iter_filters = align_hook.filters if L > 0 else [aligned_filters] for filters in iter_filters: filters_shaped = rearrange(filters, 'b n k2 c h w -> (b n c h w) k2', n=N) filters_entropy += entropyLoss(filters_shaped) all_filters.append(filters) if L > 0: filters_entropy /= L all_filters = torch.stack(all_filters, dim=1) align_hook.clear() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Increase Entropy across each Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_dist_entropy = 0 # -- across each frame -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each batch -- filters_shaped = rearrange(all_filters, 'b l n k2 c h w -> (n l) (b c h w) k2') filters_shaped = torch.mean(filters_shaped, dim=1) filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each kpn cascade -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) filters_dist_coeff = 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- fs = cfg.dynamic.frame_size residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) centered_residuals = tvF.center_crop(residuals, (fs // 2, fs // 2)) align_ot = kl_gaussian_bp(centered_residuals, noise_level, flip=True) align_ot_coeff = 100. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = denoiseLossMSE(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- regularization scheduler -- if cfg.global_step < 100: reg = 0.5 elif cfg.global_step < 200: reg = 0.25 elif cfg.global_step < 5000: reg = 0.15 elif cfg.global_step < 10000: reg = 0.1 else: reg = 0.05 # -- computation -- residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) # residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) rec_ot_loss_v1 = kl_gaussian_bp(residuals, noise_level, flip=True) # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals) # rec_ot_loss_v1 = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16) # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3) rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2) rec_ot_coeff = 100. # - .997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse norm_loss = norm_loss_coeff * norm_loss_denoiser align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot entropy_loss = filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy final_loss = align_loss + rec_loss + entropy_loss + norm_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- alignment Dist -- align_ot_losses += align_ot.item() align_ot_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.align_info.optim.step() model.denoiser_info.optim.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] M = 10 if B > 10 else B for b in range(B): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- alignment Dist. -- align_ot_ave = align_ot_losses / align_ot_count align_ot_losses, align_ot_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def train_loop_n2n(cfg, model, optimizer, criterion, train_loader, epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 train_iter = iter(train_loader) K = cfg.sim_K noise_type = cfg.noise_params.ntype noise_level = cfg.noise_params['g']['stddev'] # raw_offset,raw_scale = 0,0 # if noise_type in ["g","hg"]: # raw_offset = 0.5 # if noise_type == "g": # noise_level = cfg.noise_params[noise_type]['stddev'] # elif noise_type == "hg": # noise_level = cfg.noise_params[noise_type]['read'] # elif noise_type == "qis": # noise_params = cfg.noise_params[noise_type] # noise_level = noise_params['readout'] # raw_scale = ( 2**noise_params['nbits']-1 ) / noise_params['alpha'] cfg.noise_params['qis']['alpha'] = 255.0 cfg.noise_params['qis']['readout'] = 0.0 cfg.noise_params['qis']['nbits'] = 8 noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) for batch_idx, (burst, res_img, raw_img, d) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # -- reshaping of data -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) # -- anscombe -- if cfg.use_anscombe: burst = anscombe_nmlz.forward(cfg, burst + 0.5) burst0 = burst[[0]] burst1 = burst[[1]] # img0 = burst[0] # img1 = burst[1] # kindex_ds = kIndexPermLMDB(cfg.batch_size,cfg.N) # kindex = kindex_ds[batch_idx].cuda(non_blocking=True) # kindex = None # sim_burst = compute_similar_bursts(cfg,burst0,burst1,K,noise_level/255., # patchsize=cfg.sim_patchsize, # shuffle_k=cfg.sim_shuffleK, # kindex=kindex,only_middle=True, # search_method=cfg.sim_method, # db_level="frame") # # -- select outputs -- # # -- supervised -- # img0 = burst[0] # img1 = get_nmlz_img(cfg,raw_img) # if cfg.use_anscombe: img1 = anscombe_nmlz.forward(cfg,img1+0.5)-0.5 # -- noise2noise: mismatch noise -- # img0 = burst[0] # img1 = torch.normal(raw_img-0.5,75./255.) # -- noise2noise -- img0 = burst[0] img1 = burst[1] # img1 = noise_xform(raw_img) # img1 = img1.cuda(non_blocking=True) # raw_img = raw_img.cuda(non_blocking=True) # if cfg.use_anscombe: img1 = anscombe_nmlz.forward(cfg,img1+0.5)-0.5 # raw_img = raw_img.cuda(non_blocking=True) # tv_utils.save_image(img0,'noisy0.png') # tv_utils.save_image(img1,'noisy1.png') # img1 = burst[1] # -- noise2noise + one-denoising-level -- # img0 = burst[0] # img1 = burst[1] # if cfg.global_steps < 1000: img1 = burst[1] # else: img1 = model(burst[1]).detach() # -- noise2sim -- # img0 = burst[0] # img1 = sim_burst[0][:,0] # img0 = sim_burst[0][:,0] # img1 = sim_burst[0][:,1] # -- plot example input/output -- # plt_burst = rearrange(burst,'n b c h w -> (n b) c h w') # tv_utils.save_image(plt_burst,'burst.png',nrow=BS,normalize=True) # -- denoising -- rec_img = model(img0) # -- compare with stacked burst -- # loss = F.mse_loss(raw_img,rec_img) loss = F.mse_loss(img1, rec_img) # print_tensor_stats("img1",img1) # print_tensor_stats("rec",rec_img) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- anscombe -- print_tensor_stats("burst", burst) if cfg.use_anscombe: # rec_img = torch.clamp(rec_img+0.5,0)-0.5 print_tensor_stats("rec", rec_img) rec_img = anscombe_nmlz.backward(cfg, rec_img) - 0.5 print_tensor_stats("nmlz-rec", rec_img) # -- qis noise -- # if noise_type == "qis": # rec_img += 0.5 # rec_img *= 4 # rec_img = torch.round(rec_img) # rec_img = torch.clamp(rec_img,0,4) # rec_img /= 4 # rec_img -= 0.5 # rec_img = quantize_img(cfg,rec_img+0.5)-0.5 # rec_img = get_nmlz_img(cfg,rec_img+0.5) # -- raw image normalized for noise -- # raw_img = torch.round(7*raw_img)/7. - 0.5 # raw_img = get_nmlz_img(cfg,raw_img) # raw_img = get_nmlz_img(cfg,raw_img) # -- psnr finally -- loss = F.mse_loss(raw_img, rec_img + 0.5, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = mse_to_psnr(loss) psnr_ave = np.mean(psnr) psnr_std = np.std(psnr) # print( f"Ratio of noisy to clean: {img0.mean().item() / nmlz_raw.mean().item()}" ) # print_tensor_stats("img1",img1) print_tensor_stats("rec_img", rec_img + 0.5) print_tensor_stats("raw_img", raw_img) # print_tensor_stats("nmlz_raw",nmlz_raw) # tv_utils.save_image(img0,'learn_noisy0.png',nrow=BS,normalize=True) # tv_utils.save_image(rec_img,'learn_rec_img.png',nrow=BS,normalize=True) # tv_utils.save_image(raw_img,'learn_raw_img.png',nrow=BS,normalize=True) # tv_utils.save_image(nmlz_raw,'learn_nmlz_raw.png',nrow=BS,normalize=True) running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e [PSNR] %2.2f +/- %2.2f " % (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr_ave, psnr_std)) running_loss = 0 cfg.global_steps += 1 total_loss /= len(train_loader) return total_loss
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True write_examples_iter = 800 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) D = 5 * 10**3 steps_per_epoch = len(train_loader) # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, BS, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned, aligned_ave, denoised, denoised_ave, filters = model(burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Entropy Loss for Filters # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_shaped = rearrange(filters, 'b n k2 1 1 1 -> (b n) k2', n=N) filters_entropy = entropyLoss(filters_shaped) filters_entropy_coeff = 10. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0 #.933**cfg.global_step if cfg.global_step < 100 else 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_ave_d = denoised_ave.detach() losses = criterion(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- regularization scheduler -- if cfg.global_step < 100: reg = 0.5 elif cfg.global_step < 200: reg = 0.25 elif cfg.global_step < 5000: reg = 0.15 elif cfg.global_step < 10000: reg = 0.1 else: reg = 0.05 # -- computation -- residuals = denoised - mid_img.unsqueeze(1).repeat(1, N, 1, 1, 1) residuals = rearrange(residuals, 'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) rec_ot_pair_loss_v1 = kl_gaussian_bp(residuals, noise_level) # rec_ot_pair_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_pair_loss_v2 = ot_pairwise_bp(residuals,K=3) rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) rec_ot_pair = (rec_ot_pair_loss_v1 + rec_ot_pair_loss_v2) / 2. rec_ot_pair_coeff = 100 # - .997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_loss = align_mse_coeff * align_mse rec_loss = rec_ot_pair_coeff * rec_ot_pair + rec_mse_coeff * rec_mse entropy_loss = filters_entropy_coeff * filters_entropy final_loss = align_loss + rec_loss + entropy_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot_pair.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.denoiser_info.optim.step() optimizer.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record_losses
def test_loop_n2n(cfg, model, criterion, test_loader, epoch): model.eval() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 noise_type = cfg.noise_params.ntype # raw_offset,raw_scale = 0,0 # if noise_type in ["g","hg"]: # noise_level = cfg.noise_params[noise_type]['stddev'] # raw_offset = 0.5 # elif noise_type == "qis": # params = cfg.noise_params[noise_type] # noise_level = params['readout'] # raw_scale = ( 2**params['nbits']-1 ) / params['alpha'] with torch.no_grad(): for batch_idx, (burst, res_img, raw_img, d) in enumerate(test_loader): BS = raw_img.shape[0] # reshaping of data raw_img = raw_img.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) img0 = burst[0] # -- anscombe -- if cfg.use_anscombe: img0 = anscombe_nmlz.forward(cfg, img0 + 0.5) - 0.5 # denoising rec_img = model(img0) # -- anscombe -- if cfg.use_anscombe: rec_img = anscombe_nmlz.backward(cfg, rec_img + 0.5) - 0.5 # compare with stacked targets # rec_img = rescale_noisy_image(rec_img) # if noise_type == "qis": rec_img = quantize_img(cfg,rec_img+0.5)-0.5 # nmlz_raw = get_nmlz_img(cfg,raw_img) loss = F.mse_loss(raw_img, rec_img + 0.5, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() # -- check for perfect matches -- # PSNR_MAX = 50 # if np.any(np.isinf(loss)): # loss = [] # for b in range(BS): # if np.isinf(loss[b]): loss.append(PSNR_MAX) # else: loss.append(loss[b]) psnr = mse_to_psnr(loss) total_psnr += np.mean(psnr) total_loss += np.mean(loss) if (batch_idx % cfg.test_log_interval) == 0: root = Path( f"{settings.ROOT_PATH}/output/n2n/rec_imgs/e{epoch}") if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}.png") nrow = int(np.sqrt(cfg.batch_size)) rec_img = rec_img.detach().cpu() grid_imgs = tv_utils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) plt.imshow(grid_imgs.permute(1, 2, 0)) plt.savefig(fn) plt.close('all') ave_psnr = total_psnr / len(test_loader) ave_loss = total_loss / len(test_loader) print("Testing results: Ave psnr %2.3e Ave loss %2.3e" % (ave_psnr, ave_loss)) return ave_psnr
def train_loop(cfg, model, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.unet_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) model.unet_info.model = model.unet_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 align_ot_losses, align_ot_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) steps_per_epoch = len(train_loader) write_examples_iter = steps_per_epoch // 2 # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- model.align_info.model.zero_grad() model.align_info.optim.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() model.unet_info.model.zero_grad() model.unet_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, B, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Check Some Gradients # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- def mse_v_wassersteinG_check_some_gradients(cfg, burst, gt_img, model): grads = edict() gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) model.unet_info.model.zero_grad() burst.requires_grad_(True) outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] residuals = denoised - gt_img_rs P = 1. #residuals.numel() denoised.retain_grad() rec_mse = (denoised.reshape(B, -1) - gt_img.reshape(B, -1))**2 rec_mse.retain_grad() ones = P * torch.ones_like(rec_mse) rec_mse.backward(ones, retain_graph=True) grads.rmse = rec_mse.grad.clone().reshape(B, -1) grad_rec_mse = grads.rmse grads.dmse = denoised.grad.clone().reshape(B, -1) grad_denoised_mse = grads.dmse ones = torch.ones_like(rec_mse) grads.d_to_b = torch.autograd.grad(rec_mse, denoised, ones)[0].reshape(B, -1) model.unet_info.model.zero_grad() outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # residuals = denoised - gt_img_rs # rec_ot = w_gaussian_bp(residuals,noise_level) denoised.retain_grad() rec_ot_v = (denoised - gt_img_rs)**2 rec_ot_v.retain_grad() rec_ot = (rec_ot_v.mean() - noise_level / 255.)**2 rec_ot.retain_grad() ones = P * torch.ones_like(rec_ot) rec_ot.backward(ones) grad_denoised_ot = denoised.grad.clone().reshape(B, -1) grads.dot = grad_denoised_ot grad_rec_ot = rec_ot_v.grad.clone().reshape(B, -1) grads.rot = grad_denoised_ot print("Gradient Name Info") for name, g in grads.items(): g_norm = g.norm().item() g_mean = g.mean().item() g_std = g.std().item() print(name, g.shape, g_norm, g_mean, g_std) print_pairs = False if print_pairs: print("All Gradient Ratios") for name_t, g_t in grads.items(): for name_b, g_b in grads.items(): ratio = g_t / g_b ratio_m = ratio.mean().item() ratio_std = ratio.std().item() print("[%s/%s] [%2.2e +/- %2.2e]" % (name_t, name_b, ratio_m, ratio_std)) use_true_mse = False if use_true_mse: print("Ratios with Estimated MSE Gradient") true_dmse = 2 * torch.mean(denoised_ave - gt_img)**2 ratio_mse = grads.dmse / true_dmse ratio_mse_dtb = grads.dmse / grads.d_to_b print(ratio_mse) print(ratio_mse_dtb) dot_v_dmse = True if dot_v_dmse: print("Ratio of Denoised OT and Denoised MSE") ratio_mseot = (grads.dmse / grads.dot) print(ratio_mseot.mean(), ratio_mseot.std()) ratio_mseot = ratio_mseot[0, 0].item() c1 = torch.mean((denoised - gt_img_rs)**2).item() c2 = noise_level / 255 m = torch.mean(gt_img_rs).item() true_ratio = 2. * (c1 - c2) / (np.product(burst.shape)) # diff = denoised.reshape(B,-1)-gt_img_rs.reshape(B,-1) # true_ratio = 2.*(c1 - c2) * ( diff / ( np.product(burst.shape) ) ) # print(c1,c2,m,true_ratio,1./true_ratio) ratio_mseot = (grads.dmse / (grads.dot)) print(ratio_mseot * true_ratio) # ratio_mseot = (grads.dmse / ( grads.dot / diff) ) # print(ratio_mseot*true_ratio) # print(ratio_mseot.mean(),ratio_mseot.std()) exit() model.unet_info.model.zero_grad() # mse_v_wassersteinG_check_some_gradients(cfg,burst,gt_img,model) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Require Approx Equal Filter Norms (aligned) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned_filters_rs = rearrange(aligned_filters, 'b n k2 c h w -> b n (k2 c h w)') norms = torch.norm(aligned_filters_rs, p=2., dim=2) norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N) norm_loss_align = torch.mean( torch.pow(torch.abs(norms - norms_mid), 1.)) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Require Approx Equal Filter Norms (denoised) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_filters = rearrange(denoised_filters, 'b n k2 c h w -> b n (k2 c h w)') norms = torch.norm(denoised_filters, p=2., dim=2) norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N) norm_loss_denoiser = torch.mean( torch.pow(torch.abs(norms - norms_mid), 1.)) norm_loss_coeff = 0. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Decrease Entropy within a Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_entropy = 0 filters_entropy_coeff = 0. # 1000. all_filters = [] L = len(align_hook.filters) iter_filters = align_hook.filters if L > 0 else [aligned_filters] for filters in iter_filters: filters_shaped = rearrange(filters, 'b n k2 c h w -> (b n c h w) k2', n=N) filters_entropy += entropyLoss(filters_shaped) all_filters.append(filters) if L > 0: filters_entropy /= L all_filters = torch.stack(all_filters, dim=1) align_hook.clear() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Increase Entropy across each Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_dist_entropy = 0 # -- across each frame -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each batch -- filters_shaped = rearrange(all_filters, 'b l n k2 c h w -> (n l) (b c h w) k2') filters_shaped = torch.mean(filters_shaped, dim=1) filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each kpn cascade -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) filters_dist_coeff = 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0. #0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # pad = 2*cfg.N # fs = cfg.dynamic.frame_size residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) # centered_residuals = tvF.center_crop(residuals,(fs-pad,fs-pad)) # centered_residuals = tvF.center_crop(residuals,(fs//2,fs//2)) # align_ot = kl_gaussian_bp(residuals,noise_level,flip=True) align_ot = kl_gaussian_bp_patches(residuals, noise_level, flip=True, patchsize=16) align_ot_coeff = 0 # 100. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = denoiseLossMSE(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- computation -- gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) # rec_ot = kl_gaussian_bp(residuals,noise_level) rec_ot = kl_gaussian_bp(residuals, noise_level, flip=True) # rec_ot /= 2. # alpha_grid = [0.,1.,5.,10.,25.] # for alpha in alpha_grid: # # residuals = torch.normal(torch.zeros_like(residuals)+ gt_img_rs*alpha/255.,noise_level/255.) # residuals = torch.normal(torch.zeros_like(residuals),noise_level/255.+ gt_img_rs*alpha/255.) # rec_ot_v2_a = kl_gaussian_bp_patches(residuals,noise_level,patchsize=16) # rec_ot_v1_b = kl_gaussian_bp(residuals,noise_level,flip=True) # rec_ot_v2_b = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16) # rec_ot_all = torch.tensor([rec_ot_v1_a,rec_ot_v2_a,rec_ot_v1_b,rec_ot_v2_b]) # rec_ot_v2 = (rec_ot_v2_a + rec_ot_v2_b).item()/2. # print(alpha,torch.min(rec_ot_all),torch.max(rec_ot_all),rec_ot_v1,rec_ot_v2) # exit() # rec_ot = w_gaussian_bp(residuals,noise_level) # print(residuals.numel()) rec_ot_coeff = 100. #residuals.numel()*2. # 1000.# - .997**cfg.global_step # residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) # rec_ot_loss_v1 = kl_gaussian_bp(residuals,noise_level,flip=True) # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals) # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3) # rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) # rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse norm_loss = norm_loss_coeff * (norm_loss_denoiser + norm_loss_align) align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot entropy_loss = 0 #filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy # final_loss = align_loss + rec_loss + entropy_loss + norm_loss final_loss = rec_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- alignment Dist -- align_ot_losses += align_ot.item() align_ot_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.align_info.optim.step() model.denoiser_info.optim.step() model.unet_info.optim.step() # for name,params in model.unet_info.model.named_parameters(): # if not ("weight" in name): continue # print(params.grad.norm()) # # print(module.conv1.parameters()) # # print(module.conv1.data.grad) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] M = 10 if B > 10 else B for b in range(B): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- alignment Dist. -- align_ot_ave = align_ot_losses / align_ot_count align_ot_losses, align_ot_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg,model,optimizer,criterion,train_loader,epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # -- reshaping of data -- # raw_img = raw_img.cuda(non_blocking=True) input_order = np.arange(cfg.N) # print("pre",input_order,cfg.blind,cfg.N) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = len(input_order) // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 middle_img_idx = input_order[middle] input_order = np.arange(cfg.N) # print("post",input_order,middle_img_idx,cfg.blind,cfg.N) # -- add input noise -- burst_imgs = burst_imgs.cuda(non_blocking=True) burst_imgs_noisy = burst_imgs.clone() if cfg.input_noise: noise = np.random.rand() * cfg.input_noise_level burst_imgs_noisy[middle_img_idx] = torch.normal(burst_imgs_noisy[middle_img_idx],noise) # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)]) if cfg.color_cat: stacked_burst = torch.cat([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) else: stacked_burst = torch.stack([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) # print("stacked_burst",stacked_burst.shape) # if cfg.input_noise: # stacked_burst = torch.normal(stacked_burst,noise) # -- extract target image -- if cfg.blind: t_img = burst_imgs[middle_img_idx] else: t_img = szm(raw_img.cuda(non_blocking=True)) # -- denoising -- rec_img = model(stacked_burst) # -- compute loss -- loss = F.mse_loss(t_img,rec_img) # -- dncnn denoising -- # rec_res = model(stacked_burst) # -- compute loss -- # t_res = t_img - burst_imgs[middle_img_idx] # loss = F.mse_loss(t_res,rec_res) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img,rec_img+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e"%(epoch, cfg.epochs, batch_idx, len(train_loader), running_loss,psnr)) total_loss /= len(train_loader) return total_loss
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # -- viz burst -- # fig,ax = plt.subplots(figsize=(10,10)) # imgs = burst_imgs + 0.5 # imgs.clamp_(0.,1.) # raw_img = raw_img.expand(burst_imgs.shape) # print(imgs.shape,raw_img.shape) # all_img = torch.cat([imgs,raw_img],dim=1) # print(all_img.shape) # grids = [vutils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)] # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids] # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) # Writer = animation.writers['ffmpeg'] # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800) # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer) # print("I DID IT!") # return # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) res_imgs = res_imgs.cuda(non_blocking=True) img0, res0 = burst_imgs[0], res_imgs[0] # img1,res1 = burst_imgs[1],res_imgs[1] # -- predict residual -- pred_res = model(img0) rec_img = img0 - pred_res # -- compare with stacked burst -- loss = F.mse_loss(raw_img, rec_img + 0.5) # loss = F.mse_loss(res_imgs[0],pred_res) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img, rec_img + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e" % (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr)) running_loss = 0 total_loss /= len(train_loader) return total_loss
def train_loop_offset(cfg, model, optimizer, criterion, train_loader, epoch, record_losses): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) D = 5 * 10**3 if record_losses is None: record_losses = pd.DataFrame({ 'kpn': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # if cfg.N != 5: return switch = True for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader): if batch_idx > D: break optimizer.zero_grad() model.zero_grad() # fig,ax = plt.subplots(figsize=(10,10)) # imgs = burst_imgs + 0.5 # imgs.clamp_(0.,1.) # raw_img = raw_img.expand(burst_imgs.shape) # print(imgs.shape,raw_img.shape) # all_img = torch.cat([imgs,raw_img],dim=1) # print(all_img.shape) # grids = [vutils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)] # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids] # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) # Writer = animation.writers['ffmpeg'] # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800) # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer) # print("I DID IT!") # return # -- reshaping of data -- # raw_img = raw_img.cuda(non_blocking=True) input_order = np.arange(cfg.N) # print("pre",input_order,cfg.blind,cfg.N) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = len(input_order) // 2 # print(middle) middle_img_idx = input_order[middle] # input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 input_order = np.arange(cfg.N) middle_img_idx = input_order[middle] # input_order = np.arange(cfg.N) # print("post",input_order,cfg.blind,cfg.N,middle_img_idx) N, BS, C, H, W = burst_imgs.shape burst_imgs = burst_imgs.cuda(non_blocking=True) middle_img = burst_imgs[middle_img_idx] # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)]) # stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) # print("stacked_burst",stacked_burst.shape) # print("burst_imgs.shape",burst_imgs.shape) # print("stacked_burst.shape",stacked_burst.shape) # -- add input noise -- burst_imgs_noisy = burst_imgs.clone() if cfg.input_noise: noise = np.random.rand() * cfg.input_noise_level if cfg.input_noise_middle_only: burst_imgs_noisy[middle_img_idx] = torch.normal( burst_imgs_noisy[middle_img_idx], noise) else: burst_imgs_noisy = torch.normal(burst_imgs_noisy, noise) # -- create inputs for kpn -- stacked_burst = torch.stack( [burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)], dim=1) cat_burst = torch.cat( [burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)], dim=1) # print(stacked_burst.shape) # print(cat_burst.shape) # -- extract target image -- if cfg.blind: t_img = burst_imgs[middle_img_idx] else: t_img = szm(raw_img.cuda(non_blocking=True)) # -- direct denoising -- rec_img_i, rec_img = model(cat_burst, stacked_burst) # rec_img = burst_imgs[middle_img_idx] - rec_res # -- compare with stacked burst -- # print(cfg.blind,t_img.min(),t_img.max(),t_img.mean()) # rec_img = rec_img.expand(t_img.shape) # loss = F.mse_loss(t_img,rec_img) # -- compute mse to optimize -- mse_loss = F.mse_loss(rec_img, t_img) # -- compute kpn loss to optimize -- kpn_losses = criterion(rec_img_i, rec_img, t_img, cfg.global_step) kpn_loss = np.sum(kpn_losses) # -- compute blockwise differences -- rec_img_i_bn = rearrange(rec_img_i, 'b n c h w -> (b n) c h w') r_middle_img = t_img.unsqueeze(1).repeat(1, N, 1, 1, 1) r_middle_img = rearrange(r_middle_img, 'b n c h w -> (b n) c h w') diffs = r_middle_img - rec_img_i_bn # diffs = rearrange(unfold(diffs),'(b n) (c i) r -> b n r (c i)',b=BS,c=3) # -- compute OT loss -- # mse_loss = torch.mean(torch.pow(diffs,2)) diffs = rearrange(diffs, '(b n) c h w -> b n (h w) c', n=N) ot_loss = 0 #skip_middle = i != N//2 and j != N//2 pairs = list(set([(i, j) for i in range(N) for j in range(N) if i < j])) P = len(pairs) S = 3 #P r_idx = npr.choice(range(P), S) for idx in r_idx: i, j = pairs[idx] if i >= j: continue # assert BS == 1, "batch size must be one right now." for b in range(BS): di, dj = diffs[b, i], diffs[b, j] M = torch.sum(torch.pow(di.unsqueeze(1) - dj, 2), dim=-1) ot_loss += sink_stabilized(M, 0.5) ot_loss /= S * BS # M = torch.mean(torch.pow(diffs.unsqueeze(1) - diffs,2),dim=2) # ot_loss = sink(M, 0.5) # -- compute stats for each block -- # mean_est = torch.mean(diffs, dim=(1,2,3), keepdim=True) # std_est = torch.pow( diffs - mean_est, 2) # # mse_loss = F.mse_loss(r_middle_img,rec_img_i_bn,reduction='none') # std_est = torch.flatten(torch.mean( std_est, dim=(1,2,3) )) # # dist_loss = torch.norm(std_est.unsqueeze(1) - std_est) # # -- flatten and compare each block stats -- # dist_loss = 0 # mean_est = torch.flatten(mean_est) # std_est = torch.flatten(std_est) # M = mean_est.shape[0] # for i in range(M): # for j in range(M): # if i >= j: continue # si,sj = std_est[i],std_est[j] # dist_loss += torch.abs(mean_est[i] - mean_est[j]) # dist_loss += torch.abs(si + sj - 2 * (si * sj)**0.5) # -- combine loss -- # print(kpn_loss.item(),10**3 * ot_loss.item(),ot_loss.item() / (1 + mse_loss.item())) # loss = kpn_loss + 10**4 * ot_loss / (1 + mse_loss.item()) alpha, beta = criterion.loss_anneal.alpha, criterion.loss_anneal.beta ot_coeff = 10 # loss = kpn_loss loss = kpn_loss + ot_coeff * ot_loss # / (1 + mse_loss.item()) # print(kpn_loss.item(), 10**4 * ot_loss.item() / (1 + mse_loss.item())) # loss = mse_loss + ot_loss / (1 + mse_loss.item()) # if batch_idx % 100 == 0 or switch: switch = not switch # if switch: # loss += kpn_loss# + ot_loss / (1 + kpn_loss.item()) # # loss = kpn_loss + ot_loss / (1 + kpn_loss.item()) # print(ot_loss.item(),mse_loss.item(),kpn_loss.item(),loss.item()) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if True: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img, rec_img + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) record_losses = record_losses.append( { 'kpn': kpn_loss.item(), 'ot': ot_loss.item(), 'psnr': psnr, 'psnr_std': psnr_std }, ignore_index=True) running_loss /= cfg.log_interval if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f" % (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std)) running_loss = 0 cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record_losses
def main(): # -- init -- cfg = get_main_config() cfg.gpuid = 0 cfg.batch_size = 1 cfg.N = 2 cfg.num_workers = 0 cfg.dynamic.frames = cfg.N cfg.rot = edict() cfg.rot.skip = 0 # big gap between 2 and 3. # -- dynamics -- cfg.dataset.name = "rots" cfg.dataset.load_residual = True cfg.dynamic.frame_size = 256 cfg.frame_size = cfg.dynamic.frame_size cfg.dynamic.ppf = 0 cfg.dynamic.total_pixels = cfg.N * cfg.dynamic.ppf torch.cuda.set_device(cfg.gpuid) # -- sim params -- K = 10 patchsize = 9 db_level = "frame" search_method = "l2" # database_str = f"burstAll" database_idx = 1 database_str = "burst{}".format(database_idx) # -- grab grids for experiments -- noise_settings = create_noise_level_grid(cfg) # sim_settings = create_sim_grid(cfg) # motion_settings = create_motion_grid(cfg) for ns in noise_settings: # -=-=-=-=-=-=-=-=-=-=- # loop params # -=-=-=-=-=-=-=-=-=-=- noise_level = 0. noise_type = ns.ntype noise_str = set_noise_setting(cfg, ns) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # create path for results # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- path_args = (K, patchsize, cfg.batch_size, cfg.N, noise_str, database_str, db_level, search_method) base = Path(f"output/benchmark_noise_types/{cfg.dataset.name}") root = Path(base / "k{}_ps{}_b{}_n{}_{}_db-{}_sim-{}-{}".format(*path_args)) print(f"Writing to {root}") if root.exists(): print("Running Experiment Again.") else: root.mkdir(parents=True) # -=-=-=-=-=-=- # dataset # -=-=-=-=-=-=- data, loader = load_dataset(cfg, 'dynamic') if cfg.dataset.name == "voc": sample = next(iter(loader.tr)) else: sample = data.tr[0] # -- load sample -- burst, raw_img, res = sample['burst'], sample['clean'] - 0.5, sample[ 'res'] kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N) N, B, C, H, W = burst.shape if 'clean_burst' in sample: clean = sample['clean_burst'] - 0.5 else: clean = burst - res if noise_type in ["qis", "pn"]: tvF.rgb_to_grayscale(clean, 3) # burst = tvF.rgb_to_grayscale(burst,3) # raw_img = tvF.rgb_to_grayscale(raw_img,3) # clean = tvF.rgb_to_grayscale(clean,3) # -- temp (delete me soon) -- search_rot_grid = np.linspace(.3, .32, 100) losses = np.zeros_like(search_rot_grid) for idx, angle in enumerate(search_rot_grid): save_alpha_burst = 0.5 * burst[0] + 0.5 * tvF.rotate( burst[1], angle) losses[idx] = F.mse_loss(save_alpha_burst, burst[0]).item() min_arg = np.argmin(losses) angle = search_rot_grid[min_arg] ref_img = tvF.rotate(burst[1], angle) shift_grid = np.linspace(-20, 20, 40 - 1).astype(np.int) losses = np.zeros_like(shift_grid).astype(np.float) for idx, shift in enumerate(shift_grid): save_alpha_burst = 0.5 * burst[0] + 0.5 * torch.roll( ref_img, shift, -2) losses[idx] = F.mse_loss(save_alpha_burst, burst[0]).item() min_arg = np.argmin(losses) shift = shift_grid[min_arg] # -- run search -- kindex = kindex_ds[0] database = None if database_str == f"burstAll": database = burst clean_db = clean else: database = burst[[database_idx]] clean_db = clean[[database_idx]] query = burst[[0]] sim_outputs = compute_similar_bursts_analysis( cfg, query, database, clean_db, K, patchsize=patchsize, shuffle_k=False, kindex=kindex, only_middle=cfg.sim_only_middle, db_level=db_level, search_method=search_method, noise_level=noise_level / 255.) sims, csims, wsims, b_dist, b_indx = sim_outputs # -- save images -- fs = cfg.frame_size save_K = 1 save_sims = rearrange(sims[:, :, :save_K], 'n b k1 c h w -> (n b k1) c h w') save_csims = rearrange(csims[:, :, :save_K], 'n b k1 c h w -> (n b k1) c h w') save_cdelta = clean[0] - save_csims[0] save_alpha_burst = 0.5 * burst[0] + 0.5 * torch.roll( tvF.rotate(burst[1], angle), shift, -2) save_burst = rearrange(burst, 'n b c h w -> (b n) c h w') save_clean = rearrange(clean, 'n b c h w -> (b n) c h w') save_b_dist = rearrange(b_dist[:, :, :save_K], 'n b k1 h w -> (n b k1) 1 h w') save_b_indx = rearrange(b_indx[:, :, :save_K], 'n b k1 h w -> (n b k1) 1 h w') save_b_indx = torch.abs( torch.arange(fs * fs).reshape(fs, fs) - save_b_indx).float() save_b_indx /= (torch.sum(save_b_indx) + 1e-16) tv_utils.save_image(save_sims, root / 'sims.png', nrow=B, normalize=True, range=(-0.5, 0.5)) tv_utils.save_image(save_csims, root / 'csims.png', nrow=B, normalize=True, range=(-0.5, 0.5)) tv_utils.save_image(save_cdelta, root / 'cdelta.png', nrow=B, normalize=True, range=(-0.5, 0.5)) tv_utils.save_image(save_clean, root / 'clean.png', nrow=N, normalize=True, range=(-0.5, 0.5)) tv_utils.save_image(save_burst, root / 'burst.png', nrow=N, normalize=True, range=(-0.5, 0.5)) tv_utils.save_image(save_b_dist, root / 'b_dist.png', nrow=B, normalize=True) tv_utils.save_image(raw_img, root / 'raw.png', nrow=B, normalize=True) tv_utils.save_image(save_b_indx, root / 'b_indx.png', nrow=B, normalize=True) tv_utils.save_image(save_alpha_burst, root / 'alpha_burst.png', nrow=B, normalize=True) # -- save top K patches at location -- b = 0 ref_img = clean[0, b] ps, fs = patchsize, cfg.frame_size xx, yy = np.mgrid[32:48, 48:64] xx, yy = xx.ravel(), yy.ravel() clean_pad = F.pad(clean[database_idx, [b]], (ps // 2, ps // 2, ps // 2, ps // 2), mode='reflect')[0] patches = [] for x, y in zip(xx, yy): gt_patch = tvF.crop(ref_img, x - ps // 2, y - ps // 2, ps, ps) patches_xy = [gt_patch] for k in range(save_K): indx = b_indx[0, 0, k, x, y] xp, yp = (indx // fs) + ps // 2, (indx % fs) + ps // 2 t, l = xp - ps // 2, yp - ps // 2 clean_patch = tvF.crop(clean_pad, t, l, ps, ps) patches_xy.append(clean_patch) pix_diff = F.mse_loss(gt_patch[:, ps // 2, ps // 2], clean_patch[:, ps // 2, ps // 2]).item() pix_diff_img = pix_diff * torch.ones_like(clean_patch) patches_xy.append(pix_diff_img) patches_xy = torch.stack(patches_xy, dim=0) patches.append(patches_xy) patches = torch.stack(patches, dim=0) R = patches.shape[1] patches = rearrange(patches, 'l k c h w -> (l k) c h w') fn = f"patches_{b}.png" tv_utils.save_image(patches, root / fn, nrow=R, normalize=True) # -- stats about distance -- mean_along_k = reduce(b_dist, 'n b k1 h w -> k1', 'mean') std_along_k = torch.std(b_dist, dim=(0, 1, 3, 4)) fig, ax = plt.subplots(figsize=(8, 8)) R = mean_along_k.shape[0] ax.errorbar(np.arange(R), mean_along_k, yerr=std_along_k) plt.savefig(root / "distance_stats.png", dpi=300) plt.clf() plt.close("all") # -- psnr between 1st neighbor and clean -- psnrs = pd.DataFrame({ "b": [], "k": [], "psnr": [], 'crop200_psnr': [] }) for b in range(B): for k in range(K): # -- psnr -- crop_raw = clean[0, b] crop_cmp = csims[0, b, k] rc_mse = F.mse_loss(crop_raw, crop_cmp, reduction='none').reshape(1, -1) rc_mse = torch.mean(rc_mse, 1).numpy() + 1e-16 psnr_bk = np.mean(mse_to_psnr(rc_mse)) print(psnr_bk) # -- crop psnr -- crop_raw = tvF.center_crop(clean[0, b], 200) crop_cmp = tvF.center_crop(csims[0, b, k], 200) rc_mse = F.mse_loss(crop_raw, crop_cmp, reduction='none').reshape(1, -1) rc_mse = torch.mean(rc_mse, 1).numpy() + 1e-16 crop_psnr = np.mean(mse_to_psnr(rc_mse)) # if np.isinf(psnr_bk): psnr_bk = 50. psnrs = psnrs.append( { 'b': b, 'k': k, 'psnr': psnr_bk, 'crop200_psnr': crop_psnr }, ignore_index=True) # psnr_ave = np.mean(psnrs) # psnr_std = np.std(psnrs) # print( "PSNR: %2.2f +/- %2.2f" % (psnr_ave,psnr_std) ) psnrs = psnrs.astype({ 'b': int, 'k': int, 'psnr': float, 'crop200_psnr': float }) psnrs.to_csv(root / "psnrs.csv", sep=",", index=False)
def test_loop(cfg,model,optimizer,criterion,test_loader,epoch): model.train() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(test_loader): # for batch_idx, (burst_imgs, raw_img) in enumerate(test_loader): BS = raw_img.shape[0] N,BS,C,H,W = burst_imgs.shape # -- selecting input frames -- input_order = np.arange(cfg.N) # print("pre",input_order) # if cfg.blind or True: middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = cfg.N // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: # input_order = np.arange(cfg.N) middle = len(input_order) // 2 middle_img_idx = input_order[middle] input_order = np.arange(cfg.N) # print("post",input_order,middle_img_idx,cfg.blind,cfg.N) # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) if cfg.color_cat: stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) else: stacked_burst = torch.stack([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) # stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=0) # stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=0) stacked_burst = torch.stack([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) cat_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) # -- dip denoising -- # img = burst_imgs[middle_img_idx] + 0.5 t_img = burst_imgs[middle_img_idx] + 0.5 img = stacked_burst + 0.5 # img = torch.normal(raw_img,25./255) # z = torch.normal(0,torch.ones_like(img[0].unsqueeze(0))) # print(z.shape) # z = z.requires_grad_(True) diff = 100 idx = 0 iters = 2400 tol = 5e-9 # params = [params.data.clone() for params in model.parameters()] # stacked_burst = torch.normal(0,torch.ones( ( BS, N, C, H, W) )) # stacked_burst = stacked_burst.cuda(non_blocking=True) # cat_burst = rearrange(stacked_burst,'bs n c h w -> bs (n c) h w') best_psnr = 0 model,criterion = load_model_kpn(cfg) optimizer = load_optimizer(cfg,model) model = model.cuda() model.apply(weights_init) # print(f"global_step: {cfg.global_step}") cfg.global_step = 0 while (idx < iters): idx += 1 optimizer.zero_grad() model.zero_grad() # z_img = z + torch.normal(0,torch.ones_like(z)) * 1./20 # stacked_burst_i = torch.normal(stacked_burst,1./20) # cat_burst_i = torch.normal(cat_burst,1./20) # print('m',torch.mean( (stacked_burst_i - stacked_burst)**2) ) # z_img = z # rec_img = model(z_img) # -- create inputs for kpn -- # stacked_burst = torch.stack([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)], # dim=1) # cat_burst = torch.cat([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) # -- forward kpn model -- rec_img_i,rec_img = model(cat_burst,stacked_burst) lossE_ = criterion(rec_img_i, rec_img, t_img, cfg.global_step) # lossE_ = criterion(rec_img_i, rec_img, t_img, cfg.global_step) cfg.global_step += 30 lossE = np.sum(lossE_) # lossE = F.mse_loss(t_img,rec_img) # lossE = np.sum([F.mse_loss(t_img,rec_img_i[:,i]) for i in range(N)]) # lossE = F.mse_loss(t_img,rec_img) if (idx % 1) == 0 or idx == 1: # print(rec_img.shape) loss = F.mse_loss(raw_img[:,:,:16,:16],rec_img[:,:,:16,:16],reduction='none').reshape(BS,-1) loss = torch.mean(loss,1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(loss)) if (idx % 100) == 0 or idx == 1: print("[%d/%d] lossE: [%.2e] psnr: [%.2f]" % (idx,iters,lossE,psnr)) if psnr > best_psnr: best_psnr = psnr if torch.isinf(lossE): break # a = list(model.parameters())[0].clone() lossE.backward() optimizer.step() # b = list(model.parameters())[0].clone() # print("EQ?",torch.equal(a.data,b.data)) # print(torch.mean(a.data - b.data)**2) # params_p = [params.data.clone() for params in model.parameters()] # diff = np.mean([float(torch.mean((p - p_p)**2).cpu().item()) for p,p_p in zip(params,params_p)]) # print("diff: {:.2e}".format(diff)) # params = params_p # rec_img = model(z) print(f"Best PSNR: {best_psnr}") # -- compare with stacked targets -- # rec_img = rescale_noisy_image(rec_img) # loss = F.mse_loss(raw_img,rec_img,reduction='none').reshape(BS,-1) # loss = torch.mean(loss,1).detach().cpu().numpy() # psnr = mse_to_psnr(loss) # print(np.mean(psnr)) total_psnr += np.mean(best_psnr) # total_loss += np.mean(loss) if (batch_idx % cfg.test_log_interval) == 0: root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/rec_imgs/N{cfg.N}/e{epoch}") if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}.png") nrow = int(np.sqrt(cfg.batch_size)) rec_img = rec_img.detach().cpu() grid_imgs = vutils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) plt.imshow(grid_imgs.permute(1,2,0)) plt.savefig(fn) plt.close('all') ave_psnr = total_psnr / len(test_loader) ave_loss = total_loss / len(test_loader) print("[Blind: %d | N: %d] Testing results: Ave psnr %2.3e Ave loss %2.3e"%(cfg.blind,cfg.N,ave_psnr,ave_loss)) return ave_psnr
def train_loop_offset(cfg, model, optimizer, criterion, train_loader, epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 write_examples = True write_examples_iter = 800 szm = ScaleZeroMean() record = init_record() use_record = False # if cfg.N != 5: return B = 1500 # len(train_loader) train_it = iter(train_loader) # for batch_idx, (burst_imgs, res_imgs, raw_img, directions) in enumerate(train_loader): for batch_idx in range(B): (burst_imgs, res_imgs, raw_img, directions) = next(train_it) optimizer.zero_grad() model.zero_grad() # fig,ax = plt.subplots(figsize=(10,10)) # imgs = burst_imgs + 0.5 # imgs.clamp_(0.,1.) # raw_img = raw_img.expand(burst_imgs.shape) # print(imgs.shape,raw_img.shape) # all_img = torch.cat([imgs,raw_img],dim=1) # print(all_img.shape) # grids = [tv_utils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)] # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids] # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) # Writer = animation.writers['ffmpeg'] # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800) # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer) # print("I DID IT!") # return # -- reshaping of data -- # raw_img = raw_img.cuda(non_blocking=True) input_order = np.arange(cfg.N) # print("pre",input_order,cfg.blind,cfg.N) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = len(input_order) // 2 # print(middle) middle_img_idx = input_order[middle] # input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 input_order = np.arange(cfg.N) middle_img_idx = input_order[middle] # input_order = np.arange(cfg.N) # print("post",input_order,cfg.blind,cfg.N,middle_img_idx) burst_imgs = burst_imgs.cuda(non_blocking=True) # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)]) # stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) # print("stacked_burst",stacked_burst.shape) # print("burst_imgs.shape",burst_imgs.shape) # print("stacked_burst.shape",stacked_burst.shape) # -- add input noise -- burst_imgs_noisy = burst_imgs.clone() if cfg.input_noise: noise = np.random.rand() * cfg.input_noise_level if cfg.input_noise_middle_only: burst_imgs_noisy[middle_img_idx] = torch.normal( burst_imgs_noisy[middle_img_idx], noise) else: burst_imgs_noisy = torch.normal(burst_imgs_noisy, noise) # -- create inputs for stn -- stacked_burst = torch.stack( [burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)], dim=1) cat_burst = torch.cat( [burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)], dim=1) # print(stacked_burst.shape) # print(cat_burst.shape) # -- extract target image -- mid_img = burst_imgs[middle_img_idx] raw_img_zm = szm(raw_img.cuda(non_blocking=True)) if cfg.blind: t_img = burst_imgs[middle_img_idx] else: t_img = szm(raw_img.cuda(non_blocking=True)) # -- direct denoising -- # aligned,rec_img,temporal_loss,thetas = model(cat_burst,stacked_burst) aligned, thetas = model(stacked_burst) rec_img = torch.mean(aligned, dim=1) # aligned,rec_img,thetas = model(cat_burst,stacked_burst) # temporal_loss = torch.FloatTensor([-1.]).to(cfg.device) # print("(a) [m: %2.2e] [std: %2.2e] vs [tgt: %2.2e]" % (torch.mean(mid_img - raw_img_zm).item(),F.mse_loss(mid_img,raw_img_zm).item(),(25./255)**2) ) # r_raw_img_zm = raw_img_zm.unsqueeze(1).repeat(1,N,1,1,1) # print("(b) [m: %2.2e] [std: %2.2e] vs [tgt: %2.2e]" % (torch.mean(aligned - r_raw_img_zm).item(),F.mse_loss(aligned,r_raw_img_zm).item(),(25./255)**2) ) # -- compare with stacked burst -- # print(cfg.blind,t_img.min(),t_img.max(),t_img.mean()) # rec_img = rec_img.expand(t_img.shape) # loss = F.mse_loss(t_img,rec_img) # -- compute loss to optimize -- # loss = criterion(aligned, rec_img, t_img, cfg.global_step) loss = F.mse_loss(rec_img, t_img) # loss = np.sum(loss) stn_loss = loss # temporal_loss = temporal_loss.item() # mse_loss = F.mse_loss(rec_img,mid_img) # # -- compute ot loss to optimize -- # residuals = aligned - t_img.unsqueeze(1).repeat(1,N,1,1,1) # residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # ot_loss = ot_frame_pairwise_bp(residuals,reg=0.5,K=3) # ot_coeff = 1 - .997**cfg.global_step # -- final loss -- # loss = ot_coeff * ot_loss + stn_loss # loss = loss # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for [rec img] -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img, rec_img + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_ave = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) running_loss /= cfg.log_interval # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=25 / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(BS, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr_ave, psnr_std, bm3d_nb_ave, bm3d_nb_std) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f" % write_info) # print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f"%(epoch, cfg.epochs, batch_idx, # len(train_loader), # running_loss,psnr_ave,psnr_std)) running_loss = 0 # -- record information -- if use_record: rec = rec_img raw = raw_img_zm frame_results = compute_ot_frame(aligned, rec, raw, reg=0.5) burst_results = compute_ot_burst(aligned, rec, raw, reg=0.5) psnr_record = {'psnr_ave': psnr_ave, 'psnr_std': psnr_std} stn_record = {'stn_loss': stn_loss} new_record = merge_records(frame_results, burst_results, psnr_record, stn_record) record = record.append(new_record, ignore_index=True) # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0: write_input_output(cfg, stacked_burst, aligned, thetas, directions) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record
def test_denoising(cfg, model, test_loader, epoch, num_epochs): model.eval() rigid_loss = 0 test_loss = 0 idx = 0 with torch.no_grad(): for noisy_imgs, raw_img in tqdm(test_loader): set_loss = 0 N, BS = noisy_imgs.shape[:2] p_shape = noisy_imgs.shape[2:] bshape = (N * BS, ) + p_shape dshape = ( N, BS, ) + p_shape noisy_imgs = noisy_imgs.cuda(non_blocking=True) raw_img = raw_img.cuda(non_blocking=True) noisy_imgs = noisy_imgs.cuda(non_blocking=True) noisy_imgs = noisy_imgs.view((N * BS, ) + p_shape) dec_imgs = model(noisy_imgs).detach() dec_no_rescale = dec_imgs dec_imgs = rescale_noisy_image(dec_imgs.clone()) rigid_nmlz_imgs = normalize_image_to_zero_one(dec_imgs.clone()) rigid_nmlz_imgs = rigid_nmlz_imgs.reshape(dshape) dec_imgs = dec_imgs.reshape(dshape) raw_img = raw_img.expand(dshape) if idx == 10: print('dec_no_rescale', dec_no_rescale.mean(), dec_no_rescale.min(), dec_no_rescale.max()) print('noisy', noisy_imgs.mean(), noisy_imgs.min(), noisy_imgs.max()) print('dec', dec_imgs.mean(), dec_imgs.min(), dec_imgs.max()) print('raw', raw_img.mean(), raw_img.min(), raw_img.max()) r_loss = F.mse_loss(raw_img, rigid_nmlz_imgs).item() if cfg.test_with_psnr: r_loss = mse_to_psnr(r_loss) rigid_loss += r_loss loss = F.mse_loss(raw_img, dec_imgs).item() if cfg.test_with_psnr: loss = mse_to_psnr(loss) test_loss += loss idx += 1 test_loss /= len(test_loader) rigid_loss /= len(test_loader) print('\n[Test set] Average loss: {:2.3e}\n'.format(test_loss)) print('\n[Test set with rigid loss] Average loss: {:2.3e}\n'.format( rigid_loss)) noisy_imgs = noisy_imgs.detach().cpu().view(bshape) rigid_nmlz_imgs = rigid_nmlz_imgs.detach().cpu().view(bshape) dec_imgs = dec_imgs.detach().cpu().view(bshape) fig, ax = plt.subplots(3, 1, figsize=(10, 5)) grid_im = vutils.make_grid(noisy_imgs, padding=2, normalize=True, nrow=16) ax[0].imshow(grid_im.permute(1, 2, 0)) grid_im = vutils.make_grid(rigid_nmlz_imgs, padding=2, normalize=False, nrow=16) ax[1].imshow(grid_im.permute(1, 2, 0)) grid_im = vutils.make_grid(dec_imgs, padding=2, normalize=False, nrow=16) ax[2].imshow(grid_im.permute(1, 2, 0)) plt.savefig(f"./gan_examples_e{epoch}o{num_epochs}g{cfg.gpuid}.png") return test_loss
def test_loop_offset(cfg,model,criterion,test_loader,epoch): model.eval() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 with torch.no_grad(): for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(test_loader): # for batch_idx, (burst_imgs, raw_img) in enumerate(test_loader): BS = raw_img.shape[0] # -- selecting input frames -- input_order = np.arange(cfg.N) # print("pre",input_order) # if cfg.blind or True: middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = cfg.N // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: # input_order = np.arange(cfg.N) middle = len(input_order) // 2 middle_img_idx = input_order[middle] input_order = np.arange(cfg.N) # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) if cfg.color_cat: stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) else: stacked_burst = torch.stack([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) # -- direct denoising -- middle_img = burst_imgs[middle_img_idx] loss,rec_img = model(stacked_burst,middle_img) # rec_imgs = model(stacked_burst) # rec_imgs = rearrange( rec_imgs, 'b (n c) h w -> b n c h w',n=cfg.input_N-1) # rec_img = torch.mean( rec_imgs, dim=1) # -- dncnn denoising -- # rec_res = model(stacked_burst) # rec_img = burst_imgs[middle_img_idx] + rec_res # -- compare with stacked targets -- rec_img = rescale_noisy_image(rec_img) loss = F.mse_loss(raw_img,rec_img,reduction='none').reshape(BS,-1) loss = torch.mean(loss,1).detach().cpu().numpy() psnr = mse_to_psnr(loss) total_psnr += np.mean(psnr) total_loss += np.mean(loss) if (batch_idx % cfg.test_log_interval) == 0: root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/rec_imgs/N{cfg.N}/e{epoch}") if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}.png") nrow = int(np.sqrt(cfg.batch_size)) rec_img = rec_img.detach().cpu() grid_imgs = vutils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) plt.imshow(grid_imgs.permute(1,2,0)) plt.savefig(fn) plt.close('all') ave_psnr = total_psnr / len(test_loader) ave_loss = total_loss / len(test_loader) print("[Blind: %d | N: %d] Testing results: Ave psnr %2.3e Ave loss %2.3e"%(cfg.blind,cfg.N,ave_psnr,ave_loss)) return ave_psnr
def train_loop(cfg, model, noise_critic, optimizer, criterion, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # Setup for epoch # -=-=-=-=-=-=-=-=-=-=- model.train() model = model.to(cfg.device) N = cfg.N szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) D = 5 * 10**3 use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) write_examples = True write_examples_iter = 800 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Init Record Keeping # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses_nc, losses_nc_count = 0, 0 losses_mse, losses_mse_count = 0, 0 running_loss, total_loss = 0, 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Init Loss Functions # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- lossRecMSE = LossRec(tensor_grad=cfg.supervised) lossBurstMSE = LossRecBurst(tensor_grad=cfg.supervised) # -=-=-=-=-=-=-=-=-=-=- # Final Configs # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True train_iter = iter(train_loader) if use_timer: clock = Timer() # -=-=-=-=-=-=-=-=-=-=- # GAN Scheduler # -=-=-=-=-=-=-=-=-=-=- # -- noise critic steps -- if epoch == 0: disc_steps = 0 elif epoch < 3: disc_steps = 1 elif epoch < 10: disc_steps = 1 else: disc_steps = 1 # -- denoising steps -- if epoch == 0: gen_steps = 1 if epoch < 3: gen_steps = 15 if epoch < 10: gen_steps = 10 else: gen_steps = 10 # -- steps each epoch -- steps_per_iter = disc_steps * gen_steps steps_per_epoch = len(train_loader) // steps_per_iter if steps_per_epoch > 120: steps_per_epoch = 120 # -=-=-=-=-=-=-=-=-=-=- # Start Epoch # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): for gen_step in range(gen_steps): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Setting up for Iteration # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() noise_critic.disc.zero_grad() noise_critic.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, BS, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Formatting Images for FP # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Foward Pass # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned, aligned_ave, denoised, denoised_ave, filters = model( burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # MSE (KPN) Reconstruction Loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- loss_rec = lossRecMSE(denoised_ave, gt_img) loss_burst = lossBurstMSE(denoised, gt_img) loss_mse = loss_rec + 100 * loss_burst gbs, spe = cfg.global_step, steps_per_epoch if epoch < 3: weight_mse = 10 else: weight_mse = 10 * 0.9999**(gbs - 3 * spe) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Noise Critic Loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- loss_nc = noise_critic.compute_residual_loss(denoised, gt_img) weight_nc = 1 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Final Loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- final_loss = weight_mse * loss_mse + weight_nc * loss_nc # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Update Info for Record Keeping # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- update alignment kl loss info -- losses_nc += loss_nc.item() losses_nc_count += 1 # -- update reconstruction kl loss info -- losses_mse += loss_mse.item() losses_mse_count += 1 # -- update info -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Backward Pass # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.denoiser_info.optim.step() optimizer.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Iterate for Noise Critic # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- for disc_step in range(disc_steps): # -- zero gradients -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.optim.zero_grad() noise_critic.disc.zero_grad() noise_critic.optim.zero_grad() # -- grab noisy data -- _burst, _res_imgs, _raw_img, _directions = next(train_iter) _burst = _burst.to(cfg.device) # -- generate "fake" data from noisy data -- _aligned, _aligned_ave, _denoised, _denoised_ave, _filters = model( _burst) _residuals = _denoised - _burst[N // 2].unsqueeze(1).repeat( 1, N, 1, 1, 1) # -- update discriminator -- loss_disc = noise_critic.update_disc(_residuals) # -- message to stdout -- first_update = (disc_step == 0) last_update = (disc_step == disc_steps - 1) iter_update = first_update or last_update # if (batch_idx % cfg.log_interval//2) == 0 and batch_idx > 0 and iter_update: print(f"[Noise Critic]: {loss_disc}") # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Print Message to Stdout # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- init -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- write record -- if use_record: record_losses = record_losses.append( { 'burst': burst_loss, 'ave': ave_loss, 'ot': ot_loss, 'psnr': psnr, 'psnr_std': psnr_std }, ignore_index=True) # -- update losses -- running_loss /= cfg.log_interval # -- average mse losses -- ave_losses_mse = losses_mse / losses_mse_count losses_mse, losses_mse_count = 0, 0 # -- average noise critic loss -- ave_losses_nc = losses_nc / losses_nc_count losses_nc, losses_nc_count = 0, 0 # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, ave_losses_mse, ave_losses_nc) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [mse]: %.2e [nc]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, stacked_burst, aligned, denoised, filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record_losses
def train_loop_offset(cfg,model,optimizer,criterion,train_loader,epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() # random_eraser = th_trans.RandomErasing(scale=(0.40,0.80)) random_eraser = th_trans.RandomErasing(scale=(0.02,0.33)) # if cfg.N != 5: return # for batch_idx, (burst_imgs, raw_img) in enumerate(train_loader): for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # fig,ax = plt.subplots(figsize=(10,10)) # imgs = burst_imgs + 0.5 # imgs.clamp_(0.,1.) # raw_img = raw_img.expand(burst_imgs.shape) # print(imgs.shape,raw_img.shape) # all_img = torch.cat([imgs,raw_img],dim=1) # print(all_img.shape) # grids = [vutils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)] # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids] # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) # Writer = animation.writers['ffmpeg'] # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800) # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer) # print("I DID IT!") # return # -- shape info -- N,BS,C,H,W = burst_imgs.shape # -- reshaping of data -- # raw_img = raw_img.cuda(non_blocking=True) input_order = np.arange(cfg.N) # print("pre",input_order,cfg.blind,cfg.N) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = len(input_order) // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 middle_img_idx = input_order[middle] input_order = np.arange(cfg.N) # print("post",input_order,middle_img_idx,cfg.blind,cfg.N) # -- add input noise -- burst_imgs = burst_imgs.cuda(non_blocking=True) middle_img = burst_imgs[middle_img_idx] burst_imgs_noisy = burst_imgs.clone() if cfg.input_noise: # noise = np.random.rand() * cfg.input_noise_level noise = cfg.input_noise_level if cfg.input_noise_middle_only: burst_imgs_noisy[middle_img_idx] = torch.normal(burst_imgs_noisy[middle_img_idx],noise) else: burst_imgs_noisy = torch.normal(burst_imgs_noisy,noise) # if cfg.middle_frame_random_erase: # for i in range(burst_imgs_noisy[middle_img_idx].shape[0]): # tmp = random_eraser(burst_imgs_noisy[middle_img_idx][i]) # burst_imgs_noisy[middle_img_idx][i] = tmp # burst_imgs_noisy = torch.normal(burst_imgs_noisy,noise) # print(torch.sum(burst_imgs_noisy[middle_img_idx] - burst_imgs[middle_img_idx])) # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)]) if cfg.color_cat: stacked_burst = torch.cat([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) else: stacked_burst = torch.stack([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) # if cfg.input_noise: # stacked_burst = torch.normal(stacked_burst,noise) # -- extract target image -- if cfg.blind: t_img = burst_imgs[middle_img_idx] else: t_img = szm(raw_img.cuda(non_blocking=True)) # -- denoising -- loss,rec_img = model(stacked_burst,middle_img) # rec_imgs = rearrange(rec_imgs,'b (n c) h w -> (b n) c h w',n=cfg.input_N-1) # loss = 0 # # -- compute loss -- # r_middle_img = middle_img.repeat(N-1,1,1,1) # mse_loss = F.mse_loss(r_middle_img,rec_imgs,reduction='none') # # loss += torch.mean(mse_loss) # std_est = torch.mean( mse_loss, dim=(1,2,3) ) # loss += torch.norm(std_est.unsqueeze(1) - std_est) # # -- reconstruct image -- # rec_imgs = rec_imgs.reshape(BS,N-1,C,H,W) # # rec_imgs = rearrange( rec_imgs, '(b n) c h w -> b n c h w',n=cfg.input_N-1) # rec_img = torch.mean( rec_imgs, dim=1) # loss += F.mse_loss( rec_img, middle_img) # loss = F.mse_loss(t_img,rec_img) # -- dncnn denoising -- # rec_res = model(stacked_burst) # -- compute loss -- # t_res = t_img - burst_imgs[middle_img_idx] # loss = F.mse_loss(t_res,rec_res) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img,rec_img+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e"%(epoch, cfg.epochs, batch_idx, len(train_loader), running_loss,psnr)) running_loss = 0 total_loss /= len(train_loader) return total_loss
def test_loop(cfg, model, criterion, test_loader, epoch): model.eval() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 psnrs = np.zeros((len(test_loader), cfg.batch_size)) use_record = False record_test = pd.DataFrame({'psnr': []}) with torch.no_grad(): for batch_idx, (burst, res_imgs, raw_img, directions) in enumerate(test_loader): BS = raw_img.shape[0] # -- selecting input frames -- input_order = np.arange(cfg.N) # print("pre",input_order) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = cfg.N // 2 # print(middle) middle_img_idx = input_order[middle] # input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 input_order = np.arange(cfg.N) middle_img_idx = input_order[middle] # input_order = np.arange(cfg.N) # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) stacked_burst = torch.stack( [burst[input_order[x]] for x in range(cfg.input_N)], dim=1) cat_burst = torch.cat( [burst[input_order[x]] for x in range(cfg.input_N)], dim=1) # -- denoising -- aligned, aligned_ave, denoised, denoised_ave, filters = model( burst) denoised_ave = denoised_ave.detach() # if not cfg.input_with_middle_frame: # denoised_ave = model(cat_burst,stacked_burst)[1] # else: # denoised_ave = model(cat_burst,stacked_burst)[0][middle_img_idx] # denoised_ave = burst[middle_img_idx] - rec_res # -- compare with stacked targets -- denoised_ave = rescale_noisy_image(denoised_ave) # -- compute psnr -- loss = F.mse_loss(raw_img, denoised_ave, reduction='none').reshape(BS, -1) # loss = F.mse_loss(raw_img,burst[cfg.input_N//2]+0.5,reduction='none').reshape(BS,-1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = mse_to_psnr(loss) psnrs[batch_idx, :] = psnr if use_record: record_test = record_test.append({'psnr': psnr}, ignore_index=True) total_psnr += np.mean(psnr) total_loss += np.mean(loss) # if (batch_idx % cfg.test_log_interval) == 0: # root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/denoised_aves/N{cfg.N}/e{epoch}") # if not root.exists(): root.mkdir(parents=True) # fn = root / Path(f"b{batch_idx}.png") # nrow = int(np.sqrt(cfg.batch_size)) # denoised_ave = denoised_ave.detach().cpu() # grid_imgs = tv_utils.make_grid(denoised_ave, padding=2, normalize=True, nrow=nrow) # plt.imshow(grid_imgs.permute(1,2,0)) # plt.savefig(fn) # plt.close('all') if batch_idx % 100 == 0: print("[%d/%d] Test PSNR: %2.2f" % (batch_idx, len(test_loader), total_psnr / (batch_idx + 1))) psnr_ave = np.mean(psnrs) psnr_std = np.std(psnrs) ave_loss = total_loss / len(test_loader) print("[N: %d] Testing: [psnr: %2.2f +/- %2.2f] [ave loss %2.3e]" % (cfg.N, psnr_ave, psnr_std, ave_loss)) return psnr_ave, record_test
def test_loop(cfg, model, optimizer, criterion, test_loader, epoch): model.train() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(test_loader): # for batch_idx, (burst_imgs, raw_img) in enumerate(test_loader): BS = raw_img.shape[0] # -- selecting input frames -- input_order = np.arange(cfg.N) # print("pre",input_order) # if cfg.blind or True: middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = cfg.N // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle], input_order[middle + 1:]] else: # input_order = np.arange(cfg.N) middle = len(input_order) // 2 middle_img_idx = input_order[middle] input_order = np.arange(cfg.N) print("post", input_order, middle_img_idx, cfg.blind, cfg.N) # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) if cfg.color_cat: stacked_burst = torch.cat( [burst_imgs[input_order[x]] for x in range(cfg.input_N)], dim=1) else: stacked_burst = torch.stack( [burst_imgs[input_order[x]] for x in range(cfg.input_N)], dim=1) # -- dip denoising -- img = burst_imgs[middle_img_idx] + 0.5 # img = torch.normal(raw_img,25./255) z = torch.normal(0, torch.ones_like(img)) z = z.requires_grad_(True) diff = 100 idx = 0 iters = 4800 tol = 5e-9 # params = [params.data.clone() for params in model.parameters()] model.apply(weights_init) while (idx < iters): idx += 1 optimizer.zero_grad() model.zero_grad() z_img = z + torch.normal(0, torch.ones_like(z)) * 1. / 20 # z_img = z rec_img = model(z_img) lossE = F.mse_loss(img, rec_img) if (idx % 100) == 0: loss = F.mse_loss(raw_img, rec_img, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(loss)) print("[%d/%d] lossE: [%.2e] psnr: [%.2f]" % (idx, iters, lossE, psnr)) # a = list(model.parameters())[0].clone() lossE.backward() optimizer.step() # b = list(model.parameters())[0].clone() # print("EQ?",torch.equal(a.data,b.data)) # print(torch.mean(a.data - b.data)**2) # params_p = [params.data.clone() for params in model.parameters()] # diff = np.mean([float(torch.mean((p - p_p)**2).cpu().item()) for p,p_p in zip(params,params_p)]) # print("diff: {:.2e}".format(diff)) # params = params_p rec_img = model(z) # -- compare with stacked targets -- # rec_img = rescale_noisy_image(rec_img) loss = F.mse_loss(raw_img, rec_img, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = mse_to_psnr(loss) print(np.mean(psnr)) total_psnr += np.mean(psnr) total_loss += np.mean(loss) if (batch_idx % cfg.test_log_interval) == 0: root = Path( f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/rec_imgs/N{cfg.N}/e{epoch}" ) if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}.png") nrow = int(np.sqrt(cfg.batch_size)) rec_img = rec_img.detach().cpu() grid_imgs = vutils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) plt.imshow(grid_imgs.permute(1, 2, 0)) plt.savefig(fn) plt.close('all') ave_psnr = total_psnr / len(test_loader) ave_loss = total_loss / len(test_loader) print( "[Blind: %d | N: %d] Testing results: Ave psnr %2.3e Ave loss %2.3e" % (cfg.blind, cfg.N, ave_psnr, ave_loss)) return ave_psnr