def run_experiment_serial(cfgs, version, mode, use_ddp=True): for cfg in cfgs: t = Timer() # log process record_experiment(cfg, version, mode, 'start', t) print(_build_v1_summary(cfg)) cfg.mode = mode if cfg.mode == "test": cfg.load = True cfg.epoch_num = cfg.epochs # load last epoch else: cfg.load = False cfg.epoch_num = -1 # cfg.load = True # cfg.epoch_num = 500 if use_ddp: # cfg.use_apex = False run_ddp(cfg=cfg) else: run_serial(cfg=cfg) record_experiment(cfg, version, mode, 'start', t)
def run_experiment_parallel(cfgs, version, mode, use_ddp=False): ngpus = 3 nproc_per_gpu = 1 max_procs = nproc_per_gpu * ngpus gpuid = 0 procs = [] # for idx,cfg in enumerate(cfgs): for idx in [0, 3, 6, 9]: cfg = cfgs[idx] # if idx < start_idx: continue print(f"Running idx {idx}") # log process t = Timer() # wait if need to if len(procs) == max_procs: wait(procs) procs = [] # run in proper mode cfg.use_ddp = False if cfg.mode == "test": cfg.load = True cfg.epoch_num = cfg.epochs # load last epoch else: cfg.load = False cfg.epoch_num = -1 # log process record_experiment(cfg, version, mode, 'start', t) # launch process p = run_process(version, idx, gpuid) # print what process is being launched print(_build_summary(cfg, version)) # create separate tensorboard files time.sleep(5) # update current gpuid gpuid = (gpuid + 1) % ngpus # add process to list procs.append(p) # wait if need to if len(procs) > 0: wait(procs) procs = []
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses, writer): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.unet_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) model.unet_info.model = model.unet_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) noise_type = cfg.noise_params.ntype # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 dynamics_acc, dynamics_count = 0, 0 write_examples = False write_examples_iter = 200 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Load Pre-Simulated Random Numbers # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- transforms = [tvF.vflip, tvF.hflip, tvF.rotate] aug = RandomChoice(transforms) def apply_transformations(burst, gt_img): N, B = burst.shape[:2] gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w') all_images = torch.cat([gt_img_rs, burst], dim=0) all_images = rearrange(all_images, 'n b c h w -> (n b) c h w') tv_utils.save_image(all_images, 'aug_original.png', nrow=N + 1, normalize=True) aug_images = aug(all_images) tv_utils.save_image(aug_images, 'aug_augmented.png', nrow=N + 1, normalize=True) aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B) aug_gt_img = aug_images[0] aug_burst = aug_images[1:] return aug_burst, aug_gt_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Half Precision # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # model.align_info.model.half() # model.denoiser_info.model.half() # model.unet_info.model.half() # models = [model.align_info.model, # model.denoiser_info.model, # model.unet_info.model] # for model_l in models: # model_l.half() # for layer in model_l.modules(): # if isinstance(layer, torch.nn.BatchNorm2d): # layer.float() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha, gradient_L1=~cfg.supervised) # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=- # # Noise2Noise # # -=-=-=-=-=-=-=-=-=-=- noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: data_clock = Timer() clock = Timer() ds_size = len(train_loader) small_ds = ds_size < 500 steps_per_epoch = ds_size if not small_ds else 500 write_examples_iter = steps_per_epoch // 3 all_filters = [] # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- dynamics_acc_i = -1. if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) train_iter = iter(train_loader) for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: data_clock.tic() clock.tic() # -- grab data batch -- sample = next(train_iter) burst, raw_img, motion = sample['burst'], sample['clean'], sample[ 'directions'] raw_img_iid = sample['iid'] raw_img_iid = raw_img_iid.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) aligned, est_nnf = align_burst(cfg, burst, model) sim_images = subsample_aligned(cfg, aligned) burst_in, tgt_out = create_training_pairs(burst, sim_images) dn_losses = [] for burst, target in zip(burst_in, tgt_out): # -- forward pass -- est_denoised = model(burst) dn_loss = compute_denoising_loss(est_denoised, target) # -- compute grads -- if cfg.use_seed: torch.set_deterministic(False) dn_loss.backward() if cfg.use_seed: torch.set_deterministic(True) # -- backprop -- optim.step() scheduler.step() # -- store info -- losses.append(dn_loss.item()) # -- average over losses -- dn_loss = torch.mean(dn_losses) # -- alignment loss -- align_loss = compute_nnf_loss(gt_nnf, est_nnf) # -- total loss -- final_loss = dn_loss + align_loss running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- recompute model output for original images -- outputs = model(burst_og) m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) raw_img = get_nmlz_tgt_img(cfg, raw_img) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, m_aligned_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(burst_og, dim=0) if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave) mse_loss = F.mse_loss(raw_img, mis_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25)) # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25)) # -- psnr for [bm3d] -- mid_img_og = burst[N // 2] bm3d_nb_psnrs = [] M = 4 if B > 4 else B for b in range(M): bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) # maybe an issue here b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for input averaged frames -- # burst_ave = torch.mean(burst_og,dim=0) # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1) # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() # psnr_input_ave = np.mean(mse_to_psnr(mse_loss)) # psnr_input_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for aligned + denoised -- R = denoised.shape[1] raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1) # if noise_type == "qis": denoised = quantize_img(cfg,denoised) # save_image(denoised_ave,"denoised_ave.png") # save_image(denoised,"denoised.png") mse_loss = F.mse_loss(raw_img_repN, denoised, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- ave dynamic acc -- ave_dyn_acc = dynamics_acc / dynamics_count * 100. dynamics_acc, dynamics_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, ave_dyn_acc) #rec_ot_ave) #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e" % write_info, flush=True) # -- write to summary writer -- if writer: writer.add_scalar('train/running-loss', running_loss, cfg.global_step) writer.add_scalars('train/model-psnr', { 'ave': psnr, 'std': psnr_std }, cfg.global_step) writer.add_scalars('train/dn-frame-psnr', { 'ave': psnr_denoised_ave, 'std': psnr_denoised_std }, cfg.global_step) # -- reset loss -- running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, motion) if use_timer: clock.toc() if use_timer: print("data_clock", data_clock.average_time) print("clock", clock.average_time) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg, model, optimizer, criterion, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True write_examples_iter = 800 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) D = 5 * 10**3 steps_per_epoch = len(train_loader) # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, BS, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned, aligned_ave, denoised, denoised_ave, filters = model(burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Entropy Loss for Filters # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_shaped = rearrange(filters, 'b n k2 1 1 1 -> (b n) k2', n=N) filters_entropy = entropyLoss(filters_shaped) filters_entropy_coeff = 10. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0 #.933**cfg.global_step if cfg.global_step < 100 else 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_ave_d = denoised_ave.detach() losses = criterion(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- regularization scheduler -- if cfg.global_step < 100: reg = 0.5 elif cfg.global_step < 200: reg = 0.25 elif cfg.global_step < 5000: reg = 0.15 elif cfg.global_step < 10000: reg = 0.1 else: reg = 0.05 # -- computation -- residuals = denoised - mid_img.unsqueeze(1).repeat(1, N, 1, 1, 1) residuals = rearrange(residuals, 'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) rec_ot_pair_loss_v1 = kl_gaussian_bp(residuals, noise_level) # rec_ot_pair_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_pair_loss_v2 = ot_pairwise_bp(residuals,K=3) rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) rec_ot_pair = (rec_ot_pair_loss_v1 + rec_ot_pair_loss_v2) / 2. rec_ot_pair_coeff = 100 # - .997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_loss = align_mse_coeff * align_mse rec_loss = rec_ot_pair_coeff * rec_ot_pair + rec_mse_coeff * rec_mse entropy_loss = filters_entropy_coeff * filters_entropy final_loss = align_loss + rec_loss + entropy_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot_pair.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.denoiser_info.optim.step() optimizer.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg, model, noise_critic, optimizer, criterion, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # Setup for epoch # -=-=-=-=-=-=-=-=-=-=- model.train() model = model.to(cfg.device) N = cfg.N szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) D = 5 * 10**3 use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) write_examples = True write_examples_iter = 800 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Init Record Keeping # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses_nc, losses_nc_count = 0, 0 losses_mse, losses_mse_count = 0, 0 running_loss, total_loss = 0, 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Init Loss Functions # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- lossRecMSE = LossRec(tensor_grad=cfg.supervised) lossBurstMSE = LossRecBurst(tensor_grad=cfg.supervised) # -=-=-=-=-=-=-=-=-=-=- # Final Configs # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True train_iter = iter(train_loader) if use_timer: clock = Timer() # -=-=-=-=-=-=-=-=-=-=- # GAN Scheduler # -=-=-=-=-=-=-=-=-=-=- # -- noise critic steps -- if epoch == 0: disc_steps = 0 elif epoch < 3: disc_steps = 1 elif epoch < 10: disc_steps = 1 else: disc_steps = 1 # -- denoising steps -- if epoch == 0: gen_steps = 1 if epoch < 3: gen_steps = 15 if epoch < 10: gen_steps = 10 else: gen_steps = 10 # -- steps each epoch -- steps_per_iter = disc_steps * gen_steps steps_per_epoch = len(train_loader) // steps_per_iter if steps_per_epoch > 120: steps_per_epoch = 120 # -=-=-=-=-=-=-=-=-=-=- # Start Epoch # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): for gen_step in range(gen_steps): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Setting up for Iteration # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() noise_critic.disc.zero_grad() noise_critic.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, BS, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Formatting Images for FP # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Foward Pass # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned, aligned_ave, denoised, denoised_ave, filters = model( burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # MSE (KPN) Reconstruction Loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- loss_rec = lossRecMSE(denoised_ave, gt_img) loss_burst = lossBurstMSE(denoised, gt_img) loss_mse = loss_rec + 100 * loss_burst gbs, spe = cfg.global_step, steps_per_epoch if epoch < 3: weight_mse = 10 else: weight_mse = 10 * 0.9999**(gbs - 3 * spe) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Noise Critic Loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- loss_nc = noise_critic.compute_residual_loss(denoised, gt_img) weight_nc = 1 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Final Loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- final_loss = weight_mse * loss_mse + weight_nc * loss_nc # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Update Info for Record Keeping # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- update alignment kl loss info -- losses_nc += loss_nc.item() losses_nc_count += 1 # -- update reconstruction kl loss info -- losses_mse += loss_mse.item() losses_mse_count += 1 # -- update info -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Backward Pass # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.denoiser_info.optim.step() optimizer.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Iterate for Noise Critic # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- for disc_step in range(disc_steps): # -- zero gradients -- optimizer.zero_grad() model.zero_grad() model.denoiser_info.optim.zero_grad() noise_critic.disc.zero_grad() noise_critic.optim.zero_grad() # -- grab noisy data -- _burst, _res_imgs, _raw_img, _directions = next(train_iter) _burst = _burst.to(cfg.device) # -- generate "fake" data from noisy data -- _aligned, _aligned_ave, _denoised, _denoised_ave, _filters = model( _burst) _residuals = _denoised - _burst[N // 2].unsqueeze(1).repeat( 1, N, 1, 1, 1) # -- update discriminator -- loss_disc = noise_critic.update_disc(_residuals) # -- message to stdout -- first_update = (disc_step == 0) last_update = (disc_step == disc_steps - 1) iter_update = first_update or last_update # if (batch_idx % cfg.log_interval//2) == 0 and batch_idx > 0 and iter_update: print(f"[Noise Critic]: {loss_disc}") # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # Print Message to Stdout # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- init -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- write record -- if use_record: record_losses = record_losses.append( { 'burst': burst_loss, 'ave': ave_loss, 'ot': ot_loss, 'psnr': psnr, 'psnr_std': psnr_std }, ignore_index=True) # -- update losses -- running_loss /= cfg.log_interval # -- average mse losses -- ave_losses_mse = losses_mse / losses_mse_count losses_mse, losses_mse_count = 0, 0 # -- average noise critic loss -- ave_losses_nc = losses_nc / losses_nc_count losses_nc, losses_nc_count = 0, 0 # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, ave_losses_mse, ave_losses_nc) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [mse]: %.2e [nc]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, stacked_burst, aligned, denoised, filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg, model, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.unet_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) model.unet_info.model = model.unet_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 align_ot_losses, align_ot_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) steps_per_epoch = len(train_loader) write_examples_iter = steps_per_epoch // 2 # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- model.align_info.model.zero_grad() model.align_info.optim.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() model.unet_info.model.zero_grad() model.unet_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, B, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Check Some Gradients # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- def mse_v_wassersteinG_check_some_gradients(cfg, burst, gt_img, model): grads = edict() gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) model.unet_info.model.zero_grad() burst.requires_grad_(True) outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] residuals = denoised - gt_img_rs P = 1. #residuals.numel() denoised.retain_grad() rec_mse = (denoised.reshape(B, -1) - gt_img.reshape(B, -1))**2 rec_mse.retain_grad() ones = P * torch.ones_like(rec_mse) rec_mse.backward(ones, retain_graph=True) grads.rmse = rec_mse.grad.clone().reshape(B, -1) grad_rec_mse = grads.rmse grads.dmse = denoised.grad.clone().reshape(B, -1) grad_denoised_mse = grads.dmse ones = torch.ones_like(rec_mse) grads.d_to_b = torch.autograd.grad(rec_mse, denoised, ones)[0].reshape(B, -1) model.unet_info.model.zero_grad() outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # residuals = denoised - gt_img_rs # rec_ot = w_gaussian_bp(residuals,noise_level) denoised.retain_grad() rec_ot_v = (denoised - gt_img_rs)**2 rec_ot_v.retain_grad() rec_ot = (rec_ot_v.mean() - noise_level / 255.)**2 rec_ot.retain_grad() ones = P * torch.ones_like(rec_ot) rec_ot.backward(ones) grad_denoised_ot = denoised.grad.clone().reshape(B, -1) grads.dot = grad_denoised_ot grad_rec_ot = rec_ot_v.grad.clone().reshape(B, -1) grads.rot = grad_denoised_ot print("Gradient Name Info") for name, g in grads.items(): g_norm = g.norm().item() g_mean = g.mean().item() g_std = g.std().item() print(name, g.shape, g_norm, g_mean, g_std) print_pairs = False if print_pairs: print("All Gradient Ratios") for name_t, g_t in grads.items(): for name_b, g_b in grads.items(): ratio = g_t / g_b ratio_m = ratio.mean().item() ratio_std = ratio.std().item() print("[%s/%s] [%2.2e +/- %2.2e]" % (name_t, name_b, ratio_m, ratio_std)) use_true_mse = False if use_true_mse: print("Ratios with Estimated MSE Gradient") true_dmse = 2 * torch.mean(denoised_ave - gt_img)**2 ratio_mse = grads.dmse / true_dmse ratio_mse_dtb = grads.dmse / grads.d_to_b print(ratio_mse) print(ratio_mse_dtb) dot_v_dmse = True if dot_v_dmse: print("Ratio of Denoised OT and Denoised MSE") ratio_mseot = (grads.dmse / grads.dot) print(ratio_mseot.mean(), ratio_mseot.std()) ratio_mseot = ratio_mseot[0, 0].item() c1 = torch.mean((denoised - gt_img_rs)**2).item() c2 = noise_level / 255 m = torch.mean(gt_img_rs).item() true_ratio = 2. * (c1 - c2) / (np.product(burst.shape)) # diff = denoised.reshape(B,-1)-gt_img_rs.reshape(B,-1) # true_ratio = 2.*(c1 - c2) * ( diff / ( np.product(burst.shape) ) ) # print(c1,c2,m,true_ratio,1./true_ratio) ratio_mseot = (grads.dmse / (grads.dot)) print(ratio_mseot * true_ratio) # ratio_mseot = (grads.dmse / ( grads.dot / diff) ) # print(ratio_mseot*true_ratio) # print(ratio_mseot.mean(),ratio_mseot.std()) exit() model.unet_info.model.zero_grad() # mse_v_wassersteinG_check_some_gradients(cfg,burst,gt_img,model) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Require Approx Equal Filter Norms (aligned) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- aligned_filters_rs = rearrange(aligned_filters, 'b n k2 c h w -> b n (k2 c h w)') norms = torch.norm(aligned_filters_rs, p=2., dim=2) norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N) norm_loss_align = torch.mean( torch.pow(torch.abs(norms - norms_mid), 1.)) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Require Approx Equal Filter Norms (denoised) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_filters = rearrange(denoised_filters, 'b n k2 c h w -> b n (k2 c h w)') norms = torch.norm(denoised_filters, p=2., dim=2) norms_mid = norms[:, N // 2].unsqueeze(1).repeat(1, N) norm_loss_denoiser = torch.mean( torch.pow(torch.abs(norms - norms_mid), 1.)) norm_loss_coeff = 0. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Decrease Entropy within a Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_entropy = 0 filters_entropy_coeff = 0. # 1000. all_filters = [] L = len(align_hook.filters) iter_filters = align_hook.filters if L > 0 else [aligned_filters] for filters in iter_filters: filters_shaped = rearrange(filters, 'b n k2 c h w -> (b n c h w) k2', n=N) filters_entropy += entropyLoss(filters_shaped) all_filters.append(filters) if L > 0: filters_entropy /= L all_filters = torch.stack(all_filters, dim=1) align_hook.clear() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Increase Entropy across each Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_dist_entropy = 0 # -- across each frame -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each batch -- filters_shaped = rearrange(all_filters, 'b l n k2 c h w -> (n l) (b c h w) k2') filters_shaped = torch.mean(filters_shaped, dim=1) filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each kpn cascade -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) filters_dist_coeff = 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0. #0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # pad = 2*cfg.N # fs = cfg.dynamic.frame_size residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) # centered_residuals = tvF.center_crop(residuals,(fs-pad,fs-pad)) # centered_residuals = tvF.center_crop(residuals,(fs//2,fs//2)) # align_ot = kl_gaussian_bp(residuals,noise_level,flip=True) align_ot = kl_gaussian_bp_patches(residuals, noise_level, flip=True, patchsize=16) align_ot_coeff = 0 # 100. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = denoiseLossMSE(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- computation -- gt_img_rs = gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) # rec_ot = kl_gaussian_bp(residuals,noise_level) rec_ot = kl_gaussian_bp(residuals, noise_level, flip=True) # rec_ot /= 2. # alpha_grid = [0.,1.,5.,10.,25.] # for alpha in alpha_grid: # # residuals = torch.normal(torch.zeros_like(residuals)+ gt_img_rs*alpha/255.,noise_level/255.) # residuals = torch.normal(torch.zeros_like(residuals),noise_level/255.+ gt_img_rs*alpha/255.) # rec_ot_v2_a = kl_gaussian_bp_patches(residuals,noise_level,patchsize=16) # rec_ot_v1_b = kl_gaussian_bp(residuals,noise_level,flip=True) # rec_ot_v2_b = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16) # rec_ot_all = torch.tensor([rec_ot_v1_a,rec_ot_v2_a,rec_ot_v1_b,rec_ot_v2_b]) # rec_ot_v2 = (rec_ot_v2_a + rec_ot_v2_b).item()/2. # print(alpha,torch.min(rec_ot_all),torch.max(rec_ot_all),rec_ot_v1,rec_ot_v2) # exit() # rec_ot = w_gaussian_bp(residuals,noise_level) # print(residuals.numel()) rec_ot_coeff = 100. #residuals.numel()*2. # 1000.# - .997**cfg.global_step # residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) # rec_ot_loss_v1 = kl_gaussian_bp(residuals,noise_level,flip=True) # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals) # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3) # rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) # rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse norm_loss = norm_loss_coeff * (norm_loss_denoiser + norm_loss_align) align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot entropy_loss = 0 #filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy # final_loss = align_loss + rec_loss + entropy_loss + norm_loss final_loss = rec_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- alignment Dist -- align_ot_losses += align_ot.item() align_ot_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.align_info.optim.step() model.denoiser_info.optim.step() model.unet_info.optim.step() # for name,params in model.unet_info.model.named_parameters(): # if not ("weight" in name): continue # print(params.grad.norm()) # # print(module.conv1.parameters()) # # print(module.conv1.data.grad) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] M = 10 if B > 10 else B for b in range(B): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- alignment Dist. -- align_ot_ave = align_ot_losses / align_ot_count align_ot_losses, align_ot_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def run_train(cfg, rank, models, data, loader): this_proc_prints = (rank == 0 and cfg.use_ddp) or (cfg.use_ddp is False) s = int(npr.rand() * 5 + 1) time.sleep(s) hyperparams = load_hyperparameters(cfg) criterion_inputs = [hyperparams] criterion_inputs += extract_loss_inputs(cfg, rank) criterion = DenoisingLossDDP(*criterion_inputs) criterion = criterion.to(cfg.device) optimizer = load_optimizer(cfg, models) scheduler = load_scheduler(cfg, optimizer, len(loader.tr)) print("Loaded optimizer: ") print(optimizer) print("Loaded scheduler: ") print(scheduler) # apply apex if cfg.use_apex: models, optimizer = amp.initialize(models, optimizer, opt_level='O2') # for name,param in models.named_parameters(): # wnorm = param.norm() # print("{}: {}".format(name,wnorm)) # init writer if this_proc_prints: writer = SummaryWriter(filename_suffix=cfg.exp_name) else: writer = None # init training loop global_step, current_epoch = get_model_epoch_info(cfg) cfg.global_step = global_step cfg.current_epoch = current_epoch # training loop loop_scheduler = get_loop_scheduler(scheduler) test_losses = {} print(f"cfg.epochs: {cfg.epochs}") print(f"cfg.use_apex: {cfg.use_apex}") print(f"cfg.use_bn: {cfg.use_bn}") print("len of loader.val", len(loader.val)) print(_build_v2_summary(cfg)) t = Timer() for epoch in range(cfg.current_epoch, cfg.epochs): t.tic() loss_epoch = train_loop(cfg, loader.tr, models, criterion, optimizer, epoch, writer, loop_scheduler) t.toc() lr = optimizer.param_groups[0]["lr"] # print(t) # if ms_scheduler: # ms_scheduler.step() if scheduler and loop_scheduler is None: val_loss = test_loop(cfg, models, loader.val) scheduler.step(val_loss) if this_proc_prints: writer.add_scalar("Loss/val", val_loss, epoch) if epoch % cfg.checkpoint_interval == 0 and this_proc_prints: save_denoising_model(cfg, models, optimizer) if epoch % cfg.test_interval == 0: if this_proc_prints: te_loss = test_loop(cfg, models, loader.te) writer.add_scalar("Loss/test", te_loss, epoch) if this_proc_prints: writer.add_scalar("Loss/train", loss_epoch / len(loader.tr), epoch) writer.add_scalar("Misc/learning_rate", lr, epoch) msg = f"Epoch [{epoch}/{cfg.epochs}]\t" msg += f"Loss: {loss_epoch / len(loader.tr)}\t" msg += "{:2.3e}".format(lr) print(msg) cfg.current_epoch += 1 if this_proc_prints: te_loss = test_loop(cfg, models, loader.te) writer.add_scalar("Loss/test", te_loss, epoch) save_denoising_model(cfg, models, optimizer)
# with the default; we can also change it later, e.g., for different batch sizes) net.blobs['data'].reshape( 50, # batch size 3, # 3-channel (BGR) images 224, 224) # image size is 227x227 image_path = 'data/image/67/1698/2005/' image_name = '0a884f5a90267d.jpg' image = caffe.io.load_image(IMAGE_ROOT + image_path + image_name) transformed_image = transformer.preprocess('data', image) #plt.imshow(image) # copy the image data into the memory allocated for the net net.blobs['data'].data[...] = transformed_image timer = Timer() timer.tic() ### perform classification output = net.forward() timer.toc() print('Elapsed time {:.3f} s.').format(timer.total_time) output_prob = output['prob'][ 0] # the output probability vector for the first image in the batch print 'predicted class is:', output_prob.argmax() top_inds = output_prob.argsort( )[::-1][:5] # reverse sort and take five largest items print '5 top probabilities and labels:' print zip(output_prob[top_inds], top_inds)
def train_loop(cfg, model, train_loader, epoch, record_losses): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 align_ot_losses, align_ot_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = True noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss() # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: clock = Timer() train_iter = iter(train_loader) steps_per_epoch = len(train_loader) write_examples_iter = steps_per_epoch // 3 # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: clock.tic() # -- zero gradients; ready 2 go -- model.align_info.model.zero_grad() model.align_info.optim.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() # -- grab data batch -- burst, res_imgs, raw_img, directions = next(train_iter) # -- getting shapes of data -- N, B, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- creating some transforms -- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -- extract target image -- mid_img = burst[N // 2] raw_zm_img = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: gt_img = raw_zm_img else: gt_img = mid_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- outputs = model(burst) aligned, aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Require Approx Equal Filter Norms # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- denoised_filters = rearrange(denoised_filters.detach(), 'b n k2 c h w -> n (b k2 c h w)') norms = denoised_filters.norm(dim=1) norm_loss_denoiser = torch.mean((norms - norms[N // 2])**2) norm_loss_coeff = 100. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Decrease Entropy within a Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_entropy = 0 filters_entropy_coeff = 1000. all_filters = [] L = len(align_hook.filters) iter_filters = align_hook.filters if L > 0 else [aligned_filters] for filters in iter_filters: filters_shaped = rearrange(filters, 'b n k2 c h w -> (b n c h w) k2', n=N) filters_entropy += entropyLoss(filters_shaped) all_filters.append(filters) if L > 0: filters_entropy /= L all_filters = torch.stack(all_filters, dim=1) align_hook.clear() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Increase Entropy across each Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_dist_entropy = 0 # -- across each frame -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b l) (n c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each batch -- filters_shaped = rearrange(all_filters, 'b l n k2 c h w -> (n l) (b c h w) k2') filters_shaped = torch.mean(filters_shaped, dim=1) filters_dist_entropy += -1 * entropyLoss(filters_shaped) # -- across each kpn cascade -- # filters_shaped = rearrange(all_filters,'b l n k2 c h w -> (b n) (l c h w) k2') # filters_shaped = torch.mean(filters_shaped,dim=1) # filters_dist_entropy += -1 * entropyLoss(filters_shaped) filters_dist_coeff = 0 # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = alignmentLossMSE(aligned, aligned_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] align_mse = np.sum(losses) align_mse_coeff = 0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Alignment Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- fs = cfg.dynamic.frame_size residuals = aligned - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) centered_residuals = tvF.center_crop(residuals, (fs // 2, fs // 2)) align_ot = kl_gaussian_bp(centered_residuals, noise_level, flip=True) align_ot_coeff = 100. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = denoiseLossMSE(denoised, denoised_ave, gt_img, cfg.global_step) ave_loss, burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) rec_mse_coeff = 0.95**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- regularization scheduler -- if cfg.global_step < 100: reg = 0.5 elif cfg.global_step < 200: reg = 0.25 elif cfg.global_step < 5000: reg = 0.15 elif cfg.global_step < 10000: reg = 0.1 else: reg = 0.05 # -- computation -- residuals = denoised - gt_img.unsqueeze(1).repeat(1, N, 1, 1, 1) # residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # rec_ot_pair_loss_v1 = w_gaussian_bp(residuals,noise_level) rec_ot_loss_v1 = kl_gaussian_bp(residuals, noise_level, flip=True) # rec_ot_loss_v1 = kl_gaussian_pair_bp(residuals) # rec_ot_loss_v1 = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16) # rec_ot_loss_v1 = ot_pairwise2gaussian_bp(residuals,K=6,reg=reg) # rec_ot_loss_v2 = ot_pairwise_bp(residuals,K=3) rec_ot_pair_loss_v2 = torch.FloatTensor([0.]).to(cfg.device) rec_ot = (rec_ot_loss_v1 + rec_ot_pair_loss_v2) rec_ot_coeff = 100. # - .997**cfg.global_step # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- rec_loss = rec_ot_coeff * rec_ot + rec_mse_coeff * rec_mse norm_loss = norm_loss_coeff * norm_loss_denoiser align_loss = align_mse_coeff * align_mse + align_ot_coeff * align_ot entropy_loss = filters_entropy_coeff * filters_entropy + filters_dist_coeff * filters_dist_entropy final_loss = align_loss + rec_loss + entropy_loss + norm_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- alignment MSE -- align_mse_losses += align_mse.item() align_mse_count += 1 # -- alignment Dist -- align_ot_losses += align_ot.item() align_ot_count += 1 # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot.item() rec_ot_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- model.align_info.optim.step() model.denoiser_info.optim.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, aligned_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(stacked_burst, dim=1) mse_loss = F.mse_loss(raw_img, mis_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] M = 10 if B > 10 else B for b in range(B): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for aligned + denoised -- raw_img_repN = raw_img.unsqueeze(1).repeat(1, N, 1, 1, 1) mse_loss = F.mse_loss(raw_img_repN, denoised + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave + 0.5, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- alignment MSE -- align_mse_ave = align_mse_losses / align_mse_count align_mse_losses, align_mse_count = 0, 0 # -- alignment Dist. -- align_ot_ave = align_ot_losses / align_ot_count align_ot_losses, align_ot_count = 0, 0 # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, rec_ot_ave) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, directions) if use_timer: clock.toc() if use_timer: print(clock) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses, writer): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.unet_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) model.unet_info.model = model.unet_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) noise_type = cfg.noise_params.ntype # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 dynamics_acc, dynamics_count = 0, 0 write_examples = False write_examples_iter = 200 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Load Pre-Simulated Random Numbers # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- transforms = [tvF.vflip, tvF.hflip, tvF.rotate] aug = RandomChoice(transforms) def apply_transformations(burst, gt_img): N, B = burst.shape[:2] gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w') all_images = torch.cat([gt_img_rs, burst], dim=0) all_images = rearrange(all_images, 'n b c h w -> (n b) c h w') tv_utils.save_image(all_images, 'aug_original.png', nrow=N + 1, normalize=True) aug_images = aug(all_images) tv_utils.save_image(aug_images, 'aug_augmented.png', nrow=N + 1, normalize=True) aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B) aug_gt_img = aug_images[0] aug_burst = aug_images[1:] return aug_burst, aug_gt_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Half Precision # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # model.align_info.model.half() # model.denoiser_info.model.half() # model.unet_info.model.half() # models = [model.align_info.model, # model.denoiser_info.model, # model.unet_info.model] # for model_l in models: # model_l.half() # for layer in model_l.modules(): # if isinstance(layer, torch.nn.BatchNorm2d): # layer.float() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha, gradient_L1=~cfg.supervised) # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=- # # Noise2Noise # # -=-=-=-=-=-=-=-=-=-=- noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: data_clock = Timer() clock = Timer() ds_size = len(train_loader) small_ds = ds_size < 500 steps_per_epoch = ds_size if not small_ds else 500 write_examples_iter = steps_per_epoch // 3 all_filters = [] # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- dynamics_acc_i = -1. if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) train_iter = iter(train_loader) for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: data_clock.tic() clock.tic() # -- grab data batch -- if small_ds and batch_idx >= ds_size: if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) train_iter = iter(train_loader) # reset if too big sample = next(train_iter) burst, raw_img, motion = sample['burst'], sample['clean'], sample[ 'flow'] raw_img_iid = sample['iid'] raw_img_iid = raw_img_iid.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) # -- handle possibly cached simulated bursts -- if 'sim_burst' in sample: sim_burst = rearrange(sample['sim_burst'], 'b n k c h w -> n b k c h w') else: sim_burst = None non_sim_method = cfg.n2n or cfg.supervised if sim_burst is None and not (non_sim_method or cfg.abps): if sim_burst is None: if cfg.use_kindex_lmdb: kindex = kindex_ds[batch_idx].cuda(non_blocking=True) else: kindex = None query = burst[[N // 2]] database = torch.cat([burst[:N // 2], burst[N // 2 + 1:]]) sim_burst = compute_similar_bursts( cfg, query, database, cfg.sim_K, noise_level / 255., patchsize=cfg.sim_patchsize, shuffle_k=cfg.sim_shuffleK, kindex=kindex, only_middle=cfg.sim_only_middle, search_method=cfg.sim_method, db_level="frame") if (sim_burst is None) and cfg.abps: # scores,aligned = abp_search(cfg,burst) # scores,aligned = lpas_search(cfg,burst,motion) if cfg.lpas_method == "spoof": mtype = "global" acc = cfg.optical_flow_acc scores, aligned = lpas_spoof(burst, motion, cfg.nblocks, mtype, acc) else: ref_frame = (cfg.nframes + 1) // 2 nblocks = cfg.nblocks method = cfg.lpas_method scores, aligned, dacc = lpas_search(burst, ref_frame, nblocks, motion, method) dynamics_acc_i = dacc # scores,aligned = lpas_spoof(motion,accuracy=cfg.optical_flow_acc) # shuffled = shuffle_aligned_pixels_noncenter(aligned,cfg.nframes) nsims = cfg.nframes sim_aligned = create_sim_from_aligned(burst, aligned, nsims) burst_s = rearrange(burst, 't b c h w -> t b 1 c h w') sim_burst = torch.cat([burst_s, sim_aligned], dim=2) # print("sim_burst.shape",sim_burst.shape) # raw_img = raw_img.cuda(non_blocking=True)-0.5 # # print(np.sqrt(cfg.noise_params['g']['stddev'])) # print(motion) # tiled = tile_across_blocks(burst[[cfg.nframes//2]],cfg.nblocks) # rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2]) # for t in range(cfg.nframes): # save_image(tiled[0] - rep_burst[t],f"tiled_sub_burst_{t}.png") # save_image(aligned,"aligned.png") # print(aligned.shape) # # save_image(aligned[0] - aligned[cfg.nframes//2],"aligned_0.png") # # save_image(aligned[2] - aligned[cfg.nframes//2],"aligned_2.png") # M = (1+cfg.dynamic.ppf)*cfg.nframes # fs = cfg.dynamic.frame_size - M # fs = cfg.frame_size # cropped = crop_center_patch([burst,aligned,raw_img],cfg.nframes,cfg.frame_size) # burst,aligned,raw_img = cropped[0],cropped[1],cropped[2] # print(aligned.shape) # for t in range(cfg.nframes+1): # diff_t = aligned[t] - raw_img # spacing = cfg.nframes+1 # diff_t = crop_center_patch([diff_t],spacing,cfg.frame_size)[0] # print_tensor_stats(f"diff_aligned_{t}",diff_t) # save_image(diff_t,f"diff_aligned_{t}.png") # if t < cfg.nframes: # dt = aligned[t+1]-aligned[t] # dt = crop_center_patch([dt],spacing,cfg.frame_size)[0] # save_image(dt,f"dt_aligned_{t+1}m{t}.png") # save_image(aligned[t],f"aligned_{t}.png") # diff_t = tvF.crop(aligned[t] - raw_img,cfg.nframes,cfg.nframes,fs,fs) # print_tensor_stats(f"diff_aligned_{t}",diff_t) # save_image(burst,"burst.png") # save_image(burst[0] - burst[cfg.nframes//2],"burst_0.png") # save_image(burst[2] - burst[cfg.nframes//2],"burst_2.png") # exit() # print(sample['burst'].shape,sample['res'].shape) # b_clean = sample['burst'] - sample['res'] # scores,ave,t_aligned = test_abp_global_search(cfg,b_clean,noisy_img=burst) # burstBN = rearrange(burst,'n b c h w -> (b n) c h w') # tv_utils.save_image(burstBN,"abps_burst.png",normalize=True) # alignedBN = rearrange(aligned,'n b c h w -> (b n) c h w') # tv_utils.save_image(alignedBN,"abps_aligned.png",normalize=True) # rep_burst = burst[[N//2]].repeat(N,1,1,1,1) # deltaBN = rearrange(aligned - rep_burst,'n b c h w -> (b n) c h w') # tv_utils.save_image(deltaBN,"abps_delta.png",normalize=True) # b_clean_rep = b_clean[[N//2]].repeat(N,1,1,1,1) # tdeltaBN = rearrange(t_aligned - b_clean_rep.cpu(),'n b c h w -> (b n) c h w') # tv_utils.save_image(tdeltaBN,"abps_tdelta.png",normalize=True) if non_sim_method: sim_burst = burst.unsqueeze(2).repeat(1, 1, 2, 1, 1, 1) else: sim_burst = sim_burst.cuda(non_blocking=True) if use_timer: data_clock.toc() # -- to cuda -- burst = burst.cuda(non_blocking=True) raw_zm_img = szm(raw_img.cuda(non_blocking=True)) # anscombe.test(cfg,burst_og) # save_image(burst,f"burst_{batch_idx}_{cfg.n2n}.png") # -- crop images -- if True: #cfg.abps or cfg.abps_inputs: images = [burst, sim_burst, raw_img, raw_img_iid] spacing = burst.shape[0] # we use frames as spacing cropped = crop_center_patch(images, spacing, cfg.frame_size) burst, sim_burst = cropped[0], cropped[1] raw_img, raw_img_iid = cropped[2], cropped[3] if cfg.abps or cfg.abps_inputs: aligned = crop_center_patch([aligned], spacing, cfg.frame_size)[0] # print_tensor_stats("d-eq?",burst[-1] - aligned[-1]) burst = burst[:cfg.nframes] # last frame is target # -- getting shapes of data -- N, B, C, H, W = burst.shape burst_og = burst.clone() # -- shuffle over Simulated Samples -- k_ins, k_outs = create_k_grid(sim_burst, shuffle=True) k_ins, k_outs = [k_ins[0]], [k_outs[0]] # k_ins,k_outs = create_k_grid_v3(sim_burst) for k_in, k_out in zip(k_ins, k_outs): if k_in == k_out: continue # -- zero gradients; ready 2 go -- model.align_info.model.zero_grad() model.align_info.optim.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() model.unet_info.model.zero_grad() model.unet_info.optim.zero_grad() # -- compute input/output data -- if cfg.sim_only_middle and (not cfg.abps): # sim_burst.shape == T,B,K,C,H,W midi = 0 if sim_burst.shape[0] == 1 else N // 2 left_burst, right_burst = burst[:N // 2], burst[N // 2 + 1:] cat_burst = [ left_burst, sim_burst[[midi], :, k_in], right_burst ] burst = torch.cat(cat_burst, dim=0) mid_img = sim_burst[midi, :, k_out] elif cfg.abps and (not cfg.abps_inputs): # -- v1 -- mid_img = aligned[-1] # -- v2 -- # left_aligned,right_aligned = aligned[:N//2],aligned[N//2+1:] # nc_aligned = torch.cat([left_aligned,right_aligned],dim=0) # shuf = shuffle_aligned_pixels(nc_aligned,cfg.nframes) # mid_img = shuf[1] # ---- v3 ---- # shuf = shuffle_aligned_pixels(aligned) # shuf = aligned[[N//2,0]] # midi = 0 if sim_burst.shape[0] == 1 else N//2 # left_burst,right_burst = burst[:N//2],burst[N//2+1:] # burst = torch.cat([left_burst,shuf[[0]],right_burst],dim=0) # nc_burst = torch.cat([left_burst,right_burst],dim=0) # shuf = shuffle_aligned_pixels(aligned) # ---- v4 ---- # nc_shuf = shuffle_aligned_pixels(nc_aligned) # mid_img = nc_shuf[0] # pick = npr.randint(0,2,size=(1,))[0] # mid_img = nc_aligned[pick] # mid_img = shuf[1] # save_image(shuf,"shuf.png") # print(shuf.shape) # diff = raw_img.cuda(non_blocking=True) - aligned[0] # mean = torch.mean(diff).item() # std = torch.std(diff).item() # print(mean,std) # -- v1 -- # burst = burst # notMid = sample_not_mid(N) # mid_img = aligned[notMid] elif cfg.abps_inputs: burst = aligned.clone() burst_og = aligned.clone() mid_img = shuffle_aligned_pixels(burst, cfg.nframes)[0] else: burst = sim_burst[:, :, k_in] mid_img = sim_burst[N // 2, :, k_out] # mid_img = sim_burst[N//2,:] # print(burst.shape,mid_img.shape) # print(F.mse_loss(burst,mid_img).item()) if cfg.supervised: gt_img = get_nmlz_tgt_img(cfg, raw_img).cuda(non_blocking=True) elif cfg.n2n: gt_img = raw_img_iid #noise_xform(raw_img).cuda(non_blocking=True) else: gt_img = mid_img # another = noise_xform(raw_img).cuda(non_blocking=True) # print_tensor_stats("a-iid?",raw_img_iid.cuda() - raw_img.cuda()) # print_tensor_stats("b-iid?",mid_img.cuda() - raw_img.cuda()) # print_tensor_stats("c-iid?",mid_img.cuda() - another) # print_tensor_stats("d-iid?",raw_img_iid.cuda() - another) # print_tensor_stats("e-iid?",mid_img.cuda() - raw_img_iid.cuda()) # for bt in range(cfg.nframes): # tiled = tile_across_blocks(burst[[bt]],cfg.nblocks) # rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2]) # for t in range(cfg.nframes): # save_image(tiled[0] - rep_burst[t],f"tiled_{bt}_sub_burst_{t}.png") # print_tensor_stats(f"delta_{bt}_{t}",tiled[0,:,4] - burst[t]) # raw_img = raw_img.cuda(non_blocking=True) - 0.5 # print_tensor_stats("gt_img - raw",gt_img - raw_img) # # save_image(gt_img,"gt.png") # # save_image(raw,"raw.png") # save_image(gt_img - raw_img,"gt_sub_raw.png") # print_tensor_stats("burst[N//2] - raw",burst[N//2] - raw_img) # save_image(burst[N//2] - raw_img,"burst_sub_raw.png") # print_tensor_stats("burst[N//2] - gt_img",burst[N//2] - gt_img) # save_image(burst[N//2] - gt_img,"burst_sub_gt.png") # print_tensor_stats("aligned[N//2] - raw",aligned[N//2] - raw_img) # save_image(aligned[N//2] - raw_img,"aligned_sub_raw.png") # print_tensor_stats("aligned[N//2] - burst[N//2]", # aligned[N//2] - burst[N//2]) # save_image(aligned[N//2] - burst[N//2],"aligned_sub_burst.png") # gt_img = torch.normal(raw_zm_img,noise_level/255.) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # burst,gt_img = apply_transformations(burst,gt_img) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- outputs = model(burst) m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Decrease Entropy within a Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_entropy = 0 filters_entropy_coeff = 0. # 1000. all_filters = [] L = len(align_hook.filters) iter_filters = align_hook.filters if L > 0 else [aligned_filters] for filters in iter_filters: f_shape = 'b n k2 c h w -> (b n c h w) k2' filters_shaped = rearrange(filters, f_shape) filters_entropy += one #entropyLoss(filters_shaped) all_filters.append(filters) if L > 0: filters_entropy /= L all_filters = torch.stack(all_filters, dim=1) align_hook.clear() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = [F.mse_loss(denoised_ave, gt_img)] # losses = denoiseLossMSE(denoised,denoised_ave,gt_img,cfg.global_step) # losses = [ one, one ] # ave_loss,burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) # rec_mse = F.mse_loss(denoised_ave,gt_img) rec_mse_coeff = 1. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- gt_img_rep = gt_img.unsqueeze(1).repeat(1, denoised.shape[1], 1, 1, 1) residuals = denoised - gt_img_rep rec_ot = torch.FloatTensor([0.]).to(cfg.device) # rec_ot = kl_gaussian_bp(residuals,noise_level,flip=True) # rec_ot = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16) if torch.any(torch.isnan(rec_ot)): rec_ot = torch.FloatTensor([0.]).to(cfg.device) if torch.any(torch.isinf(rec_ot)): rec_ot = torch.FloatTensor([0.]).to(cfg.device) rec_ot_coeff = 0. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- rec_loss = rec_mse_coeff * rec_mse + rec_ot_coeff * rec_ot final_loss = rec_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot.item() rec_ot_count += 1 # -- dynamic acc - dynamics_acc += dynamics_acc_i dynamics_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- if cfg.use_seed: torch.set_deterministic(False) final_loss.backward() if cfg.use_seed: torch.set_deterministic(True) # -- backprop now. -- model.align_info.optim.step() model.denoiser_info.optim.step() model.unet_info.optim.step() scheduler.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- recompute model output for original images -- outputs = model(burst_og) m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) raw_img = get_nmlz_tgt_img(cfg, raw_img) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, m_aligned_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(burst_og, dim=0) if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave) mse_loss = F.mse_loss(raw_img, mis_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25)) # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25)) # -- psnr for [bm3d] -- mid_img_og = burst[N // 2] bm3d_nb_psnrs = [] M = 4 if B > 4 else B for b in range(M): bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) # maybe an issue here b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for input averaged frames -- # burst_ave = torch.mean(burst_og,dim=0) # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1) # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() # psnr_input_ave = np.mean(mse_to_psnr(mse_loss)) # psnr_input_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for aligned + denoised -- R = denoised.shape[1] raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1) # if noise_type == "qis": denoised = quantize_img(cfg,denoised) # save_image(denoised_ave,"denoised_ave.png") # save_image(denoised,"denoised.png") mse_loss = F.mse_loss(raw_img_repN, denoised, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- ave dynamic acc -- ave_dyn_acc = dynamics_acc / dynamics_count * 100. dynamics_acc, dynamics_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, ave_dyn_acc) #rec_ot_ave) #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e" % write_info, flush=True) # -- write to summary writer -- if writer: writer.add_scalar('train/running-loss', running_loss, cfg.global_step) writer.add_scalars('train/model-psnr', { 'ave': psnr, 'std': psnr_std }, cfg.global_step) writer.add_scalars('train/dn-frame-psnr', { 'ave': psnr_denoised_ave, 'std': psnr_denoised_std }, cfg.global_step) # -- reset loss -- running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, motion) if use_timer: clock.toc() if use_timer: print("data_clock", data_clock.average_time) print("clock", clock.average_time) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def train_loop(cfg, model, optimizer, scheduler, train_loader, epoch, record_losses, writer): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.train() model = model.to(cfg.gpuid) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) noise_type = cfg.noise_params.ntype # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 write_examples = False write_examples_iter = 200 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Load Pre-Simulated Random Numbers # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- transforms = [tvF.vflip, tvF.hflip, tvF.rotate] aug = RandomChoice(transforms) def apply_transformations(burst, gt_img): N, B = burst.shape[:2] gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w') all_images = torch.cat([gt_img_rs, burst], dim=0) all_images = rearrange(all_images, 'n b c h w -> (n b) c h w') tv_utils.save_image(all_images, 'aug_original.png', nrow=N + 1, normalize=True) aug_images = aug(all_images) tv_utils.save_image(aug_images, 'aug_augmented.png', nrow=N + 1, normalize=True) aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B) aug_gt_img = aug_images[0] aug_burst = aug_images[1:] return aug_burst, aug_gt_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Half Precision # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # model.align_info.model.half() # model.denoiser_info.model.half() # model.unet_info.model.half() # models = [model.align_info.model, # model.denoiser_info.model, # model.unet_info.model] # for model_l in models: # model_l.half() # for layer in model_l.modules(): # if isinstance(layer, torch.nn.BatchNorm2d): # layer.float() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha, gradient_L1=~cfg.supervised) # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=- # # Noise2Noise # # -=-=-=-=-=-=-=-=-=-=- noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- random_crop = tvT.RandomCrop(cfg.byol_patchsize) use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: data_clock = Timer() clock = Timer() train_iter = iter(train_loader) ds_size = len(train_loader) small_ds = ds_size < 500 steps_per_epoch = ds_size if not small_ds else 500 write_examples_iter = steps_per_epoch // 3 all_filters = [] # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: data_clock.tic() clock.tic() # -- grab data batch -- if small_ds and batch_idx >= ds_size: train_iter = iter(train_loader) # reset if too big sample = next(train_iter) burst, raw_img, directions = sample['burst'], sample['clean'], sample[ 'directions'] burst = burst.cuda(non_blocking=True) # -- handle possibly cached simulated bursts -- if 'sim_burst' in sample: sim_burst = rearrange(sample['sim_burst'], 'b n k c h w -> n b k c h w') else: sim_burst = None if sim_burst is None and not (cfg.n2n or cfg.supervised): if sim_burst is None: if cfg.use_kindex_lmdb: kindex = kindex_ds[batch_idx].cuda(non_blocking=True) else: kindex = None query = burst[[N // 2]] database = torch.cat([burst[:N // 2], burst[N // 2 + 1:]]) sim_burst = compute_similar_bursts( cfg, query, database, cfg.sim_K, noise_level / 255., patchsize=cfg.sim_patchsize, shuffle_k=cfg.sim_shuffleK, kindex=kindex, only_middle=cfg.sim_only_middle, search_method=cfg.sim_method, db_level="frame") if cfg.n2n or cfg.supervised: sim_burst = burst.unsqueeze(2).repeat(1, 1, 2, 1, 1, 1) else: sim_burst = sim_burst.cuda(non_blocking=True) if use_timer: data_clock.toc() # -- getting shapes of data -- N, B, C, H, W = burst.shape burst = burst.cuda(non_blocking=True) raw_zm_img = szm(raw_img.cuda(non_blocking=True)) burst_og = burst.clone() mid_img_og = burst[N // 2] # -- shuffle over Simulated Samples -- k_ins, k_outs = create_k_grid(sim_burst, shuffle=True) # k_ins,k_outs = [k_ins[0]],[k_outs[0]] for k_in, k_out in zip(k_ins, k_outs): if k_in == k_out: continue # -- zero gradients; ready 2 go -- optimizer.zero_grad() model.zero_grad() # -- compute input/output data -- if cfg.sim_only_middle: midi = 0 if sim_burst.shape[0] == 1 else N // 2 left_burst, right_burst = burst[:N // 2], burst[N // 2 + 1:] burst = torch.cat( [left_burst, sim_burst[[midi], :, k_in], right_burst], dim=0) mid_img = sim_burst[midi, :, k_out] else: burst = sim_burst[:, :, k_in] mid_img = sim_burst[N // 2, :, k_out] # mid_img = sim_burst[N//2,:] # print(burst.shape,mid_img.shape) # print(F.mse_loss(burst,mid_img).item()) if cfg.supervised: gt_img = get_nmlz_img(cfg, raw_img).cuda(non_blocking=True) elif cfg.n2n: gt_img = noise_xform(raw_img).cuda(non_blocking=True) else: gt_img = mid_img # gt_img = torch.normal(raw_zm_img,noise_level/255.) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # burst,gt_img = apply_transformations(burst,gt_img) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Experimentally Set Hyperparams # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- [before training] setting the ps and nh -- # test_ps_nh_sizes(cfg,model,burst) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images & FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- patches = sample_burst_patches(cfg, model, burst + 0.5) input_patches_0 = model.patch_helper.form_input_patches(patches) f_patches = torch.flip(patches, dims=(0, )) # reverse input_patches_1 = model.patch_helper.form_input_patches(f_patches) final_loss = model(input_patches_0) final_loss += model(input_patches_1) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- final_loss.backward() # -- backprop now. -- optimizer.step() model.update_moving_average() # scheduler.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- update losses -- running_loss /= cfg.log_interval # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss) print("[%d/%d][%d/%d]: %2.3e" % write_info) nbatches = 2 burst = burst[:, :nbatches] # limit batch size to run test psnrs_sim = test_sim_search(cfg, burst + 0.5, model) psnrs_ftr = psnrs_sim[cfg.byol_backbone_name] psnrs_pix = psnrs_sim["pix"] print_psnr_results(psnrs_ftr, "[PSNR-ftr]") print_psnr_results(psnrs_pix, "[PSNR-pix]") print_edge_info(burst) # psnrs = test_sim_search(cfg,burst,model) # print_psnr_results(psnrs,"[PSNR-ftr]") # psnrs = test_sim_search_pix(cfg,burst,model) # print_psnr_results(psnrs,"[PSNR-pix]") # -- reset loss -- running_loss = 0 if use_timer: clock.toc() if use_timer: print("data_clock", data_clock.average_time) print("clock", clock.average_time) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss, record_losses
def run_train(cfg, rank, model, data, loader): this_proc_prints = (rank == 0 and cfg.use_ddp) or (cfg.use_ddp is False) s = int(npr.rand() * 5 + 1) time.sleep(s) hyperparams = load_hyperparameters(cfg) criterion_inputs = [hyperparams] criterion = ClBlockLoss(hyperparams, cfg.N, cfg.batch_size) criterion = criterion.to(cfg.device) optimizer = load_optimizer(cfg, model) scheduler = load_scheduler(cfg, optimizer, len(loader.tr)) print("Loaded optimizer: ") print(optimizer) print("Loaded scheduler: ") print(scheduler) # apply apex if cfg.use_apex: model, optimizer = amp.initialize(model, optimizer, opt_level='O2') # init writer if this_proc_prints: datetime_now = datetime.datetime.now().strftime("%b%d_%H-%M-%S") writer_dir = cfg.summary_log_dir / Path(datetime_now) writer = SummaryWriter(writer_dir) else: writer = None # init training loop global_step, current_epoch = get_model_epoch_info(cfg) cfg.global_step = global_step cfg.current_epoch = current_epoch # training loop loop_scheduler = get_loop_scheduler(scheduler) test_losses = {} # if this_proc_prints: # spawn_split_eval(cfg,'val',writer,24) print(f"cfg.epochs: {cfg.epochs}") print(f"cfg.use_apex: {cfg.use_apex}") print(f"cfg.use_bn: {cfg.use_bn}") print("len of loader.val", len(loader.val)) print(f"cfg.optim_type: {cfg.optim_type}") print(f"cfg.checkpoint_interval: {cfg.checkpoint_interval}") print(_build_v2_summary(cfg)) t = Timer() for epoch in range(cfg.current_epoch, cfg.epochs): t.tic() loss_epoch = train_loop(cfg, loader.tr, model, criterion, optimizer, epoch, writer, loop_scheduler) t.toc() lr = optimizer.param_groups[0]["lr"] # print(t) # if ms_scheduler: # ms_scheduler.step() if epoch % cfg.checkpoint_interval == 0 and this_proc_prints and epoch > 0: save_simcl_model(cfg, model, optimizer) if scheduler and loop_scheduler is None and epoch % cfg.val_interval == 0 and epoch > 0: val_loss = test_loop(cfg, model, 'val') scheduler.step(val_loss) if this_proc_prints: writer.add_scalar("Loss/val", val_loss, epoch) elif epoch % cfg.val_interval == 0 and this_proc_prints and epoch > 0: if this_proc_prints: # spawn_split_eval(cfg,'val',writer,epoch) val_loss = test_loop(cfg, model, 'val') writer.add_scalar("Loss/val", val_loss, epoch) if epoch % cfg.test_interval == 0 and epoch > 0: if this_proc_prints: # spawn_split_eval(cfg,'test',writer,epoch) te_loss = test_loop(cfg, model, 'test') writer.add_scalar("Loss/test", te_loss, epoch) if this_proc_prints: writer.add_scalar("Loss/train", loss_epoch / len(loader.tr), epoch) writer.add_scalar("Misc/learning_rate", lr, epoch) msg = f"Epoch [{epoch}/{cfg.epochs}]\t" msg += f"Loss: {loss_epoch / len(loader.tr)}\t" msg += "{:2.3e}".format(lr) print(msg) cfg.current_epoch += 1 if this_proc_prints: te_loss = test_loop(cfg, model, 'test') writer.add_scalar("Loss/test", te_loss, epoch) save_simcl_model(cfg, model, optimizer)
def thtrain_denoising(cfg, train_loader, model, criterion, optimizer, epoch, writer, scheduler=None): model.train() # model.encoder.eval() idx = 0 loss_epoch = 0 data = train_loader.dataset.data print("N samples:", len(data)) simcl_t, loss_t, optim_t = Timer(), Timer(), Timer() for batch_idx, (noisy_imgs, raw_img) in enumerate(train_loader): optimizer.zero_grad() # setup the forward pass idx += cfg.batch_size noisy_imgs = noisy_imgs.cuda(non_blocking=True) simcl_t.tic() dec_imgs, proj = model(noisy_imgs) simcl_t.toc() loss_t.tic() # print(noisy_imgs.mean().item(),noisy_imgs.max().item(),noisy_imgs.min().item()) # print(dec_imgs.mean().item(),dec_imgs.max().item(),dec_imgs.min().item()) loss = criterion(noisy_imgs, dec_imgs, proj) loss_t.toc() # print(dec_imgs.mean(),dec_imgs.min(),dec_imgs.max()) # compute gradients if cfg.use_apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() # print(loss.item()) # print(loss.grad) # psum = 0 # for param in model.decoder.parameters(): # pnorm = param.grad.norm() # print(pnorm) # psum += pnorm # print("Overall: {:2.3e}".format(psum)) # exit() # update weights optim_t.tic() optimizer.step() optim_t.toc() if scheduler: scheduler.step() # print updates if writer: writer.add_scalar("Loss/train_epoch", loss.item(), cfg.global_step) cfg.global_step += 1 if batch_idx % cfg.log_interval == 0 and writer: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, cfg.world_size * batch_idx * cfg.batch_size, len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) loss_epoch += loss.item() print(simcl_t) print(loss_t) print(optim_t) return loss_epoch
def compare_sim_images_methods(cfg,burst,K,patchsize=3): b = 0 n = 0 ps = patchsize img = burst[n,b] single_burst = img.unsqueeze(0).unsqueeze(0) img_rs = rearrange(img,'c h w -> h w c').cpu() t_1 = Timer() t_1.tic() sim_burst_v1 = rearrange(torch.tensor(compute_sim_images(img_rs, patchsize, K)),'k h w c -> k c h w') t_1.toc() t_2 = Timer() t_2.tic() sim_burst_v2 = compute_similar_bursts(cfg,single_burst,K,patchsize=3,shuffle_k=False)[0,0,1:] t_2.toc() print(t_1,t_2) print("v1",sim_burst_v1.shape) print("v2",sim_burst_v2.shape) for k in range(K): print("mse-v1-{}".format(k),F.mse_loss(sim_burst_v1[k].cpu(),img.cpu())) print("mse-v2-{}".format(k),F.mse_loss(sim_burst_v2[k].cpu(),img.cpu())) print("mse-{}".format(k),F.mse_loss(sim_burst_v1[k].cpu(),sim_burst_v2[k].cpu()))
def main(): args = get_args() cfg = get_cfg(args) gpuid = 2 cfg.device = f"cuda:{gpuid}" torch.cuda.set_device(gpuid) cfg.S = 50000 cfg.N = 30 cfg.noise_type = 'g' cfg.noise_params['g']['stddev'] = 25 cfg.dataset.name = "cifar10" cfg.dynamic = edict() cfg.dynamic.bool = False cfg.dynamic.ppf = 2 cfg.dynamic.frames = cfg.N cfg.dynamic.mode = "global" cfg.dynamic.global_mode = "shift" cfg.dynamic.frame_size = 128 cfg.dynamic.total_pixels = 20 cfg.use_ddp = False cfg.use_collate = True cfg.set_worker_seed = False cfg.batch_size = 8 cfg.num_workers = 4 data, loader = get_dataset(cfg, 'single_denoising') # noisy_trans = data.tr._get_noise_transform(cfg.noise_type,cfg.noise_params[cfg.noise_type]) # motion = GlobalCameraMotionTransform(cfg.dynamic,noisy_trans,True) noisy_l, res_l, raw_l = [], [], [] timer = Timer() timer.tic() for index in range(cfg.batch_size): noisy, raw = data.tr[index] res = data.tr.noise_set[index] print(noisy.shape, raw.shape, res.shape) noisy_l.append(noisy), res_l.append(res), raw_l.append(raw) timer.toc() print(timer) noisy = torch.stack(noisy_l, dim=1) res = torch.stack(res_l, dim=1) raw = torch.stack(raw_l, dim=1) # noisy,raw = next(iter(loader.tr)) # noisy += 0.5 # noisy.clamp_(0,1.) # rec += 0.5 # rec.clamp_(0,1.) raw = raw.expand(noisy.shape) print(noisy.shape) images = torch.cat([noisy, res, raw], dim=1) print("pre", images.shape) images = images.transpose(0, 1) print("post", images.shape) fig, ax = plt.subplots(figsize=(10, 10)) grid = vutils.make_grid(images, nrow=8) print(grid.shape) ax.imshow(np.transpose(grid, (1, 2, 0))) path = f"{settings.ROOT_PATH}/output/vis_cifar10.png" plt.savefig(path) # grids = [vutils.make_grid(images[i],nrow=8) for i in range(cfg.dynamic.frames)] # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids] # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) # Writer = animation.writers['ffmpeg'] # writer = Writer(fps=3, metadata=dict(artist='Me'), bitrate=1800) # path = f"{settings.ROOT_PATH}/test_voc.mp4" # ani.save(path, writer=writer) print(f"Wrote to {path}")
def main(): args = get_args() cfg = get_cfg(args) cfg.S = 50000 cfg.N = 30 cfg.noise_type = 'g' cfg.noise_params['g']['stddev'] = 50 cfg.dataset.name = "voc" cfg.dynamic = edict() cfg.dynamic.bool = True cfg.dynamic.ppf = 2 cfg.dynamic.frames = cfg.N cfg.dynamic.mode = "global" cfg.dynamic.global_mode = "shift" cfg.dynamic.frame_size = 128 cfg.dynamic.total_pixels = 20 cfg.use_ddp = False cfg.use_collate = True cfg.set_worker_seed = False cfg.batch_size = 8 cfg.num_workers = 4 data, loader = get_dataset(cfg, 'dynamic') noisy_trans = data.tr._get_noise_transform( cfg.noise_type, cfg.noise_params[cfg.noise_type]) motion = GlobalCameraMotionTransform(cfg.dynamic, noisy_trans, True) noisy_l, rec_l, raw_l = [], [], [] timer = Timer() timer.tic() for index in range(cfg.batch_size): img = Image.open(data.tr.images[index]) noisy, rec, raw = motion(img) noisy_l.append(noisy), rec_l.append(rec), raw_l.append(raw) timer.toc() print(timer) noisy = torch.stack(noisy_l, dim=1) rec = torch.stack(rec_l, dim=1) raw = torch.stack(raw_l) # noisy,raw = next(iter(loader.tr)) noisy += 0.5 noisy.clamp_(0, 1.) rec += 0.5 rec.clamp_(0, 1.) raw = raw.expand(noisy.shape) images = torch.cat([noisy, rec, raw], dim=1) fig, ax = plt.subplots(figsize=(10, 10)) grids = [ vutils.make_grid(images[i], nrow=8) for i in range(cfg.dynamic.frames) ] ims = [[ax.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in grids] ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) Writer = animation.writers['ffmpeg'] writer = Writer(fps=3, metadata=dict(artist='Me'), bitrate=1800) path = f"{settings.ROOT_PATH}/test_voc.mp4" ani.save(path, writer=writer) print(f"Wrote to {path}")