def train_loop(cfg, model, optimizer, criterion, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True write_examples_iter = 800 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) D = 5 * 10**3 steps_per_epoch = len(train_loader) # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, BS, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned, aligned_ave, denoised, denoised_ave, filters = model(burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Entropy Loss for Filters # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_shaped = rearrange(filters, 'b n k2 1 1 1 -> (b n) k2', n=N) filters_entropy = entropyLoss(filters_shaped) filters_entropy_coeff = 10. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0 #.933**cfg.global_step if cfg.global_step < 100 else 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_ave_d = denoised_ave.detach() losses = criterion(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- regularization scheduler -- if cfg.global_step < 100: reg = 0.5 elif cfg.global_step < 200: reg = 0.25 elif cfg.global_step < 5000: reg = 0.15 elif cfg.global_step < 10000: reg = 0.1 else: reg = 0.05 # -- computation -- residuals = denoised - mid_img.unsqueeze(1).repeat(1, N, 1, 1, 1) residuals = rearrange(residuals, 'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) rec_ot_pair_loss_v1 = kl_gaussian_bp(residuals, noise_level) # rec_ot_pair_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_pair_loss_v2 = ot_pairwise_bp(residuals,K=3) rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) rec_ot_pair = (rec_ot_pair_loss_v1 + rec_ot_pair_loss_v2) / 2. rec_ot_pair_coeff = 100 # - .997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_loss = align_mse_coeff * align_mse rec_loss = rec_ot_pair_coeff * rec_ot_pair + rec_mse_coeff * rec_mse entropy_loss = filters_entropy_coeff * filters_entropy final_loss = align_loss + rec_loss + entropy_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot_pair.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.denoiser_info.optim.step() optimizer.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg,model,optimizer,criterion,train_loader,epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # -- reshaping of data -- # raw_img = raw_img.cuda(non_blocking=True) input_order = np.arange(cfg.N) # print("pre",input_order,cfg.blind,cfg.N) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = len(input_order) // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 middle_img_idx = input_order[middle] input_order = np.arange(cfg.N) # print("post",input_order,middle_img_idx,cfg.blind,cfg.N) # -- add input noise -- burst_imgs = burst_imgs.cuda(non_blocking=True) burst_imgs_noisy = burst_imgs.clone() if cfg.input_noise: noise = np.random.rand() * cfg.input_noise_level burst_imgs_noisy[middle_img_idx] = torch.normal(burst_imgs_noisy[middle_img_idx],noise) # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)]) if cfg.color_cat: stacked_burst = torch.cat([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) else: stacked_burst = torch.stack([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) # print("stacked_burst",stacked_burst.shape) # if cfg.input_noise: # stacked_burst = torch.normal(stacked_burst,noise) # -- extract target image -- if cfg.blind: t_img = burst_imgs[middle_img_idx] else: t_img = szm(raw_img.cuda(non_blocking=True)) # -- denoising -- rec_img = model(stacked_burst) # -- compute loss -- loss = F.mse_loss(t_img,rec_img) # -- dncnn denoising -- # rec_res = model(stacked_burst) # -- compute loss -- # t_res = t_img - burst_imgs[middle_img_idx] # loss = F.mse_loss(t_res,rec_res) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img,rec_img+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e"%(epoch, cfg.epochs, batch_idx, len(train_loader), running_loss,psnr)) total_loss /= len(train_loader) return total_loss
def train_loop(cfg, model, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 align_ot_losses, align_ot_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) steps_per_epoch = len(train_loader) write_examples_iter = steps_per_epoch // 3 # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- model.align_info.model.zero_grad() model.align_info.optim.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, B, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Require Approx Equal Filter Norms # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_filters = rearrange(denoised_filters.detach(), 'b n k2 c h w -> n (b k2 c h w)') norms = denoised_filters.norm(dim=1) norm_loss_denoiser = torch.mean((norms - norms[N // 2])**2) norm_loss_coeff = 100. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Decrease Entropy within a Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_entropy = 0 filters_entropy_coeff = 1000. all_filters = [] L = len(align_hook.filters) iter_filters = align_hook.filters if L > 0 else [aligned_filters] for filters in iter_filters: filters_shaped = rearrange(filters, 'b n k2 c h w -> (b n c h w) k2', n=N) filters_entropy += entropyLoss(filters_shaped) all_filters.append(filters) if L > 0: filters_entropy /= L all_filters = torch.stack(all_filters, dim=1) align_hook.clear() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Increase Entropy across each Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_dist_entropy = 0 # -- across each frame -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each batch -- filters_shaped = rearrange(all_filters, 'b l n k2 c h w -> (n l) (b c h w) k2') filters_shaped = torch.mean(filters_shaped, dim=1) filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each kpn cascade -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) filters_dist_coeff = 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- fs = cfg.dynamic.frame_size residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) centered_residuals = tvF.center_crop(residuals, (fs // 2, fs // 2)) align_ot = kl_gaussian_bp(centered_residuals, noise_level, flip=True) align_ot_coeff = 100. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = denoiseLossMSE(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- regularization scheduler -- if cfg.global_step < 100: reg = 0.5 elif cfg.global_step < 200: reg = 0.25 elif cfg.global_step < 5000: reg = 0.15 elif cfg.global_step < 10000: reg = 0.1 else: reg = 0.05 # -- computation -- residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) # residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) rec_ot_loss_v1 = kl_gaussian_bp(residuals, noise_level, flip=True) # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals) # rec_ot_loss_v1 = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16) # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3) rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2) rec_ot_coeff = 100. # - .997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse norm_loss = norm_loss_coeff * norm_loss_denoiser align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot entropy_loss = filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy final_loss = align_loss + rec_loss + entropy_loss + norm_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- alignment Dist -- align_ot_losses += align_ot.item() align_ot_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.align_info.optim.step() model.denoiser_info.optim.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] M = 10 if B > 10 else B for b in range(B): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- alignment Dist. -- align_ot_ave = align_ot_losses / align_ot_count align_ot_losses, align_ot_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() spoof_loss = 0 optimizer.zero_grad() model.zero_grad() # if cfg.N != 5: return # for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader): for batch_idx, stuff in enumerate(train_loader): if len(stuff) == 3: burst_imgs, res_imgs, raw_img = stuff elif len(stuff) == 2: burst_imgs, raw_img = stuff else: exit("WHAT?") optimizer.zero_grad() model.zero_grad() # fig,ax = plt.subplots(figsize=(10,10)) # imgs = burst_imgs + 0.5 # imgs.clamp_(0.,1.) # raw_img = raw_img.expand(burst_imgs.shape) # print(imgs.shape,raw_img.shape) # all_img = torch.cat([imgs,raw_img],dim=1) # print(all_img.shape) # grids = [vutils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)] # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids] # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) # Writer = animation.writers['ffmpeg'] # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800) # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer) # print("I DID IT!") # return # -- reshaping of data -- # raw_img = raw_img.cuda(non_blocking=True) input_order = np.arange(cfg.N) # print("pre",input_order,cfg.blind,cfg.N) middle = len(input_order) // 2 middle_img_idx = input_order[middle] if not cfg.input_with_middle_frame: middle = len(input_order) // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle], input_order[middle + 1:]] else: input_order = np.arange(cfg.N) # -- add input noise -- burst_imgs = burst_imgs.cuda(non_blocking=True) burst_imgs_noisy = burst_imgs.clone() if cfg.input_noise: noise = np.random.rand() * cfg.input_noise_level burst_imgs_noisy[middle_img_idx] = torch.normal( burst_imgs_noisy[middle_img_idx], noise) # -- stack images -- stacked_burst = torch.stack( [burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)], dim=1) # print("post",input_order,cfg.blind,cfg.N,middle_img_idx) # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)]) # print("stacked_burst",stacked_burst.shape) # -- extract target image -- if cfg.blind: t_img = burst_imgs[middle_img_idx] else: t_img = szm(raw_img.cuda(non_blocking=True)) # -- denoising -- loss, rec_img = model(stacked_burst, t_img) # -- spoof batch size -- if cfg.spoof_batch: spoof_loss += loss # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- if cfg.spoof_batch and (batch_idx % cfg.spoof_batch_size) == 0: spoof_loss.backward() optimizer.step() optimizer.zero_grad() model.zero_grad() spoof_loss = 0 elif not cfg.spoof_batch: loss.backward() optimizer.step() optimizer.zero_grad() model.zero_grad() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e " % (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss)) running_loss = 0 total_loss /= len(train_loader) return total_loss
def test_loop(cfg, model, criterion, test_loader, epoch): model.eval() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 szm = ScaleZeroMean() with torch.no_grad(): # for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(test_loader): for batch_idx, stuff in enumerate(test_loader): if len(stuff) == 3: burst_imgs, res_imgs, raw_img = stuff elif len(stuff) == 2: burst_imgs, raw_img = stuff else: exit("WHAT?") BS = raw_img.shape[0] # -- selecting input frames -- input_order = np.arange(cfg.N) # print("pre",input_order) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = cfg.N // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle], input_order[middle + 1:]] else: input_order = np.arange(cfg.N) if cfg.blind: t_img = burst_imgs[middle_img_idx] else: t_img = szm(raw_img.cuda(non_blocking=True)) t_img = t_img.to(cfg.device) # reshaping of data raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) stacked_burst = torch.stack( [burst_imgs[input_order[x]] for x in range(cfg.input_N)], dim=1) # denoising loss, rec_img = model(stacked_burst, t_img) # rec_img = burst_imgs[middle_img_idx] - rec_res # compare with stacked targets rec_img = rescale_noisy_image(rec_img) loss = F.mse_loss(raw_img, rec_img, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = mse_to_psnr(loss) total_psnr += np.mean(psnr) total_loss += np.mean(loss) if (batch_idx % cfg.test_log_interval) == 0: root = Path( f"{settings.ROOT_PATH}/output/attn/offset_out_noise/rec_imgs/N{cfg.N}/e{epoch}" ) if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}.png") nrow = int(np.sqrt(cfg.batch_size)) rec_img = rec_img.detach().cpu() grid_imgs = vutils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) plt.imshow(grid_imgs.permute(1, 2, 0)) plt.savefig(fn) plt.close('all') ave_psnr = total_psnr / len(test_loader) ave_loss = total_loss / len(test_loader) print( "[Blind: %d | N: %d] Testing results: Ave psnr %2.3e Ave loss %2.3e" % (cfg.blind, cfg.N, ave_psnr, ave_loss)) return ave_psnr
def train_loop(cfg,model_target,model_online,optim_target,optim_online,criterion,train_loader,epoch,record_losses): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # setup for train epoch # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup for training -- model_online.train() model_online = model_online.to(cfg.device) model_target.train() model_target = model_target.to(cfg.device) moving_average_decay = 0.99 ema_updater = EMA(moving_average_decay) # -- init vars -- N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize,1,0,blocksize) D = 5 * 10**3 use_record = False if record_losses is None: record_losses = pd.DataFrame({'burst':[],'ave':[],'ot':[],'psnr':[],'psnr_std':[]}) nc_losses,nc_count = 0,0 al_ot_losses,al_ot_count = 0,0 rec_ot_losses,rec_ot_count = 0,0 write_examples = True write_examples_iter = 800 noise_level = cfg.noise_params['g']['stddev'] one = torch.FloatTensor([1.]).to(cfg.device) switch = True # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # run training epoch # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- for batch_idx, (burst, res_imgs, raw_img, directions) in enumerate(train_loader): if batch_idx > D: break # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # forward pass # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- zero gradient -- optim_online.zero_grad() optim_target.zero_grad() model_online.zero_grad() model_online.denoiser_info.optim.zero_grad() model_target.zero_grad() model_target.denoiser_info.optim.zero_grad() # -- reshaping of data -- N,BS,C,H,W = burst.shape burst = burst.cuda(non_blocking=True) stacked_burst = rearrange(burst,'n b c h w -> b n c h w') # -- create target image -- mid_img = burst[N//2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: t_img = szm(raw_img.cuda(non_blocking=True)) else: t_img = burst[N//2] # -- direct denoising -- aligned_o,aligned_ave_o,denoised_o,rec_img_o,filters_o = model_online(burst) aligned_t,aligned_ave_t,denoised_t,rec_img_t,filters_t = model_target(burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # alignment losses # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute aligned losses to optimize -- rec_img_d_o = rec_img_o.detach() losses = criterion(aligned_o,aligned_ave_o,rec_img_d_o,t_img,raw_zm_img,cfg.global_step) nc_loss,ave_loss,burst_loss,ot_loss = [loss.item() for loss in losses] kpn_loss = losses[1] + losses[2] # np.sum(losses) kpn_coeff = .9997**cfg.global_step # -- OT loss -- al_ot_loss = torch.FloatTensor([0.]).to(cfg.device) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # reconstruction losses # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- decaying rec loss -- rec_mse_coeff = 0.997**cfg.global_step rec_mse_loss = F.mse_loss(rec_img_o,mid_img) # -- BYOL loss -- byol_loss = F.mse_loss(rec_img_o,rec_img_t) # -- OT loss -- rec_ot_coeff = 100 residuals = denoised_o - mid_img.unsqueeze(1).repeat(1,N,1,1,1) residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # rec_ot_loss = ot_pairwise_bp(residuals,K=3) rec_ot_loss = torch.FloatTensor([0.]).to(cfg.device) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # final losses & recording # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- final losses -- align_loss = kpn_coeff * kpn_loss denoise_loss = rec_ot_coeff * rec_ot_loss + byol_loss + rec_mse_coeff * rec_mse_loss # -- update alignment kl loss info -- al_ot_losses += al_ot_loss.item() al_ot_count += 1 # -- update reconstruction kl loss info -- rec_ot_losses += rec_ot_loss.item() rec_ot_count += 1 # -- update info -- if not np.isclose(nc_loss,0): nc_losses += nc_loss nc_count += 1 running_loss += align_loss.item() + denoise_loss.item() total_loss += align_loss.item() + denoise_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # backprop and optimize # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- loss = align_loss + denoise_loss loss.backward() # -- backprop for [online] -- optim_online.step() model_online.denoiser_info.optim.step() # -- exponential moving average for [target] -- update_moving_average(ema_updater,model_target,model_online) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # message to stdout # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img,aligned_ave_o+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst,dim=1) mse_loss = F.mse_loss(raw_img,mis_ave+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0,2)+0.5, sigma_psd=noise_level/255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0,2) b_loss = F.mse_loss(raw_img[b].cpu(),bm3d_rec,reduction='none').reshape(BS,-1) b_loss = torch.mean(b_loss,1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1,N,1,1,1) mse_loss = F.mse_loss(raw_img_repN,denoised_o+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img,rec_img_o+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- write record -- if use_record: record_losses = record_losses.append({'burst':burst_loss,'ave':ave_loss,'ot':ot_loss,'psnr':psnr,'psnr_std':psnr_std},ignore_index=True) # -- update losses -- running_loss /= cfg.log_interval ave_nc_loss = nc_losses / nc_count if nc_count > 0 else 0 # -- alignment kl loss -- ave_al_ot_loss = al_ot_losses / al_ot_count if al_ot_count > 0 else 0 al_ot_losses,al_ot_count = 0,0 # -- reconstruction kl loss -- ave_rec_ot_loss = rec_ot_losses / rec_ot_count if rec_ot_count > 0 else 0 rec_ot_losses,rec_ot_count = 0,0 # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx,len(train_loader),running_loss,psnr,psnr_std, psnr_denoised_ave,psnr_denoised_std,psnr_aligned_ave,psnr_aligned_std, psnr_misaligned_ave,psnr_misaligned_std,bm3d_nb_ave,bm3d_nb_std, ave_nc_loss,ave_rec_ot_loss,ave_al_ot_loss) print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [denoised]: %2.2f +/- %2.2f [aligned]: %2.2f +/- %2.2f [misaligned]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [loss-nc]: %.2e [loss-rot]: %.2e [loss-aot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and (batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg,stacked_burst,aligned_o,denoised_o,filters_o,directions) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss,record_losses