def train_loop(cfg, model, optimizer, criterion, train_loader, epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 for batch_idx, (burst_imgs, raw_img) in enumerate(train_loader): # for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) # res_imgs = res_imgs.cuda(non_blocking=True) img0 = burst_imgs[0] # img0,res0 = burst_imgs[0],res_imgs[0] # img1,res1 = burst_imgs[1],res_imgs[1] # -- predict residual -- pred_res = model(img0) rec_img = img0 - pred_res # -- compare with stacked burst -- loss = F.mse_loss(raw_img, rec_img + 0.5) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img, rec_img + 0.5, reduction='none').reshape(BS, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e" % (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr)) running_loss = 0 total_loss /= len(train_loader) return total_loss
def test_loop(cfg, model, criterion, test_loader, epoch): model.eval() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 with torch.no_grad(): for batch_idx, (burst_imgs, raw_img) in enumerate(test_loader): # for batch_idx, (burst_imgs, res_img, raw_img) in enumerate(test_loader): BS = raw_img.shape[0] # reshaping of data raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) img0 = burst_imgs[0] # denoising pred_res = model(img0) rec_img = img0 - pred_res # compare with stacked targets rec_img = rescale_noisy_image(rec_img) loss = F.mse_loss(raw_img, rec_img, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = mse_to_psnr(loss) total_psnr += np.mean(psnr) total_loss += np.mean(loss) if (batch_idx % cfg.test_log_interval) == 0: root = Path( f"{settings.ROOT_PATH}/output/n2n/rec_imgs/e{epoch}") if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}.png") nrow = int(np.sqrt(cfg.batch_size)) rec_img = rec_img.detach().cpu() grid_imgs = vutils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) plt.imshow(grid_imgs.permute(1, 2, 0)) plt.savefig(fn) plt.close('all') ave_psnr = total_psnr / len(test_loader) ave_loss = total_loss / len(test_loader) print("Testing results: Ave psnr %2.3e Ave loss %2.3e" % (ave_psnr, ave_loss)) return ave_psnr
def train_loop_offset(cfg,model,optimizer,criterion,train_loader,epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() # random_eraser = th_trans.RandomErasing(scale=(0.40,0.80)) random_eraser = th_trans.RandomErasing(scale=(0.02,0.33)) # if cfg.N != 5: return # for batch_idx, (burst_imgs, raw_img) in enumerate(train_loader): for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # fig,ax = plt.subplots(figsize=(10,10)) # imgs = burst_imgs + 0.5 # imgs.clamp_(0.,1.) # raw_img = raw_img.expand(burst_imgs.shape) # print(imgs.shape,raw_img.shape) # all_img = torch.cat([imgs,raw_img],dim=1) # print(all_img.shape) # grids = [vutils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)] # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids] # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) # Writer = animation.writers['ffmpeg'] # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800) # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer) # print("I DID IT!") # return # -- reshaping of data -- # raw_img = raw_img.cuda(non_blocking=True) input_order = np.arange(cfg.N) # print("pre",input_order,cfg.blind,cfg.N) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = len(input_order) // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 middle_img_idx = input_order[middle] input_order = np.arange(cfg.N) # print("post",input_order,middle_img_idx,cfg.blind,cfg.N) # -- add input noise -- burst_imgs = burst_imgs.cuda(non_blocking=True) burst_imgs_noisy = burst_imgs.clone() if cfg.input_noise: # noise = np.random.rand() * cfg.input_noise_level noise = cfg.input_noise_level if cfg.input_noise_middle_only: burst_imgs_noisy[middle_img_idx] = torch.normal(burst_imgs_noisy[middle_img_idx],noise) else: burst_imgs_noisy = torch.normal(burst_imgs_noisy,noise) # if cfg.middle_frame_random_erase: # for i in range(burst_imgs_noisy[middle_img_idx].shape[0]): # tmp = random_eraser(burst_imgs_noisy[middle_img_idx][i]) # burst_imgs_noisy[middle_img_idx][i] = tmp # burst_imgs_noisy = torch.normal(burst_imgs_noisy,noise) # print(torch.sum(burst_imgs_noisy[middle_img_idx] - burst_imgs[middle_img_idx])) # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)]) if cfg.color_cat: stacked_burst = torch.cat([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) else: stacked_burst = torch.stack([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) # if cfg.input_noise: # stacked_burst = torch.normal(stacked_burst,noise) # -- extract target image -- if cfg.blind: t_img = burst_imgs[middle_img_idx] else: t_img = szm(raw_img.cuda(non_blocking=True)) # -- denoising -- rec_img = model(stacked_burst) # -- compute loss -- loss = F.mse_loss(t_img,rec_img) # -- dncnn denoising -- # rec_res = model(stacked_burst) # -- compute loss -- # t_res = t_img - burst_imgs[middle_img_idx] # loss = F.mse_loss(t_res,rec_res) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for fun -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img,rec_img+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.3e"%(epoch, cfg.epochs, batch_idx, len(train_loader), running_loss,psnr)) running_loss = 0 total_loss /= len(train_loader) return total_loss
def test_loop_offset(cfg,model,criterion,test_loader,epoch): model.eval() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 with torch.no_grad(): for batch_idx, (burst_imgs, res_imgs, raw_img) in enumerate(test_loader): # for batch_idx, (burst_imgs, raw_img) in enumerate(test_loader): BS = raw_img.shape[0] # -- selecting input frames -- input_order = np.arange(cfg.N) # print("pre",input_order) # if cfg.blind or True: middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = cfg.N // 2 # print(middle) middle_img_idx = input_order[middle] input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: # input_order = np.arange(cfg.N) middle = len(input_order) // 2 middle_img_idx = input_order[middle] input_order = np.arange(cfg.N) # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) if cfg.color_cat: stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) else: stacked_burst = torch.stack([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) # -- direct denoising -- rec_img = model(stacked_burst) # -- dncnn denoising -- # rec_res = model(stacked_burst) # rec_img = burst_imgs[middle_img_idx] + rec_res # -- compare with stacked targets -- rec_img = rescale_noisy_image(rec_img) loss = F.mse_loss(raw_img,rec_img,reduction='none').reshape(BS,-1) loss = torch.mean(loss,1).detach().cpu().numpy() psnr = mse_to_psnr(loss) total_psnr += np.mean(psnr) total_loss += np.mean(loss) if (batch_idx % cfg.test_log_interval) == 0: root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/rec_imgs/N{cfg.N}/e{epoch}") if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}.png") nrow = int(np.sqrt(cfg.batch_size)) rec_img = rec_img.detach().cpu() grid_imgs = vutils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) plt.imshow(grid_imgs.permute(1,2,0)) plt.savefig(fn) plt.close('all') ave_psnr = total_psnr / len(test_loader) ave_loss = total_loss / len(test_loader) print("[Blind: %d | N: %d] Testing results: Ave psnr %2.3e Ave loss %2.3e"%(cfg.blind,cfg.N,ave_psnr,ave_loss)) return ave_psnr
def test_loop(cfg, model, test_loader, epoch): model.eval() model.align_info.model.eval() model.denoiser_info.model.eval() model.unet_info.model.eval() model = model.to(cfg.device) noise_type = cfg.noise_params.ntype total_psnr = 0 total_loss = 0 use_record = False record_test = pd.DataFrame({'psnr': []}) if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) test_iter = iter(test_loader) num_batches, D = 25, len(test_iter) num_batches = D num_batches = num_batches if D > num_batches else D psnrs = np.zeros((num_batches, cfg.batch_size)) with torch.no_grad(): for batch_idx in range(num_batches): sample = next(test_iter) burst, raw_img, motion = sample['burst'], sample['clean'], sample[ 'directions'] B = raw_img.shape[0] # -- selecting input frames -- input_order = np.arange(cfg.N) # print("pre",input_order) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = cfg.N // 2 # print(middle) middle_img_idx = input_order[middle] # input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 input_order = np.arange(cfg.N) middle_img_idx = input_order[middle] # input_order = np.arange(cfg.N) # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) stacked_burst = torch.stack( [burst[input_order[x]] for x in range(cfg.input_N)], dim=1) cat_burst = torch.cat( [burst[input_order[x]] for x in range(cfg.input_N)], dim=1) # -- align images if necessary -- if cfg.abps_inputs: # scores,aligned = abp_search(cfg,burst) if cfg.lpas_method == "spoof": mtype = "global" acc = cfg.optical_flow_acc scores, aligned = lpas_spoof(burst, motion, cfg.nblocks, mtype, acc) else: ref_frame = (cfg.nframes + 1) // 2 nblocks = cfg.nblocks method = cfg.lpas_method results = lpas_search(burst, ref_frame, nblocks, motion, method) scores, aligned, dacc = results burst = aligned.clone() if True: images = [burst, raw_img] cropped = crop_center_patch(images, cfg.nframes, cfg.frame_size) burst, raw_img = cropped[0], cropped[1] if cfg.abps_inputs: aligned = crop_center_patch([aligned], spacing, cfg.frame_size)[0] burst = burst[:cfg.nframes] # -- denoising -- m_aligned, m_aligned_ave, denoised, denoised_ave, a_filters, d_filters = model( burst) denoised_ave = denoised_ave.detach() # if not cfg.input_with_middle_frame: # denoised_ave = model(cat_burst,stacked_burst)[1] # else: # denoised_ave = model(cat_burst,stacked_burst)[0][middle_img_idx] # denoised_ave = burst[middle_img_idx] - rec_res # -- compare with stacked targets -- raw_img = get_nmlz_tgt_img(cfg, raw_img) # denoised_ave = rescale_noisy_image(denoised_ave) # -- compute psnr -- loss = F.mse_loss(raw_img, denoised_ave, reduction='none').reshape(B, -1) # loss = F.mse_loss(raw_img,burst[cfg.input_N//2]+0.5,reduction='none').reshape(B,-1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = mse_to_psnr(loss) psnrs[batch_idx, :] = psnr if use_record: record_test = record_test.append({'psnr': psnr}, ignore_index=True) total_psnr += np.mean(psnr) total_loss += np.mean(loss) # if (batch_idx % cfg.test_log_interval) == 0: # root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/denoised_aves/N{cfg.N}/e{epoch}") # if not root.exists(): root.mkdir(parents=True) # fn = root / Path(f"b{batch_idx}.png") # nrow = int(np.sqrt(cfg.batch_size)) # denoised_ave = denoised_ave.detach().cpu() # grid_imgs = tv_utils.make_grid(denoised_ave, padding=2, normalize=True, nrow=nrow) # plt.imshow(grid_imgs.permute(1,2,0)) # plt.savefig(fn) # plt.close('all') if batch_idx % 100 == 0: print("[%d/%d] Test PSNR: %2.2f" % (batch_idx, num_batches, total_psnr / (batch_idx + 1)), flush=True) psnr_ave = np.mean(psnrs) psnr_std = np.std(psnrs) ave_loss = total_loss / num_batches print("[N: %d] Testing: [psnr: %2.2f +/- %2.2f] [ave loss %2.3e]" % (cfg.N, psnr_ave, psnr_std, ave_loss), flush=True) return psnr_ave, record_test
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses, writer): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.unet_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) model.unet_info.model = model.unet_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) noise_type = cfg.noise_params.ntype # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 dynamics_acc, dynamics_count = 0, 0 write_examples = False write_examples_iter = 200 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Load Pre-Simulated Random Numbers # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- transforms = [tvF.vflip, tvF.hflip, tvF.rotate] aug = RandomChoice(transforms) def apply_transformations(burst, gt_img): N, B = burst.shape[:2] gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w') all_images = torch.cat([gt_img_rs, burst], dim=0) all_images = rearrange(all_images, 'n b c h w -> (n b) c h w') tv_utils.save_image(all_images, 'aug_original.png', nrow=N + 1, normalize=True) aug_images = aug(all_images) tv_utils.save_image(aug_images, 'aug_augmented.png', nrow=N + 1, normalize=True) aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B) aug_gt_img = aug_images[0] aug_burst = aug_images[1:] return aug_burst, aug_gt_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Half Precision # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # model.align_info.model.half() # model.denoiser_info.model.half() # model.unet_info.model.half() # models = [model.align_info.model, # model.denoiser_info.model, # model.unet_info.model] # for model_l in models: # model_l.half() # for layer in model_l.modules(): # if isinstance(layer, torch.nn.BatchNorm2d): # layer.float() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha, gradient_L1=~cfg.supervised) # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=- # # Noise2Noise # # -=-=-=-=-=-=-=-=-=-=- noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: data_clock = Timer() clock = Timer() ds_size = len(train_loader) small_ds = ds_size < 500 steps_per_epoch = ds_size if not small_ds else 500 write_examples_iter = steps_per_epoch // 3 all_filters = [] # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- dynamics_acc_i = -1. if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) train_iter = iter(train_loader) for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: data_clock.tic() clock.tic() # -- grab data batch -- sample = next(train_iter) burst, raw_img, motion = sample['burst'], sample['clean'], sample[ 'directions'] raw_img_iid = sample['iid'] raw_img_iid = raw_img_iid.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) aligned, est_nnf = align_burst(cfg, burst, model) sim_images = subsample_aligned(cfg, aligned) burst_in, tgt_out = create_training_pairs(burst, sim_images) dn_losses = [] for burst, target in zip(burst_in, tgt_out): # -- forward pass -- est_denoised = model(burst) dn_loss = compute_denoising_loss(est_denoised, target) # -- compute grads -- if cfg.use_seed: torch.set_deterministic(False) dn_loss.backward() if cfg.use_seed: torch.set_deterministic(True) # -- backprop -- optim.step() scheduler.step() # -- store info -- losses.append(dn_loss.item()) # -- average over losses -- dn_loss = torch.mean(dn_losses) # -- alignment loss -- align_loss = compute_nnf_loss(gt_nnf, est_nnf) # -- total loss -- final_loss = dn_loss + align_loss running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- recompute model output for original images -- outputs = model(burst_og) m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) raw_img = get_nmlz_tgt_img(cfg, raw_img) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, m_aligned_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(burst_og, dim=0) if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave) mse_loss = F.mse_loss(raw_img, mis_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25)) # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25)) # -- psnr for [bm3d] -- mid_img_og = burst[N // 2] bm3d_nb_psnrs = [] M = 4 if B > 4 else B for b in range(M): bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) # maybe an issue here b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for input averaged frames -- # burst_ave = torch.mean(burst_og,dim=0) # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1) # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() # psnr_input_ave = np.mean(mse_to_psnr(mse_loss)) # psnr_input_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for aligned + denoised -- R = denoised.shape[1] raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1) # if noise_type == "qis": denoised = quantize_img(cfg,denoised) # save_image(denoised_ave,"denoised_ave.png") # save_image(denoised,"denoised.png") mse_loss = F.mse_loss(raw_img_repN, denoised, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- ave dynamic acc -- ave_dyn_acc = dynamics_acc / dynamics_count * 100. dynamics_acc, dynamics_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, ave_dyn_acc) #rec_ot_ave) #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e" % write_info, flush=True) # -- write to summary writer -- if writer: writer.add_scalar('train/running-loss', running_loss, cfg.global_step) writer.add_scalars('train/model-psnr', { 'ave': psnr, 'std': psnr_std }, cfg.global_step) writer.add_scalars('train/dn-frame-psnr', { 'ave': psnr_denoised_ave, 'std': psnr_denoised_std }, cfg.global_step) # -- reset loss -- running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, motion) if use_timer: clock.toc() if use_timer: print("data_clock", data_clock.average_time) print("clock", clock.average_time) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def train_loop_mse(cfg, model, optimizer, criterion, train_loader, epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 train_iter = iter(train_loader) K = cfg.sim_K noise_type = cfg.noise_params.ntype noise_level = cfg.noise_params['g']['stddev'] # raw_offset,raw_scale = 0,0 # if noise_type in ["g","hg"]: # raw_offset = 0.5 # if noise_type == "g": # noise_level = cfg.noise_params[noise_type]['stddev'] # elif noise_type == "hg": # noise_level = cfg.noise_params[noise_type]['read'] # elif noise_type == "qis": # noise_params = cfg.noise_params[noise_type] # noise_level = noise_params['readout'] # raw_scale = ( 2**noise_params['nbits']-1 ) / noise_params['alpha'] cfg.noise_params['qis']['alpha'] = 255.0 cfg.noise_params['qis']['readout'] = 0.0 cfg.noise_params['qis']['nbits'] = 8 noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) for batch_idx, (burst, res_img, raw_img, d) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # -- reshaping of data -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) T, B = burst.shape[:2] # -- anscombe -- # if cfg.use_anscombe: # burst = anscombe_nmlz.forward(cfg,burst+0.5) burst = rearrange(burst, 't b c h w -> (t b) c h w') denoised = model(burst) loss = compute_bootstrap_loss(denoised, B, T, R=100) loss = torch.mean(loss) loss_other = (1 / (cfg.global_step + 1.))**1.2 * F.mse_loss( burst, denoised) loss += loss_other # img0 = burst[0] # img1 = burst[1] # kindex_ds = kIndexPermLMDB(cfg.batch_size,cfg.N) # kindex = kindex_ds[batch_idx].cuda(non_blocking=True) # kindex = None # sim_burst = compute_similar_bursts(cfg,burst0,burst1,K,noise_level/255., # patchsize=cfg.sim_patchsize, # shuffle_k=cfg.sim_shuffleK, # kindex=kindex,only_middle=True, # search_method=cfg.sim_method, # db_level="frame") # # -- select outputs -- # # -- supervised -- # img0 = burst[0] # img1 = get_nmlz_img(cfg,raw_img) # if cfg.use_anscombe: img1 = anscombe_nmlz.forward(cfg,img1+0.5)-0.5 # -- noise2noise: mismatch noise -- # img0 = burst[0] # img1 = torch.normal(raw_img-0.5,75./255.) # -- noise2noise -- img0 = burst[0] img1 = burst[1] # img1 = noise_xform(raw_img) # img1 = img1.cuda(non_blocking=True) # raw_img = raw_img.cuda(non_blocking=True) # if cfg.use_anscombe: img1 = anscombe_nmlz.forward(cfg,img1+0.5)-0.5 # raw_img = raw_img.cuda(non_blocking=True) # tv_utils.save_image(img0,'noisy0.png') # tv_utils.save_image(img1,'noisy1.png') # img1 = burst[1] # -- noise2noise + one-denoising-level -- # img0 = burst[0] # img1 = burst[1] # if cfg.global_steps < 1000: img1 = burst[1] # else: img1 = model(burst[1]).detach() # -- noise2sim -- # img0 = burst[0] # img1 = sim_burst[0][:,0] # img0 = sim_burst[0][:,0] # img1 = sim_burst[0][:,1] # -- plot example input/output -- # plt_burst = rearrange(burst,'n b c h w -> (n b) c h w') # tv_utils.save_image(plt_burst,'burst.png',nrow=BS,normalize=True) # -- denoising -- # rec_img = model(img0) # -- compare with stacked burst -- # loss = F.mse_loss(raw_img,rec_img) # loss = F.mse_loss(img1,rec_img) # print_tensor_stats("img1",img1) # print_tensor_stats("rec",rec_img) # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: burst = rearrange(burst, '(t b) c h w -> t b c h w', t=T) rec_img = model(burst[0]) # -- anscombe -- print_tensor_stats("burst", burst) # if cfg.use_anscombe: # # rec_img = torch.clamp(rec_img+0.5,0)-0.5 # print_tensor_stats("rec",rec_img) # rec_img = anscombe_nmlz.backward(cfg,rec_img)-0.5 # print_tensor_stats("nmlz-rec",rec_img) # -- qis noise -- # if noise_type == "qis": # rec_img += 0.5 # rec_img *= 4 # rec_img = torch.round(rec_img) # rec_img = torch.clamp(rec_img,0,4) # rec_img /= 4 # rec_img -= 0.5 # rec_img = quantize_img(cfg,rec_img+0.5)-0.5 # rec_img = get_nmlz_img(cfg,rec_img+0.5) # -- raw image normalized for noise -- # raw_img = torch.round(7*raw_img)/7. - 0.5 # raw_img = get_nmlz_img(cfg,raw_img) # raw_img = get_nmlz_img(cfg,raw_img) # -- psnr finally -- loss = F.mse_loss(raw_img, rec_img + 0.5, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() psnr = mse_to_psnr(loss) psnr_ave = np.mean(psnr) psnr_std = np.std(psnr) # print( f"Ratio of noisy to clean: {img0.mean().item() / nmlz_raw.mean().item()}" ) # print_tensor_stats("img1",img1) print_tensor_stats("rec_img", rec_img + 0.5) print_tensor_stats("raw_img", raw_img) # print_tensor_stats("nmlz_raw",nmlz_raw) # tv_utils.save_image(img0,'learn_noisy0.png',nrow=BS,normalize=True) # tv_utils.save_image(rec_img,'learn_rec_img.png',nrow=BS,normalize=True) # tv_utils.save_image(raw_img,'learn_raw_img.png',nrow=BS,normalize=True) # tv_utils.save_image(nmlz_raw,'learn_nmlz_raw.png',nrow=BS,normalize=True) running_loss /= cfg.log_interval print("[%d/%d][%d/%d]: %2.3e [PSNR] %2.2f +/- %2.2f " % (epoch, cfg.epochs, batch_idx, len(train_loader), running_loss, psnr_ave, psnr_std)) running_loss = 0 cfg.global_steps += 1 total_loss /= len(train_loader) return total_loss
def test_loop_mse(cfg, model, criterion, test_loader, epoch): model.eval() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 noise_type = cfg.noise_params.ntype # raw_offset,raw_scale = 0,0 # if noise_type in ["g","hg"]: # noise_level = cfg.noise_params[noise_type]['stddev'] # raw_offset = 0.5 # elif noise_type == "qis": # params = cfg.noise_params[noise_type] # noise_level = params['readout'] # raw_scale = ( 2**params['nbits']-1 ) / params['alpha'] with torch.no_grad(): for batch_idx, (burst, res_img, raw_img, d) in enumerate(test_loader): BS = raw_img.shape[0] # reshaping of data raw_img = raw_img.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) img0 = burst[0] # -- anscombe -- # if cfg.use_anscombe: # img0 = anscombe_nmlz.forward(cfg,img0+0.5) - 0.5 # denoising rec_img = model(img0) # -- anscombe -- # if cfg.use_anscombe: # rec_img = anscombe_nmlz.backward(cfg,rec_img + 0.5) - 0.5 # compare with stacked targets # rec_img = rescale_noisy_image(rec_img) # if noise_type == "qis": rec_img = quantize_img(cfg,rec_img+0.5)-0.5 # nmlz_raw = get_nmlz_img(cfg,raw_img) loss = F.mse_loss(raw_img, rec_img + 0.5, reduction='none').reshape(BS, -1) loss = torch.mean(loss, 1).detach().cpu().numpy() # -- check for perfect matches -- # PSNR_MAX = 50 # if np.any(np.isinf(loss)): # loss = [] # for b in range(BS): # if np.isinf(loss[b]): loss.append(PSNR_MAX) # else: loss.append(loss[b]) psnr = mse_to_psnr(loss) total_psnr += np.mean(psnr) total_loss += np.mean(loss) if (batch_idx % cfg.test_log_interval) == 0: root = Path( f"{settings.ROOT_PATH}/output/mse/rec_imgs/e{epoch}") if not root.exists(): root.mkdir(parents=True) fn = root / Path(f"b{batch_idx}.png") nrow = int(np.sqrt(cfg.batch_size)) rec_img = rec_img.detach().cpu() grid_imgs = tv_utils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) plt.imshow(grid_imgs.permute(1, 2, 0)) plt.savefig(fn) plt.close('all') ave_psnr = total_psnr / len(test_loader) ave_loss = total_loss / len(test_loader) print("Testing results: Ave psnr %2.3e Ave loss %2.3e" % (ave_psnr, ave_loss)) return ave_psnr
def train_loop_offset(cfg,model,optimizer,criterion,train_loader,epoch): model.train() model = model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 sf_losses,sf_count = 0,0 kl_losses,kl_count = 0,0 temporal_losses,temporal_count = 0,0 write_examples = True write_examples_iter = 800 szm = ScaleZeroMean() record = init_record() use_record = False # if cfg.N != 5: return for batch_idx, (burst_imgs, res_imgs, raw_img, directions) in enumerate(train_loader): optimizer.zero_grad() model.zero_grad() # fig,ax = plt.subplots(figsize=(10,10)) # imgs = burst_imgs + 0.5 # imgs.clamp_(0.,1.) # raw_img = raw_img.expand(burst_imgs.shape) # print(imgs.shape,raw_img.shape) # all_img = torch.cat([imgs,raw_img],dim=1) # print(all_img.shape) # grids = [tv_utils.make_grid(all_img[i],nrow=16) for i in range(cfg.dynamic.frames)] # ims = [[ax.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in grids] # ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True) # Writer = animation.writers['ffmpeg'] # writer = Writer(fps=1, metadata=dict(artist='Me'), bitrate=1800) # ani.save(f"{settings.ROOT_PATH}/train_loop_voc.mp4", writer=writer) # print("I DID IT!") # return # -- reshaping of data -- # raw_img = raw_img.cuda(non_blocking=True) input_order = np.arange(cfg.N) # print("pre",input_order,cfg.blind,cfg.N) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = len(input_order) // 2 # print(middle) middle_img_idx = input_order[middle] # input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 input_order = np.arange(cfg.N) middle_img_idx = input_order[middle] # input_order = np.arange(cfg.N) # print("post",input_order,cfg.blind,cfg.N,middle_img_idx) burst_imgs = burst_imgs.cuda(non_blocking=True) # print(cfg.N,cfg.blind,[input_order[x] for x in range(cfg.input_N)]) # stacked_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) # print("stacked_burst",stacked_burst.shape) # print("burst_imgs.shape",burst_imgs.shape) # print("stacked_burst.shape",stacked_burst.shape) # -- add input noise -- burst_imgs_noisy = burst_imgs.clone() if cfg.input_noise: noise = np.random.rand() * cfg.input_noise_level if cfg.input_noise_middle_only: burst_imgs_noisy[middle_img_idx] = torch.normal(burst_imgs_noisy[middle_img_idx],noise) else: burst_imgs_noisy = torch.normal(burst_imgs_noisy,noise) # -- create inputs for kpn -- stacked_burst = torch.stack([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) cat_burst = torch.cat([burst_imgs_noisy[input_order[x]] for x in range(cfg.input_N)],dim=1) # print(stacked_burst.shape) # print(cat_burst.shape) # -- extract target image -- mid_img = burst_imgs[middle_img_idx] raw_img_zm = szm(raw_img.cuda(non_blocking=True)) if cfg.supervised: t_img = szm(raw_img.cuda(non_blocking=True)) else: t_img = burst_imgs[middle_img_idx] # -- direct denoising -- mis_ave = torch.mean(stacked_burst,dim=1) # aligned,rec_img,temporal_loss,filters = model(cat_burst,stacked_burst) aligned,rec_img,filters = model(cat_burst,stacked_burst) temporal_loss = torch.FloatTensor([-1.]).to(cfg.device) # print("(a) [m: %2.2e] [std: %2.2e] vs [tgt: %2.2e]" % (torch.mean(mid_img - raw_img_zm).item(),F.mse_loss(mid_img,raw_img_zm).item(),(25./255)**2) ) # r_raw_img_zm = raw_img_zm.unsqueeze(1).repeat(1,N,1,1,1) # print("(b) [m: %2.2e] [std: %2.2e] vs [tgt: %2.2e]" % (torch.mean(aligned - r_raw_img_zm).item(),F.mse_loss(aligned,r_raw_img_zm).item(),(25./255)**2) ) # -- compare with stacked burst -- # print(cfg.blind,t_img.min(),t_img.max(),t_img.mean()) # rec_img = rec_img.expand(t_img.shape) # loss = F.mse_loss(t_img,rec_img) # -- sparse filter loss (sf_loss) -- # sf_loss = sparse_filter_loss(filters) sf_loss = torch.FloatTensor([-1.]).to(cfg.device) # -- compute loss to optimize -- losses = criterion(aligned, rec_img, t_img, cfg.global_step) loss = np.sum(losses) #+ sf_loss + temporal_loss # loss = losses[1] kpn_loss = loss kpn_coeff = 1. # .9997**cfg.global_step # temporal_loss = temporal_loss.item() # mse_loss = F.mse_loss(rec_img,mid_img) # -- compute ot loss to optimize -- # residuals = aligned - rec_img.unsqueeze(1).repeat(1,N,1,1,1) # residuals = rearrange(residuals,'b n c h w -> b n (h w) c') # ot_loss = ot_pairwise_bp(residuals,reg=1.0,K=5) # ot_coeff = 1 - .997**cfg.global_step # -- compute kl loss to optimize -- if cfg.supervised: kl_ref = szm(raw_img.cuda(non_blocking=True)) else: kl_ref = rec_img residuals = aligned - kl_ref.unsqueeze(1).repeat(1,N,1,1,1) residuals = rearrange(residuals,'b n c h w -> b n (h w) c') kl_loss = kl_pairwise_bp(residuals,K=100,supervised=cfg.supervised) kl_coeff = 100# - .997**cfg.global_step # kl_loss = torch.FloatTensor([-1.]).to(cfg.device) # -- final loss -- # loss = ot_coeff * ot_loss + kpn_loss # loss = kl_coeff * kl_loss + kpn_coeff * kpn_loss loss = kpn_coeff * kpn_loss # -- update info -- running_loss += loss.item() total_loss += loss.item() # -- update sparse filter loss info -- sf_losses += sf_loss.item() sf_count += 1 # -- update temporal loss info -- temporal_losses += temporal_loss.item() temporal_count += 1 # -- update temporal loss info -- kl_losses += kl_loss.item() kl_count += 1 # -- BP and optimize -- loss.backward() optimizer.step() if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- compute mse for [rec img] -- BS = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) mse_loss = F.mse_loss(raw_img,rec_img+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() psnr_ave = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) running_loss /= cfg.log_interval # -- psnr for misaligned ave -- mse_loss = F.mse_loss(raw_img,mis_ave+0.5,reduction='none').reshape(BS,-1) mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() mis_psnr_ave = np.mean(mse_to_psnr(mse_loss)) mis_psnr_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [bm3d] -- bm3d_nb_psnrs = [] for b in range(BS): bm3d_rec = bm3d.bm3d(mid_img[b].cpu().transpose(0,2)+0.5, sigma_psd=25/255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0,2) b_loss = F.mse_loss(raw_img[b].cpu(),bm3d_rec,reduction='none').reshape(BS,-1) b_loss = torch.mean(b_loss,1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- temporal loss -- ave_temporal_loss = temporal_losses / temporal_count if temporal_count > 0 else 0 temporal_losses,temporal_count = 0,0 # -- sparse filter loss -- ave_sf_loss = sf_losses / sf_count if sf_count > 0 else 0 sf_losses,sf_count = 0,0 # -- kl loss -- ave_kl_loss = kl_losses / kl_count if kl_count > 0 else 0 kl_losses,kl_count = 0,0 # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx,len(train_loader),running_loss,psnr_ave,psnr_std,bm3d_nb_ave,bm3d_nb_std, mis_psnr_ave,mis_psnr_std,ave_temporal_loss,ave_sf_loss,ave_kl_loss) print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [misaligned]: %2.2f +/- %2.2f [loss-t]: %.2e [loss-sf]: %.2e [loss-kl]: %.2e" % write_info) # print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f"%(epoch, cfg.epochs, batch_idx, # len(train_loader), # running_loss,psnr_ave,psnr_std)) running_loss = 0 # -- record information -- if use_record: rec = rec_img raw = raw_img_zm frame_results = compute_ot_frame(aligned,rec,raw,reg=0.5) burst_results = compute_ot_burst(aligned,rec,raw,reg=0.5) psnr_record = {'psnr_ave':psnr_ave,'psnr_std':psnr_std} kpn_record = {'kpn_loss':kpn_loss} new_record = merge_records(frame_results,burst_results,psnr_record,kpn_record) record = record.append(new_record,ignore_index=True) # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0: write_input_output(cfg,model,stacked_burst,aligned,filters,directions) cfg.global_step += 1 total_loss /= len(train_loader) return total_loss,record
def test_loop_offset(cfg,model,criterion,test_loader,epoch): model.eval() model = model.to(cfg.device) total_psnr = 0 total_loss = 0 psnrs = np.zeros( (len(test_loader),cfg.batch_size) ) szm = ScaleZeroMean() with torch.no_grad(): for batch_idx, (burst_imgs, res_imgs, raw_img, directions) in enumerate(test_loader): BS = raw_img.shape[0] # -- selecting input frames -- input_order = np.arange(cfg.N) # print("pre",input_order) middle_img_idx = -1 if not cfg.input_with_middle_frame: middle = cfg.N // 2 # print(middle) middle_img_idx = input_order[middle] # input_order = np.r_[input_order[:middle],input_order[middle+1:]] else: middle = len(input_order) // 2 input_order = np.arange(cfg.N) middle_img_idx = input_order[middle] # input_order = np.arange(cfg.N) # -- reshaping of data -- raw_img = raw_img.cuda(non_blocking=True) burst_imgs = burst_imgs.cuda(non_blocking=True) stacked_burst = torch.stack([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) cat_burst = torch.cat([burst_imgs[input_order[x]] for x in range(cfg.input_N)],dim=1) # -- extract images for psnr -- mid_img = burst_imgs[middle_img_idx] raw_img_zm = szm(raw_img.cuda(non_blocking=True)) # -- denoising -- rec_img = model(cat_burst,stacked_burst)[1].detach() # if not cfg.input_with_middle_frame: # rec_img = model(cat_burst,stacked_burst)[1] # else: # rec_img = model(cat_burst,stacked_burst)[0][middle_img_idx] # rec_img = burst_imgs[middle_img_idx] - rec_res # -- compare with stacked targets -- rec_img = rescale_noisy_image(rec_img) # -- compute psnr -- loss = F.mse_loss(raw_img,rec_img,reduction='none').reshape(BS,-1) # loss = F.mse_loss(raw_img,burst_imgs[cfg.input_N//2]+0.5,reduction='none').reshape(BS,-1) loss = torch.mean(loss,1).detach().cpu().numpy() psnr = mse_to_psnr(loss) psnrs[batch_idx,:] = psnr total_psnr += np.mean(psnr) total_loss += np.mean(loss) # if (batch_idx % cfg.test_log_interval) == 0: # root = Path(f"{settings.ROOT_PATH}/output/n2n/offset_out_noise/rec_imgs/N{cfg.N}/e{epoch}") # if not root.exists(): root.mkdir(parents=True) # fn = root / Path(f"b{batch_idx}.png") # nrow = int(np.sqrt(cfg.batch_size)) # rec_img = rec_img.detach().cpu() # grid_imgs = tv_utils.make_grid(rec_img, padding=2, normalize=True, nrow=nrow) # plt.imshow(grid_imgs.permute(1,2,0)) # plt.savefig(fn) # plt.close('all') if (batch_idx % cfg.test_log_interval) == 0: print("[%d/%d] Running Test PSNR: %2.2f" % (batch_idx, len(test_loader), total_psnr / (batch_idx+1))) psnr_ave = np.mean(psnrs) psnr_std = np.std(psnrs) ave_loss = total_loss / len(test_loader) print("[N: %d] Testing: [psnr: %2.2f +/- %2.2f] [ave loss %2.3e]"%(cfg.N,psnr_ave,psnr_std,ave_loss)) return psnr_ave
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses, writer): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.unet_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) model.unet_info.model = model.unet_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) noise_type = cfg.noise_params.ntype # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 dynamics_acc, dynamics_count = 0, 0 write_examples = False write_examples_iter = 200 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Load Pre-Simulated Random Numbers # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- transforms = [tvF.vflip, tvF.hflip, tvF.rotate] aug = RandomChoice(transforms) def apply_transformations(burst, gt_img): N, B = burst.shape[:2] gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w') all_images = torch.cat([gt_img_rs, burst], dim=0) all_images = rearrange(all_images, 'n b c h w -> (n b) c h w') tv_utils.save_image(all_images, 'aug_original.png', nrow=N + 1, normalize=True) aug_images = aug(all_images) tv_utils.save_image(aug_images, 'aug_augmented.png', nrow=N + 1, normalize=True) aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B) aug_gt_img = aug_images[0] aug_burst = aug_images[1:] return aug_burst, aug_gt_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Half Precision # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # model.align_info.model.half() # model.denoiser_info.model.half() # model.unet_info.model.half() # models = [model.align_info.model, # model.denoiser_info.model, # model.unet_info.model] # for model_l in models: # model_l.half() # for layer in model_l.modules(): # if isinstance(layer, torch.nn.BatchNorm2d): # layer.float() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha, gradient_L1=~cfg.supervised) # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=- # # Noise2Noise # # -=-=-=-=-=-=-=-=-=-=- noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: data_clock = Timer() clock = Timer() ds_size = len(train_loader) small_ds = ds_size < 500 steps_per_epoch = ds_size if not small_ds else 500 write_examples_iter = steps_per_epoch // 3 all_filters = [] # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- dynamics_acc_i = -1. if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) train_iter = iter(train_loader) for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: data_clock.tic() clock.tic() # -- grab data batch -- if small_ds and batch_idx >= ds_size: if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) train_iter = iter(train_loader) # reset if too big sample = next(train_iter) burst, raw_img, motion = sample['burst'], sample['clean'], sample[ 'flow'] raw_img_iid = sample['iid'] raw_img_iid = raw_img_iid.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) # -- handle possibly cached simulated bursts -- if 'sim_burst' in sample: sim_burst = rearrange(sample['sim_burst'], 'b n k c h w -> n b k c h w') else: sim_burst = None non_sim_method = cfg.n2n or cfg.supervised if sim_burst is None and not (non_sim_method or cfg.abps): if sim_burst is None: if cfg.use_kindex_lmdb: kindex = kindex_ds[batch_idx].cuda(non_blocking=True) else: kindex = None query = burst[[N // 2]] database = torch.cat([burst[:N // 2], burst[N // 2 + 1:]]) sim_burst = compute_similar_bursts( cfg, query, database, cfg.sim_K, noise_level / 255., patchsize=cfg.sim_patchsize, shuffle_k=cfg.sim_shuffleK, kindex=kindex, only_middle=cfg.sim_only_middle, search_method=cfg.sim_method, db_level="frame") if (sim_burst is None) and cfg.abps: # scores,aligned = abp_search(cfg,burst) # scores,aligned = lpas_search(cfg,burst,motion) if cfg.lpas_method == "spoof": mtype = "global" acc = cfg.optical_flow_acc scores, aligned = lpas_spoof(burst, motion, cfg.nblocks, mtype, acc) else: ref_frame = (cfg.nframes + 1) // 2 nblocks = cfg.nblocks method = cfg.lpas_method scores, aligned, dacc = lpas_search(burst, ref_frame, nblocks, motion, method) dynamics_acc_i = dacc # scores,aligned = lpas_spoof(motion,accuracy=cfg.optical_flow_acc) # shuffled = shuffle_aligned_pixels_noncenter(aligned,cfg.nframes) nsims = cfg.nframes sim_aligned = create_sim_from_aligned(burst, aligned, nsims) burst_s = rearrange(burst, 't b c h w -> t b 1 c h w') sim_burst = torch.cat([burst_s, sim_aligned], dim=2) # print("sim_burst.shape",sim_burst.shape) # raw_img = raw_img.cuda(non_blocking=True)-0.5 # # print(np.sqrt(cfg.noise_params['g']['stddev'])) # print(motion) # tiled = tile_across_blocks(burst[[cfg.nframes//2]],cfg.nblocks) # rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2]) # for t in range(cfg.nframes): # save_image(tiled[0] - rep_burst[t],f"tiled_sub_burst_{t}.png") # save_image(aligned,"aligned.png") # print(aligned.shape) # # save_image(aligned[0] - aligned[cfg.nframes//2],"aligned_0.png") # # save_image(aligned[2] - aligned[cfg.nframes//2],"aligned_2.png") # M = (1+cfg.dynamic.ppf)*cfg.nframes # fs = cfg.dynamic.frame_size - M # fs = cfg.frame_size # cropped = crop_center_patch([burst,aligned,raw_img],cfg.nframes,cfg.frame_size) # burst,aligned,raw_img = cropped[0],cropped[1],cropped[2] # print(aligned.shape) # for t in range(cfg.nframes+1): # diff_t = aligned[t] - raw_img # spacing = cfg.nframes+1 # diff_t = crop_center_patch([diff_t],spacing,cfg.frame_size)[0] # print_tensor_stats(f"diff_aligned_{t}",diff_t) # save_image(diff_t,f"diff_aligned_{t}.png") # if t < cfg.nframes: # dt = aligned[t+1]-aligned[t] # dt = crop_center_patch([dt],spacing,cfg.frame_size)[0] # save_image(dt,f"dt_aligned_{t+1}m{t}.png") # save_image(aligned[t],f"aligned_{t}.png") # diff_t = tvF.crop(aligned[t] - raw_img,cfg.nframes,cfg.nframes,fs,fs) # print_tensor_stats(f"diff_aligned_{t}",diff_t) # save_image(burst,"burst.png") # save_image(burst[0] - burst[cfg.nframes//2],"burst_0.png") # save_image(burst[2] - burst[cfg.nframes//2],"burst_2.png") # exit() # print(sample['burst'].shape,sample['res'].shape) # b_clean = sample['burst'] - sample['res'] # scores,ave,t_aligned = test_abp_global_search(cfg,b_clean,noisy_img=burst) # burstBN = rearrange(burst,'n b c h w -> (b n) c h w') # tv_utils.save_image(burstBN,"abps_burst.png",normalize=True) # alignedBN = rearrange(aligned,'n b c h w -> (b n) c h w') # tv_utils.save_image(alignedBN,"abps_aligned.png",normalize=True) # rep_burst = burst[[N//2]].repeat(N,1,1,1,1) # deltaBN = rearrange(aligned - rep_burst,'n b c h w -> (b n) c h w') # tv_utils.save_image(deltaBN,"abps_delta.png",normalize=True) # b_clean_rep = b_clean[[N//2]].repeat(N,1,1,1,1) # tdeltaBN = rearrange(t_aligned - b_clean_rep.cpu(),'n b c h w -> (b n) c h w') # tv_utils.save_image(tdeltaBN,"abps_tdelta.png",normalize=True) if non_sim_method: sim_burst = burst.unsqueeze(2).repeat(1, 1, 2, 1, 1, 1) else: sim_burst = sim_burst.cuda(non_blocking=True) if use_timer: data_clock.toc() # -- to cuda -- burst = burst.cuda(non_blocking=True) raw_zm_img = szm(raw_img.cuda(non_blocking=True)) # anscombe.test(cfg,burst_og) # save_image(burst,f"burst_{batch_idx}_{cfg.n2n}.png") # -- crop images -- if True: #cfg.abps or cfg.abps_inputs: images = [burst, sim_burst, raw_img, raw_img_iid] spacing = burst.shape[0] # we use frames as spacing cropped = crop_center_patch(images, spacing, cfg.frame_size) burst, sim_burst = cropped[0], cropped[1] raw_img, raw_img_iid = cropped[2], cropped[3] if cfg.abps or cfg.abps_inputs: aligned = crop_center_patch([aligned], spacing, cfg.frame_size)[0] # print_tensor_stats("d-eq?",burst[-1] - aligned[-1]) burst = burst[:cfg.nframes] # last frame is target # -- getting shapes of data -- N, B, C, H, W = burst.shape burst_og = burst.clone() # -- shuffle over Simulated Samples -- k_ins, k_outs = create_k_grid(sim_burst, shuffle=True) k_ins, k_outs = [k_ins[0]], [k_outs[0]] # k_ins,k_outs = create_k_grid_v3(sim_burst) for k_in, k_out in zip(k_ins, k_outs): if k_in == k_out: continue # -- zero gradients; ready 2 go -- model.align_info.model.zero_grad() model.align_info.optim.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() model.unet_info.model.zero_grad() model.unet_info.optim.zero_grad() # -- compute input/output data -- if cfg.sim_only_middle and (not cfg.abps): # sim_burst.shape == T,B,K,C,H,W midi = 0 if sim_burst.shape[0] == 1 else N // 2 left_burst, right_burst = burst[:N // 2], burst[N // 2 + 1:] cat_burst = [ left_burst, sim_burst[[midi], :, k_in], right_burst ] burst = torch.cat(cat_burst, dim=0) mid_img = sim_burst[midi, :, k_out] elif cfg.abps and (not cfg.abps_inputs): # -- v1 -- mid_img = aligned[-1] # -- v2 -- # left_aligned,right_aligned = aligned[:N//2],aligned[N//2+1:] # nc_aligned = torch.cat([left_aligned,right_aligned],dim=0) # shuf = shuffle_aligned_pixels(nc_aligned,cfg.nframes) # mid_img = shuf[1] # ---- v3 ---- # shuf = shuffle_aligned_pixels(aligned) # shuf = aligned[[N//2,0]] # midi = 0 if sim_burst.shape[0] == 1 else N//2 # left_burst,right_burst = burst[:N//2],burst[N//2+1:] # burst = torch.cat([left_burst,shuf[[0]],right_burst],dim=0) # nc_burst = torch.cat([left_burst,right_burst],dim=0) # shuf = shuffle_aligned_pixels(aligned) # ---- v4 ---- # nc_shuf = shuffle_aligned_pixels(nc_aligned) # mid_img = nc_shuf[0] # pick = npr.randint(0,2,size=(1,))[0] # mid_img = nc_aligned[pick] # mid_img = shuf[1] # save_image(shuf,"shuf.png") # print(shuf.shape) # diff = raw_img.cuda(non_blocking=True) - aligned[0] # mean = torch.mean(diff).item() # std = torch.std(diff).item() # print(mean,std) # -- v1 -- # burst = burst # notMid = sample_not_mid(N) # mid_img = aligned[notMid] elif cfg.abps_inputs: burst = aligned.clone() burst_og = aligned.clone() mid_img = shuffle_aligned_pixels(burst, cfg.nframes)[0] else: burst = sim_burst[:, :, k_in] mid_img = sim_burst[N // 2, :, k_out] # mid_img = sim_burst[N//2,:] # print(burst.shape,mid_img.shape) # print(F.mse_loss(burst,mid_img).item()) if cfg.supervised: gt_img = get_nmlz_tgt_img(cfg, raw_img).cuda(non_blocking=True) elif cfg.n2n: gt_img = raw_img_iid #noise_xform(raw_img).cuda(non_blocking=True) else: gt_img = mid_img # another = noise_xform(raw_img).cuda(non_blocking=True) # print_tensor_stats("a-iid?",raw_img_iid.cuda() - raw_img.cuda()) # print_tensor_stats("b-iid?",mid_img.cuda() - raw_img.cuda()) # print_tensor_stats("c-iid?",mid_img.cuda() - another) # print_tensor_stats("d-iid?",raw_img_iid.cuda() - another) # print_tensor_stats("e-iid?",mid_img.cuda() - raw_img_iid.cuda()) # for bt in range(cfg.nframes): # tiled = tile_across_blocks(burst[[bt]],cfg.nblocks) # rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2]) # for t in range(cfg.nframes): # save_image(tiled[0] - rep_burst[t],f"tiled_{bt}_sub_burst_{t}.png") # print_tensor_stats(f"delta_{bt}_{t}",tiled[0,:,4] - burst[t]) # raw_img = raw_img.cuda(non_blocking=True) - 0.5 # print_tensor_stats("gt_img - raw",gt_img - raw_img) # # save_image(gt_img,"gt.png") # # save_image(raw,"raw.png") # save_image(gt_img - raw_img,"gt_sub_raw.png") # print_tensor_stats("burst[N//2] - raw",burst[N//2] - raw_img) # save_image(burst[N//2] - raw_img,"burst_sub_raw.png") # print_tensor_stats("burst[N//2] - gt_img",burst[N//2] - gt_img) # save_image(burst[N//2] - gt_img,"burst_sub_gt.png") # print_tensor_stats("aligned[N//2] - raw",aligned[N//2] - raw_img) # save_image(aligned[N//2] - raw_img,"aligned_sub_raw.png") # print_tensor_stats("aligned[N//2] - burst[N//2]", # aligned[N//2] - burst[N//2]) # save_image(aligned[N//2] - burst[N//2],"aligned_sub_burst.png") # gt_img = torch.normal(raw_zm_img,noise_level/255.) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # burst,gt_img = apply_transformations(burst,gt_img) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- outputs = model(burst) m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Decrease Entropy within a Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_entropy = 0 filters_entropy_coeff = 0. # 1000. all_filters = [] L = len(align_hook.filters) iter_filters = align_hook.filters if L > 0 else [aligned_filters] for filters in iter_filters: f_shape = 'b n k2 c h w -> (b n c h w) k2' filters_shaped = rearrange(filters, f_shape) filters_entropy += one #entropyLoss(filters_shaped) all_filters.append(filters) if L > 0: filters_entropy /= L all_filters = torch.stack(all_filters, dim=1) align_hook.clear() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = [F.mse_loss(denoised_ave, gt_img)] # losses = denoiseLossMSE(denoised,denoised_ave,gt_img,cfg.global_step) # losses = [ one, one ] # ave_loss,burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) # rec_mse = F.mse_loss(denoised_ave,gt_img) rec_mse_coeff = 1. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- gt_img_rep = gt_img.unsqueeze(1).repeat(1, denoised.shape[1], 1, 1, 1) residuals = denoised - gt_img_rep rec_ot = torch.FloatTensor([0.]).to(cfg.device) # rec_ot = kl_gaussian_bp(residuals,noise_level,flip=True) # rec_ot = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16) if torch.any(torch.isnan(rec_ot)): rec_ot = torch.FloatTensor([0.]).to(cfg.device) if torch.any(torch.isinf(rec_ot)): rec_ot = torch.FloatTensor([0.]).to(cfg.device) rec_ot_coeff = 0. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- rec_loss = rec_mse_coeff * rec_mse + rec_ot_coeff * rec_ot final_loss = rec_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot.item() rec_ot_count += 1 # -- dynamic acc - dynamics_acc += dynamics_acc_i dynamics_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- if cfg.use_seed: torch.set_deterministic(False) final_loss.backward() if cfg.use_seed: torch.set_deterministic(True) # -- backprop now. -- model.align_info.optim.step() model.denoiser_info.optim.step() model.unet_info.optim.step() scheduler.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- recompute model output for original images -- outputs = model(burst_og) m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) raw_img = get_nmlz_tgt_img(cfg, raw_img) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, m_aligned_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(burst_og, dim=0) if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave) mse_loss = F.mse_loss(raw_img, mis_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25)) # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25)) # -- psnr for [bm3d] -- mid_img_og = burst[N // 2] bm3d_nb_psnrs = [] M = 4 if B > 4 else B for b in range(M): bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) # maybe an issue here b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for input averaged frames -- # burst_ave = torch.mean(burst_og,dim=0) # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1) # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() # psnr_input_ave = np.mean(mse_to_psnr(mse_loss)) # psnr_input_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for aligned + denoised -- R = denoised.shape[1] raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1) # if noise_type == "qis": denoised = quantize_img(cfg,denoised) # save_image(denoised_ave,"denoised_ave.png") # save_image(denoised,"denoised.png") mse_loss = F.mse_loss(raw_img_repN, denoised, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- ave dynamic acc -- ave_dyn_acc = dynamics_acc / dynamics_count * 100. dynamics_acc, dynamics_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, ave_dyn_acc) #rec_ot_ave) #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e" % write_info, flush=True) # -- write to summary writer -- if writer: writer.add_scalar('train/running-loss', running_loss, cfg.global_step) writer.add_scalars('train/model-psnr', { 'ave': psnr, 'std': psnr_std }, cfg.global_step) writer.add_scalars('train/dn-frame-psnr', { 'ave': psnr_denoised_ave, 'std': psnr_denoised_std }, cfg.global_step) # -- reset loss -- running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, motion) if use_timer: clock.toc() if use_timer: print("data_clock", data_clock.average_time) print("clock", clock.average_time) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses