def __init__(self, root, split, edition, nframes, noise_info, nnf_K, nnf_ps, nnf_exists=True): self.root = root paths = get_kitti_path(root) self.edition = edition self.paths = paths self.split = split self.istest = split == "test" self.nframes = nframes self.noise_info = noise_info self.nnf_K = nnf_K self.nnf_ps = nnf_ps self.nnf_exists = nnf_exists self.read_resize = (370, 1224) parts = self._get_split_parts_name(split) self.dataset = self._read_dataset_paths(paths, edition, parts, nframes, self.read_resize, nnf_K, nnf_ps, nnf_exists) self.noise_xform = get_noise_transform(noise_info, use_to_tensor=False) print(split, len(self.dataset['burst_id']))
def init_exp(cfg,exp): # -- set patchsize -- cfg.patchsize = int(exp.patchsize) # -- set patchsize -- cfg.nframes = int(exp.nframes) cfg.N = cfg.nframes # -- set number of blocks (old: neighborhood size) -- cfg.nblocks = int(exp.nblocks) cfg.nh_size = cfg.nblocks # old name # -- get noise function -- nconfig = get_noise_config(cfg,exp.noise_type) noise_xform = get_noise_transform(nconfig,use_to_tensor=False) # -- get dynamics function -- cfg.dynamic.ppf = exp.ppf cfg.dynamic.bool = True cfg.dynamic.random_eraser = False cfg.dynamic.frame_size = cfg.frame_size cfg.dynamic.total_pixels = cfg.dynamic.ppf*(cfg.nframes-1) cfg.dynamic.frames = exp.nframes def nonoise(image): return image dynamic_info = cfg.dynamic dynamic_raw_xform = get_dynamic_transform(dynamic_info,nonoise) dynamic_xform = dynamic_wrapper(dynamic_raw_xform) # -- get score function -- score_function = get_score_function(exp.score_function) return noise_xform,dynamic_xform,score_function
def __init__(self, root, split, isize, nsamples, noise_info, dynamic_info): # -- set init params -- self.root = root self.split = split self.noise_info = noise_info self.dynamic_info = dynamic_info self.nsamples = nsamples self.isize = isize # -- create transforms -- self.noise_trans = get_noise_transform(noise_info, noise_only=True) self.dynamic_trans = get_dynamic_transform(dynamic_info, None, load_res) # -- load paths -- self.paths = [] # -- limit num of samples -- self.indices = enumerate_indices(len(self.paths), nsamples) self.nsamples = len(self.indices) # -- single random dynamics -- self.dyn_once = return_optional(dynamic_info, "sim_once", False) self.fixRandDynamics = RandomOnce(self.dyn_once, nsamples) # -- single random noise -- self.noise_once = return_optional(noise_info, "sim_once", False) self.fixRandNoise_1 = RandomOnce(self.noise_once, nsamples) self.fixRandNoise_2 = RandomOnce(self.noise_once, nsamples)
def transforms_from_cfg(cfg): # -- noise transform -- noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -- simple functions for compat. -- def dynamic_wrapper(dynamic_raw_xform): def wrapped(image): pil_image = tvT.ToPILImage()(image).convert("RGB") results = dynamic_raw_xform(pil_image) burst = results[0] + 0.5 flow = results[3] return burst, flow return wrapped def nonoise(image): return image # -- dynamics -- dynamic_info = cfg.dynamic dynamic_raw_xform = get_dynamic_transform(dynamic_info, nonoise) dynamic_xform = dynamic_wrapper(dynamic_raw_xform) return noise_xform, dynamic_xform
def __init__(self, iroot, froot, sroot, split, isize, ps, nsamples, nframes, noise_info): # -- set init params -- self.iroot = iroot self.froot = froot self.sroot = sroot self.split = split self.noise_info = noise_info self.ps = ps self.nsamples = nsamples self.isize = isize # -- create transforms -- self.noise_trans = get_noise_transform(noise_info, noise_only=True) # -- load paths -- self.paths, self.nframes, all_eq = read_files(iroot, froot, sroot, split, isize, ps, nframes) if not (all_eq): print("\n\n\n\nWarning: Not all bursts are same length!!!\n\n\n\n") self.groups = sorted(list(self.paths['images'].keys())) # -- limit num of samples -- self.indices = enumerate_indices(len(self.paths['images']), nsamples) self.nsamples = len(self.indices) # -- single random noise -- self.noise_once = return_optional(noise_info, "sim_once", False) self.fixRandNoise_1 = RandomOnce(self.noise_once, nsamples) self.fixRandNoise_2 = RandomOnce(self.noise_once, nsamples)
def get_align_noise(align_name): if align_name == "same": # use the same noisy samples given def align_noise_fxn(noisy, clean): return noisy else: # use a different noisy sample with a different noise level apply_noise = get_noise_transform(align_name, noise_only=True) def align_noise_fxn(noisy, clean): return apply_noise(clean) return align_noise_fxn
def single_image_unet(cfg, queue, full_image, device): full_image = full_image.to(device) image = tvF.crop(full_image, 128, 128, 32, 32) T = 5 # -- poisson noise -- noise_type = "pn" cfg.noise_type = noise_type cfg.noise_params['pn']['alpha'] = 40.0 cfg.noise_params['pn']['readout'] = 0.0 cfg.noise_params.ntype = cfg.noise_type noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) clean = torch.stack([image for i in range(T)], dim=0) noisy = noise_xform(clean) save_image(clean, "clean.png", normalize=True) m_clean, m_noisy = clean.clone(), noisy.clone() for i in range(T // 2): image_mis = tvF.crop(full_image, 128 + 1, 128, 32, 32) m_clean[i] = image_mis m_noisy[i] = noise_xform(m_clean[i]) noise_level = 50. / 255. # -- model all -- print("-- All Aligned --") model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean, noisy, model, optim) results = test(cfg, image, clean, noisy, model, 0) rec = model(noisy) + 0.5 save_image(rec, "rec_all.png", normalize=True) print("Single Image Unet:") print(images_to_psnrs(clean, rec)) print(images_to_psnrs(rec - 0.5, noisy)) print(images_to_psnrs(rec[[0]], rec[[1]])) print(images_to_psnrs(rec[[0]], rec[[2]])) print(images_to_psnrs(rec[[1]], rec[[2]])) og_clean, og_noisy = clean.clone(), noisy.clone() clean[0] = image_mis noisy[0] = noise_xform(clean[[0]])[0] # -- model all -- print("-- All Misligned --") model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean, noisy, model, optim) results = test(cfg, image, clean, noisy, model, 0) rec = model(noisy) + 0.5 save_image(rec, "rec_all.png", normalize=True) print("Single Image Unet:") print(images_to_psnrs(clean, rec)) print(images_to_psnrs(rec - 0.5, noisy)) print(images_to_psnrs(rec[[0]], rec[[1]])) print(images_to_psnrs(rec[[0]], rec[[2]])) print(images_to_psnrs(rec[[1]], rec[[2]])) for j in range(3): # -- data -- noisy1 = torch.stack([noisy[0], noisy[1]], dim=0) clean1 = torch.stack([clean[0], clean[1]], dim=0) noisy2 = torch.stack([noisy[1], noisy[2]], dim=0) clean2 = torch.stack([clean[1], clean[2]], dim=0) noisy3 = torch.stack([og_noisy[0], noisy[1]], dim=0) clean3 = torch.stack([og_clean[0], clean[1]], dim=0) # -- model 1 -- model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean1, noisy1, model, optim) results = test(cfg, image, clean1, noisy1, model, 0) rec = model(noisy1) + 0.5 xrec = model(noisy2) + 0.5 save_image(rec, "rec1.png", normalize=True) print("[misaligned] Single Image Unet:", images_to_psnrs(clean1, rec), images_to_psnrs(clean2, xrec), images_to_psnrs(rec - 0.5, noisy1), images_to_psnrs(xrec - 0.5, noisy2), images_to_psnrs(rec[[0]], rec[[1]]), images_to_psnrs(xrec[[0]], xrec[[1]]), images_to_psnrs(xrec[[0]], rec[[1]]), images_to_psnrs(xrec[[1]], rec[[0]])) # -- model 2 -- model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean2, noisy2, model, optim) results = test(cfg, image, clean2, noisy2, model, 0) rec = model(noisy2) + 0.5 xrec = model(noisy1) + 0.5 save_image(rec, "rec2.png", normalize=True) print("[aligned] Single Image Unet:", images_to_psnrs(clean2, rec), images_to_psnrs(clean1, xrec), images_to_psnrs(rec - 0.5, noisy2), images_to_psnrs(xrec - 0.5, noisy1), images_to_psnrs(rec[[0]], rec[[1]]), images_to_psnrs(xrec[[0]], xrec[[1]]), images_to_psnrs(xrec[[0]], rec[[1]]), images_to_psnrs(xrec[[1]], rec[[0]])) # -- model 3 -- model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean3, noisy3, model, optim) results = test(cfg, image, clean3, noisy3, model, 0) rec = model(noisy3) + 0.5 rec_2 = model(noisy2) + 0.5 rec_1 = model(noisy1) + 0.5 save_image(rec, "rec1.png", normalize=True) print("[aligned (v3)] Single Image Unet:") print("clean-rec", images_to_psnrs(clean3, rec)) print("clean1-rec1", images_to_psnrs(clean1, rec_1)) print("clean2-rec2", images_to_psnrs(clean2, rec_2)) print("rec-noisy3", images_to_psnrs(rec - 0.5, noisy3)) print("rec1-noisy1", images_to_psnrs(rec_1 - 0.5, noisy1)) print("rec2-noisy2", images_to_psnrs(rec_2 - 0.5, noisy2)) print("[v3]: rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]])) print("[v1]: rec0-rec1", images_to_psnrs(rec_1[[0]], rec_1[[1]])) print("[v2]: rec0-rec1", images_to_psnrs(rec_2[[0]], rec_2[[1]])) print("[v1-v2](a):", images_to_psnrs(rec_1[[1]], rec_2[[1]])) print("[v1-v2](b):", images_to_psnrs(rec_1[[1]], rec_2[[0]])) print("[v1-v2](c):", images_to_psnrs(rec_1[[0]], rec_2[[1]])) print("[v1-v2](d):", images_to_psnrs(rec_1[[0]], rec_2[[0]])) print("-" * 20) print("[v2-v3](a):", images_to_psnrs(rec_2[[1]], rec[[1]])) print("[v2-v3](b):", images_to_psnrs(rec_2[[1]], rec[[0]])) print("[v2-v3](c):", images_to_psnrs(rec_2[[0]], rec[[1]])) print("[v2-v3](d):", images_to_psnrs(rec_2[[0]], rec[[0]])) print("-" * 20) print("[v1-v3](a):", images_to_psnrs(rec_1[[1]], rec[[1]])) print("[v1-v3](b):", images_to_psnrs(rec_1[[1]], rec[[0]])) print("[v1-v3](c):", images_to_psnrs(rec_1[[0]], rec[[1]])) print("[v1-v3](d):", images_to_psnrs(rec_1[[0]], rec[[0]])) print("-" * 20) print("rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]]))
def run_experiment(cfg, data, record_fn, bss_fn): # -- setup noise -- cfg.noise_type = 'g' cfg.ntype = cfg.noise_type cfg.noise_params.ntype = cfg.noise_type noise_level = 50. cfg.noise_params['g']['stddev'] = noise_level noise_level_str = f"{int(noise_level)}" # nconfig = get_noise_config(cfg,exp.noise_type) noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -- set configs -- T = cfg.nframes H = cfg.nblocks # -- create our neighborhood -- full_image = data.tr[0][2] clean = [] # tl_list = [[0,0],[1,0],[0,1]] tl_list = np.zeros((T, 2)).astype(np.int) #[[0,0],[0,0],[0,0]] for t in range(T): clean_t = [] t, l = tl_list[t] for i in range(-H // 2 + 1, H // 2 + 1): for j in range(-H // 2 + 1, H // 2 + 1): clean_t.append( tvF.crop(full_image, t + 128 + i, l + 128 + j, 32, 32)) clean_t = torch.stack(clean_t, dim=0) clean.append(clean_t) clean = torch.stack(clean, dim=0) REF_H = get_ref_block_index(cfg.nblocks) image = clean[T // 2, REF_H] # -- normalize -- clean -= clean.min() clean /= clean.max() save_image(clean, "fast_unet_clean.png", normalize=True) save_image(image, "fast_unet_image.png", normalize=True) # -- apply noise -- # torch.manual_seed(123) noisy = noise_xform(clean) # aveT = torch.mean(burst,dim=0) # print("Ave MSE: %2.3e" % images_to_psnrs(aveT.unsqueeze(0),image.unsqueeze(0))) record = [] gridT = torch.arange(T) # -- exhaustive -- block_search_space = get_block_arangements_subset(cfg.nblocks, cfg.nframes, tcount=3) # block_search_space = get_block_arangements(cfg.nblocks,cfg.nframes) # -- random subset -- use_rand = True if use_rand: if bss_fn.exists() and False: print(f"Reading bss {bss_fn}") block_search_space = np.load(bss_fn, allow_pickle=True) else: bss = block_search_space print(f"Original block search space: [{len(bss)}]") if len(block_search_space) >= 100: rand_blocks = random.sample(list(block_search_space), 100) block_search_space = [ np.array([REF_H] * T), ] # include gt block_search_space.extend(rand_blocks) bss = block_search_space print(f"Writing block search space: [{bss_fn}]") np.save(bss_fn, np.array(block_search_space)) print(f"Search Space Size: {len(block_search_space)}") idx = 0 clean = clean.to(cfg.device) noisy = noisy.to(cfg.device) for prop in tqdm(block_search_space): # -- fetch -- clean_prop = clean[gridT, prop] noisy_prop = noisy[gridT, prop] if idx == 49 or idx == 40 or idx == 40: save_image(noisy_prop, f"noisy_prop_{idx}.png", normalize=True) save_image(clean_prop, f"clean_prop_{idx}.png", normalize=True) # -- compute again -- train_steps = 500 cog = COG(UNet_small, T, noisy.device, nn_params=None, train_steps=train_steps) cog.train_models(noisy_prop) recs = cog.test_models(noisy_prop) score = cog.operator_consistency(recs, noisy_prop) results = fill_results(cfg, image, clean_prop, noisy_prop, None, idx) results['cog'] = score print(score, prop) # -- compute -- # model = [UNet_small(3),UNet_small(3)] # model = UNet_small(3) # UNet_n2n(1) # cfg.init_lr = 1e-4 # optim = torch.optim.Adam(model.parameters(),lr=cfg.init_lr,betas=(0.9,0.99)) # train(cfg,image,clean_prop,noisy_prop,model,optim) # results = test(cfg,image,clean_prop,noisy_prop,model,idx) # -- update -- record.append(results) idx += 1 record = pd.DataFrame(record) print(f"Writing record to {record_fn}") record.to_csv(record_fn) return record
def train_loop(cfg, model, scheduler, train_loader, epoch, record_losses, writer): # -=-=-=-=-=-=-=-=-=-=- # # Setup for epoch # # -=-=-=-=-=-=-=-=-=-=- model.align_info.model.train() model.denoiser_info.model.train() model.unet_info.model.train() model.denoiser_info.model = model.denoiser_info.model.to(cfg.device) model.align_info.model = model.align_info.model.to(cfg.device) model.unet_info.model = model.unet_info.model.to(cfg.device) N = cfg.N total_loss = 0 running_loss = 0 szm = ScaleZeroMean() blocksize = 128 unfold = torch.nn.Unfold(blocksize, 1, 0, blocksize) use_record = False if record_losses is None: record_losses = pd.DataFrame({ 'burst': [], 'ave': [], 'ot': [], 'psnr': [], 'psnr_std': [] }) noise_type = cfg.noise_params.ntype # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- align_mse_losses, align_mse_count = 0, 0 rec_mse_losses, rec_mse_count = 0, 0 rec_ot_losses, rec_ot_count = 0, 0 running_loss, total_loss = 0, 0 dynamics_acc, dynamics_count = 0, 0 write_examples = False write_examples_iter = 200 noise_level = cfg.noise_params['g']['stddev'] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Load Pre-Simulated Random Numbers # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if cfg.use_kindex_lmdb: kindex_ds = kIndexPermLMDB(cfg.batch_size, cfg.N) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- transforms = [tvF.vflip, tvF.hflip, tvF.rotate] aug = RandomChoice(transforms) def apply_transformations(burst, gt_img): N, B = burst.shape[:2] gt_img_rs = rearrange(gt_img, 'b c h w -> 1 b c h w') all_images = torch.cat([gt_img_rs, burst], dim=0) all_images = rearrange(all_images, 'n b c h w -> (n b) c h w') tv_utils.save_image(all_images, 'aug_original.png', nrow=N + 1, normalize=True) aug_images = aug(all_images) tv_utils.save_image(aug_images, 'aug_augmented.png', nrow=N + 1, normalize=True) aug_images = rearrange(aug_images, '(n b) c h w -> n b c h w', b=B) aug_gt_img = aug_images[0] aug_burst = aug_images[1:] return aug_burst, aug_gt_img # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Half Precision # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # model.align_info.model.half() # model.denoiser_info.model.half() # model.unet_info.model.half() # models = [model.align_info.model, # model.denoiser_info.model, # model.unet_info.model] # for model_l in models: # model_l.half() # for layer in model_l.modules(): # if isinstance(layer, torch.nn.BatchNorm2d): # layer.float() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Init Loss Functions # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- alignmentLossMSE = BurstRecLoss() denoiseLossMSE = BurstRecLoss(alpha=cfg.kpn_burst_alpha, gradient_L1=~cfg.supervised) # denoiseLossOT = BurstResidualLoss() entropyLoss = EntropyLoss() # -=-=-=-=-=-=-=-=-=-=-=-=- # # Add hooks for epoch # # -=-=-=-=-=-=-=-=-=-=-=-=- align_hook = AlignmentFilterHooks(cfg.N) align_hooks = [] for kpn_module in model.align_info.model.children(): for name, layer in kpn_module.named_children(): if name == "filter_cls": align_hook_handle = layer.register_forward_hook(align_hook) align_hooks.append(align_hook_handle) # -=-=-=-=-=-=-=-=-=-=- # # Noise2Noise # # -=-=-=-=-=-=-=-=-=-=- noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -=-=-=-=-=-=-=-=-=-=- # # Final Configs # # -=-=-=-=-=-=-=-=-=-=- use_timer = False one = torch.FloatTensor([1.]).to(cfg.device) switch = True if use_timer: data_clock = Timer() clock = Timer() ds_size = len(train_loader) small_ds = ds_size < 500 steps_per_epoch = ds_size if not small_ds else 500 write_examples_iter = steps_per_epoch // 3 all_filters = [] # -=-=-=-=-=-=-=-=-=-=- # # Start Epoch # # -=-=-=-=-=-=-=-=-=-=- dynamics_acc_i = -1. if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) train_iter = iter(train_loader) for batch_idx in range(steps_per_epoch): # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Setting up for Iteration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- setup iteration timer -- if use_timer: data_clock.tic() clock.tic() # -- grab data batch -- if small_ds and batch_idx >= ds_size: if cfg.use_seed: init = torch.initial_seed() torch.manual_seed(cfg.seed + 1 + epoch + init) train_iter = iter(train_loader) # reset if too big sample = next(train_iter) burst, raw_img, motion = sample['burst'], sample['clean'], sample[ 'flow'] raw_img_iid = sample['iid'] raw_img_iid = raw_img_iid.cuda(non_blocking=True) burst = burst.cuda(non_blocking=True) # -- handle possibly cached simulated bursts -- if 'sim_burst' in sample: sim_burst = rearrange(sample['sim_burst'], 'b n k c h w -> n b k c h w') else: sim_burst = None non_sim_method = cfg.n2n or cfg.supervised if sim_burst is None and not (non_sim_method or cfg.abps): if sim_burst is None: if cfg.use_kindex_lmdb: kindex = kindex_ds[batch_idx].cuda(non_blocking=True) else: kindex = None query = burst[[N // 2]] database = torch.cat([burst[:N // 2], burst[N // 2 + 1:]]) sim_burst = compute_similar_bursts( cfg, query, database, cfg.sim_K, noise_level / 255., patchsize=cfg.sim_patchsize, shuffle_k=cfg.sim_shuffleK, kindex=kindex, only_middle=cfg.sim_only_middle, search_method=cfg.sim_method, db_level="frame") if (sim_burst is None) and cfg.abps: # scores,aligned = abp_search(cfg,burst) # scores,aligned = lpas_search(cfg,burst,motion) if cfg.lpas_method == "spoof": mtype = "global" acc = cfg.optical_flow_acc scores, aligned = lpas_spoof(burst, motion, cfg.nblocks, mtype, acc) else: ref_frame = (cfg.nframes + 1) // 2 nblocks = cfg.nblocks method = cfg.lpas_method scores, aligned, dacc = lpas_search(burst, ref_frame, nblocks, motion, method) dynamics_acc_i = dacc # scores,aligned = lpas_spoof(motion,accuracy=cfg.optical_flow_acc) # shuffled = shuffle_aligned_pixels_noncenter(aligned,cfg.nframes) nsims = cfg.nframes sim_aligned = create_sim_from_aligned(burst, aligned, nsims) burst_s = rearrange(burst, 't b c h w -> t b 1 c h w') sim_burst = torch.cat([burst_s, sim_aligned], dim=2) # print("sim_burst.shape",sim_burst.shape) # raw_img = raw_img.cuda(non_blocking=True)-0.5 # # print(np.sqrt(cfg.noise_params['g']['stddev'])) # print(motion) # tiled = tile_across_blocks(burst[[cfg.nframes//2]],cfg.nblocks) # rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2]) # for t in range(cfg.nframes): # save_image(tiled[0] - rep_burst[t],f"tiled_sub_burst_{t}.png") # save_image(aligned,"aligned.png") # print(aligned.shape) # # save_image(aligned[0] - aligned[cfg.nframes//2],"aligned_0.png") # # save_image(aligned[2] - aligned[cfg.nframes//2],"aligned_2.png") # M = (1+cfg.dynamic.ppf)*cfg.nframes # fs = cfg.dynamic.frame_size - M # fs = cfg.frame_size # cropped = crop_center_patch([burst,aligned,raw_img],cfg.nframes,cfg.frame_size) # burst,aligned,raw_img = cropped[0],cropped[1],cropped[2] # print(aligned.shape) # for t in range(cfg.nframes+1): # diff_t = aligned[t] - raw_img # spacing = cfg.nframes+1 # diff_t = crop_center_patch([diff_t],spacing,cfg.frame_size)[0] # print_tensor_stats(f"diff_aligned_{t}",diff_t) # save_image(diff_t,f"diff_aligned_{t}.png") # if t < cfg.nframes: # dt = aligned[t+1]-aligned[t] # dt = crop_center_patch([dt],spacing,cfg.frame_size)[0] # save_image(dt,f"dt_aligned_{t+1}m{t}.png") # save_image(aligned[t],f"aligned_{t}.png") # diff_t = tvF.crop(aligned[t] - raw_img,cfg.nframes,cfg.nframes,fs,fs) # print_tensor_stats(f"diff_aligned_{t}",diff_t) # save_image(burst,"burst.png") # save_image(burst[0] - burst[cfg.nframes//2],"burst_0.png") # save_image(burst[2] - burst[cfg.nframes//2],"burst_2.png") # exit() # print(sample['burst'].shape,sample['res'].shape) # b_clean = sample['burst'] - sample['res'] # scores,ave,t_aligned = test_abp_global_search(cfg,b_clean,noisy_img=burst) # burstBN = rearrange(burst,'n b c h w -> (b n) c h w') # tv_utils.save_image(burstBN,"abps_burst.png",normalize=True) # alignedBN = rearrange(aligned,'n b c h w -> (b n) c h w') # tv_utils.save_image(alignedBN,"abps_aligned.png",normalize=True) # rep_burst = burst[[N//2]].repeat(N,1,1,1,1) # deltaBN = rearrange(aligned - rep_burst,'n b c h w -> (b n) c h w') # tv_utils.save_image(deltaBN,"abps_delta.png",normalize=True) # b_clean_rep = b_clean[[N//2]].repeat(N,1,1,1,1) # tdeltaBN = rearrange(t_aligned - b_clean_rep.cpu(),'n b c h w -> (b n) c h w') # tv_utils.save_image(tdeltaBN,"abps_tdelta.png",normalize=True) if non_sim_method: sim_burst = burst.unsqueeze(2).repeat(1, 1, 2, 1, 1, 1) else: sim_burst = sim_burst.cuda(non_blocking=True) if use_timer: data_clock.toc() # -- to cuda -- burst = burst.cuda(non_blocking=True) raw_zm_img = szm(raw_img.cuda(non_blocking=True)) # anscombe.test(cfg,burst_og) # save_image(burst,f"burst_{batch_idx}_{cfg.n2n}.png") # -- crop images -- if True: #cfg.abps or cfg.abps_inputs: images = [burst, sim_burst, raw_img, raw_img_iid] spacing = burst.shape[0] # we use frames as spacing cropped = crop_center_patch(images, spacing, cfg.frame_size) burst, sim_burst = cropped[0], cropped[1] raw_img, raw_img_iid = cropped[2], cropped[3] if cfg.abps or cfg.abps_inputs: aligned = crop_center_patch([aligned], spacing, cfg.frame_size)[0] # print_tensor_stats("d-eq?",burst[-1] - aligned[-1]) burst = burst[:cfg.nframes] # last frame is target # -- getting shapes of data -- N, B, C, H, W = burst.shape burst_og = burst.clone() # -- shuffle over Simulated Samples -- k_ins, k_outs = create_k_grid(sim_burst, shuffle=True) k_ins, k_outs = [k_ins[0]], [k_outs[0]] # k_ins,k_outs = create_k_grid_v3(sim_burst) for k_in, k_out in zip(k_ins, k_outs): if k_in == k_out: continue # -- zero gradients; ready 2 go -- model.align_info.model.zero_grad() model.align_info.optim.zero_grad() model.denoiser_info.model.zero_grad() model.denoiser_info.optim.zero_grad() model.unet_info.model.zero_grad() model.unet_info.optim.zero_grad() # -- compute input/output data -- if cfg.sim_only_middle and (not cfg.abps): # sim_burst.shape == T,B,K,C,H,W midi = 0 if sim_burst.shape[0] == 1 else N // 2 left_burst, right_burst = burst[:N // 2], burst[N // 2 + 1:] cat_burst = [ left_burst, sim_burst[[midi], :, k_in], right_burst ] burst = torch.cat(cat_burst, dim=0) mid_img = sim_burst[midi, :, k_out] elif cfg.abps and (not cfg.abps_inputs): # -- v1 -- mid_img = aligned[-1] # -- v2 -- # left_aligned,right_aligned = aligned[:N//2],aligned[N//2+1:] # nc_aligned = torch.cat([left_aligned,right_aligned],dim=0) # shuf = shuffle_aligned_pixels(nc_aligned,cfg.nframes) # mid_img = shuf[1] # ---- v3 ---- # shuf = shuffle_aligned_pixels(aligned) # shuf = aligned[[N//2,0]] # midi = 0 if sim_burst.shape[0] == 1 else N//2 # left_burst,right_burst = burst[:N//2],burst[N//2+1:] # burst = torch.cat([left_burst,shuf[[0]],right_burst],dim=0) # nc_burst = torch.cat([left_burst,right_burst],dim=0) # shuf = shuffle_aligned_pixels(aligned) # ---- v4 ---- # nc_shuf = shuffle_aligned_pixels(nc_aligned) # mid_img = nc_shuf[0] # pick = npr.randint(0,2,size=(1,))[0] # mid_img = nc_aligned[pick] # mid_img = shuf[1] # save_image(shuf,"shuf.png") # print(shuf.shape) # diff = raw_img.cuda(non_blocking=True) - aligned[0] # mean = torch.mean(diff).item() # std = torch.std(diff).item() # print(mean,std) # -- v1 -- # burst = burst # notMid = sample_not_mid(N) # mid_img = aligned[notMid] elif cfg.abps_inputs: burst = aligned.clone() burst_og = aligned.clone() mid_img = shuffle_aligned_pixels(burst, cfg.nframes)[0] else: burst = sim_burst[:, :, k_in] mid_img = sim_burst[N // 2, :, k_out] # mid_img = sim_burst[N//2,:] # print(burst.shape,mid_img.shape) # print(F.mse_loss(burst,mid_img).item()) if cfg.supervised: gt_img = get_nmlz_tgt_img(cfg, raw_img).cuda(non_blocking=True) elif cfg.n2n: gt_img = raw_img_iid #noise_xform(raw_img).cuda(non_blocking=True) else: gt_img = mid_img # another = noise_xform(raw_img).cuda(non_blocking=True) # print_tensor_stats("a-iid?",raw_img_iid.cuda() - raw_img.cuda()) # print_tensor_stats("b-iid?",mid_img.cuda() - raw_img.cuda()) # print_tensor_stats("c-iid?",mid_img.cuda() - another) # print_tensor_stats("d-iid?",raw_img_iid.cuda() - another) # print_tensor_stats("e-iid?",mid_img.cuda() - raw_img_iid.cuda()) # for bt in range(cfg.nframes): # tiled = tile_across_blocks(burst[[bt]],cfg.nblocks) # rep_burst = repeat(burst,'t b c h w -> t b g c h w',g=tiled.shape[2]) # for t in range(cfg.nframes): # save_image(tiled[0] - rep_burst[t],f"tiled_{bt}_sub_burst_{t}.png") # print_tensor_stats(f"delta_{bt}_{t}",tiled[0,:,4] - burst[t]) # raw_img = raw_img.cuda(non_blocking=True) - 0.5 # print_tensor_stats("gt_img - raw",gt_img - raw_img) # # save_image(gt_img,"gt.png") # # save_image(raw,"raw.png") # save_image(gt_img - raw_img,"gt_sub_raw.png") # print_tensor_stats("burst[N//2] - raw",burst[N//2] - raw_img) # save_image(burst[N//2] - raw_img,"burst_sub_raw.png") # print_tensor_stats("burst[N//2] - gt_img",burst[N//2] - gt_img) # save_image(burst[N//2] - gt_img,"burst_sub_gt.png") # print_tensor_stats("aligned[N//2] - raw",aligned[N//2] - raw_img) # save_image(aligned[N//2] - raw_img,"aligned_sub_raw.png") # print_tensor_stats("aligned[N//2] - burst[N//2]", # aligned[N//2] - burst[N//2]) # save_image(aligned[N//2] - burst[N//2],"aligned_sub_burst.png") # gt_img = torch.normal(raw_zm_img,noise_level/255.) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Dataset Augmentation # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # burst,gt_img = apply_transformations(burst,gt_img) # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Formatting Images for FP # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- stacked_burst = rearrange(burst, 'n b c h w -> b n c h w') cat_burst = rearrange(burst, 'n b c h w -> (b n) c h w') # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Foward Pass # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- outputs = model(burst) m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Decrease Entropy within a Kernel # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- filters_entropy = 0 filters_entropy_coeff = 0. # 1000. all_filters = [] L = len(align_hook.filters) iter_filters = align_hook.filters if L > 0 else [aligned_filters] for filters in iter_filters: f_shape = 'b n k2 c h w -> (b n c h w) k2' filters_shaped = rearrange(filters, f_shape) filters_entropy += one #entropyLoss(filters_shaped) all_filters.append(filters) if L > 0: filters_entropy /= L all_filters = torch.stack(all_filters, dim=1) align_hook.clear() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (MSE) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- losses = [F.mse_loss(denoised_ave, gt_img)] # losses = denoiseLossMSE(denoised,denoised_ave,gt_img,cfg.global_step) # losses = [ one, one ] # ave_loss,burst_loss = [loss.item() for loss in losses] rec_mse = np.sum(losses) # rec_mse = F.mse_loss(denoised_ave,gt_img) rec_mse_coeff = 1. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Reconstruction Losses (Distribution) # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- gt_img_rep = gt_img.unsqueeze(1).repeat(1, denoised.shape[1], 1, 1, 1) residuals = denoised - gt_img_rep rec_ot = torch.FloatTensor([0.]).to(cfg.device) # rec_ot = kl_gaussian_bp(residuals,noise_level,flip=True) # rec_ot = kl_gaussian_bp_patches(residuals,noise_level,flip=True,patchsize=16) if torch.any(torch.isnan(rec_ot)): rec_ot = torch.FloatTensor([0.]).to(cfg.device) if torch.any(torch.isinf(rec_ot)): rec_ot = torch.FloatTensor([0.]).to(cfg.device) rec_ot_coeff = 0. # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Final Losses # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- rec_loss = rec_mse_coeff * rec_mse + rec_ot_coeff * rec_ot final_loss = rec_loss # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Record Keeping # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- reconstruction MSE -- rec_mse_losses += rec_mse.item() rec_mse_count += 1 # -- reconstruction Dist. -- rec_ot_losses += rec_ot.item() rec_ot_count += 1 # -- dynamic acc - dynamics_acc += dynamics_acc_i dynamics_count += 1 # -- total loss -- running_loss += final_loss.item() total_loss += final_loss.item() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Gradients & Backpropogration # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # -- compute the gradients! -- if cfg.use_seed: torch.set_deterministic(False) final_loss.backward() if cfg.use_seed: torch.set_deterministic(True) # -- backprop now. -- model.align_info.optim.step() model.denoiser_info.optim.step() model.unet_info.optim.step() scheduler.step() # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- # # Printing to Stdout # # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- if (batch_idx % cfg.log_interval) == 0 and batch_idx > 0: # -- recompute model output for original images -- outputs = model(burst_og) m_aligned, m_aligned_ave, denoised, denoised_ave = outputs[:4] aligned_filters, denoised_filters = outputs[4:] # -- compute mse for fun -- B = raw_img.shape[0] raw_img = raw_img.cuda(non_blocking=True) raw_img = get_nmlz_tgt_img(cfg, raw_img) # -- psnr for [average of aligned frames] -- mse_loss = F.mse_loss(raw_img, m_aligned_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_aligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_aligned_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [average of input, misaligned frames] -- mis_ave = torch.mean(burst_og, dim=0) if noise_type == "qis": mis_ave = quantize_img(cfg, mis_ave) mse_loss = F.mse_loss(raw_img, mis_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_misaligned_ave = np.mean(mse_to_psnr(mse_loss)) psnr_misaligned_std = np.std(mse_to_psnr(mse_loss)) # tv_utils.save_image(raw_img,"raw.png",nrow=1,normalize=True,range=(-0.5,1.25)) # tv_utils.save_image(mis_ave,"mis.png",nrow=1,normalize=True,range=(-0.5,1.25)) # -- psnr for [bm3d] -- mid_img_og = burst[N // 2] bm3d_nb_psnrs = [] M = 4 if B > 4 else B for b in range(M): bm3d_rec = bm3d.bm3d(mid_img_og[b].cpu().transpose(0, 2) + 0.5, sigma_psd=noise_level / 255, stage_arg=bm3d.BM3DStages.ALL_STAGES) bm3d_rec = torch.FloatTensor(bm3d_rec).transpose(0, 2) # maybe an issue here b_loss = F.mse_loss(raw_img[b].cpu(), bm3d_rec, reduction='none').reshape(1, -1) b_loss = torch.mean(b_loss, 1).detach().cpu().numpy() bm3d_nb_psnr = np.mean(mse_to_psnr(b_loss)) bm3d_nb_psnrs.append(bm3d_nb_psnr) bm3d_nb_ave = np.mean(bm3d_nb_psnrs) bm3d_nb_std = np.std(bm3d_nb_psnrs) # -- psnr for input averaged frames -- # burst_ave = torch.mean(burst_og,dim=0) # mse_loss = F.mse_loss(raw_img,burst_ave,reduction='none').reshape(B,-1) # mse_loss = torch.mean(mse_loss,1).detach().cpu().numpy() # psnr_input_ave = np.mean(mse_to_psnr(mse_loss)) # psnr_input_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for aligned + denoised -- R = denoised.shape[1] raw_img_repN = raw_img.unsqueeze(1).repeat(1, R, 1, 1, 1) # if noise_type == "qis": denoised = quantize_img(cfg,denoised) # save_image(denoised_ave,"denoised_ave.png") # save_image(denoised,"denoised.png") mse_loss = F.mse_loss(raw_img_repN, denoised, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr_denoised_ave = np.mean(mse_to_psnr(mse_loss)) psnr_denoised_std = np.std(mse_to_psnr(mse_loss)) # -- psnr for [model output image] -- mse_loss = F.mse_loss(raw_img, denoised_ave, reduction='none').reshape(B, -1) mse_loss = torch.mean(mse_loss, 1).detach().cpu().numpy() psnr = np.mean(mse_to_psnr(mse_loss)) psnr_std = np.std(mse_to_psnr(mse_loss)) # -- update losses -- running_loss /= cfg.log_interval # -- reconstruction MSE -- rec_mse_ave = rec_mse_losses / rec_mse_count rec_mse_losses, rec_mse_count = 0, 0 # -- reconstruction Dist. -- rec_ot_ave = rec_ot_losses / rec_ot_count rec_ot_losses, rec_ot_count = 0, 0 # -- ave dynamic acc -- ave_dyn_acc = dynamics_acc / dynamics_count * 100. dynamics_acc, dynamics_count = 0, 0 # -- write record -- if use_record: info = { 'burst': burst_loss, 'ave': ave_loss, 'ot': rec_ot_ave, 'psnr': psnr, 'psnr_std': psnr_std } record_losses = record_losses.append(info, ignore_index=True) # -- write to stdout -- write_info = (epoch, cfg.epochs, batch_idx, steps_per_epoch, running_loss, psnr, psnr_std, psnr_denoised_ave, psnr_denoised_std, psnr_aligned_ave, psnr_aligned_std, psnr_misaligned_ave, psnr_misaligned_std, bm3d_nb_ave, bm3d_nb_std, rec_mse_ave, ave_dyn_acc) #rec_ot_ave) #print("[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [r-ot]: %.2e" % write_info) print( "[%d/%d][%d/%d]: %2.3e [PSNR]: %2.2f +/- %2.2f [den]: %2.2f +/- %2.2f [al]: %2.2f +/- %2.2f [mis]: %2.2f +/- %2.2f [bm3d]: %2.2f +/- %2.2f [r-mse]: %.2e [dyn]: %.2e" % write_info, flush=True) # -- write to summary writer -- if writer: writer.add_scalar('train/running-loss', running_loss, cfg.global_step) writer.add_scalars('train/model-psnr', { 'ave': psnr, 'std': psnr_std }, cfg.global_step) writer.add_scalars('train/dn-frame-psnr', { 'ave': psnr_denoised_ave, 'std': psnr_denoised_std }, cfg.global_step) # -- reset loss -- running_loss = 0 # -- write examples -- if write_examples and (batch_idx % write_examples_iter) == 0 and ( batch_idx > 0 or cfg.global_step == 0): write_input_output(cfg, model, stacked_burst, aligned, denoised, all_filters, motion) if use_timer: clock.toc() if use_timer: print("data_clock", data_clock.average_time) print("clock", clock.average_time) cfg.global_step += 1 # -- remove hooks -- for hook in align_hooks: hook.remove() total_loss /= len(train_loader) return total_loss, record_losses
def main(): # # -- init experiment -- # cfg = edict() cfg.gpuid = 1 cfg.noise_params = edict() cfg.noise_params.g = edict() # data = load_dataset(cfg) torch.manual_seed(143) #131 = 80% vs 20% # # -- pick our noise -- # # -- gaussian noise -- # cfg.noise_type = 'g' # cfg.noise_params['g']['mean'] = 0. # cfg.noise_params['g']['stddev'] = 125. # cfg.noise_params.ntype = cfg.noise_type # -- poisson noise -- cfg.noise_type = "pn" cfg.noise_params['pn'] = edict() cfg.noise_params['pn']['alpha'] = 1.0 cfg.noise_params['pn']['std'] = 0.0 cfg.noise_params.ntype = cfg.noise_type # -- low-light noise -- # cfg.noise_type = "qis" # cfg.noise_params['qis'] = edict() # cfg.noise_params['qis']['alpha'] = 4.0 # cfg.noise_params['qis']['readout'] = 0.0 # cfg.noise_params['qis']['nbits'] = 3 # cfg.noise_params['qis']['use_adc'] = True # cfg.noise_params.ntype = cfg.noise_type # # -- setup the dynamics -- # cfg.nframes = 5 cfg.frame_size = 350 cfg.nblocks = 5 T = cfg.nframes cfg.dynamic = edict() cfg.dynamic.frames = cfg.nframes cfg.dynamic.bool = True cfg.dynamic.ppf = 1 cfg.dynamic.mode = "global" cfg.dynamic.random_eraser = False cfg.dynamic.frame_size = cfg.frame_size cfg.dynamic.total_pixels = cfg.dynamic.ppf * (cfg.nframes - 1) # -- setup noise and dynamics -- noise_xform = get_noise_transform(cfg.noise_params, noise_only=True) def null(image): return image dynamics_xform = get_dynamic_transform(cfg.dynamic, null) # -- sample data -- image_path = "./data/512-512-grayscale-image-Cameraman.png" image = Image.open(image_path).convert("RGB") image = image.crop((0, 0, cfg.frame_size, cfg.frame_size)) clean, res, raw, flow = dynamics_xform(image) clean = clean[:, None] burst = noise_xform(clean + 0.5) flow = flow[None, :] reference = repeat(clean[[T // 2]], '1 b c h w -> t b c h w', t=T) print("Flow") print(flow) # -- our method -- ref_frame = T // 2 nblocks = cfg.nblocks method = "simple" noise_info = cfg.noise_params scores, aligned_simp, dacc_simp = lpas_search(burst, ref_frame, nblocks, flow, method, clean, noise_info) # -- split search -- ref_frame = T // 2 nblocks = cfg.nblocks method = "split" noise_info = cfg.noise_params scores, aligned_split, dacc_split = lpas_search(burst, ref_frame, nblocks, flow, method, clean, noise_info) # -- quantitative comparison -- crop_size = 256 image1, image2 = cc(aligned_simp, crop_size), cc(reference, crop_size) psnrs = images_to_psnrs(image1, image2) print("Aligned Simple Method: ", psnrs, dacc_simp.item()) image1, image2 = cc(aligned_split, crop_size), cc(reference, crop_size) psnrs = images_to_psnrs(image1, image2) print("Aligned Split Method: ", psnrs, dacc_split.item()) # -- compute noise 2 sim -- # T,K = cfg.nframes,cfg.nframes # patchsize = 31 # query = burst[[T//2]] # database = torch.cat([burst[:T//2],burst[T//2+1:]]) # clean_db = clean # sim_outputs = compute_similar_bursts_analysis(cfg,query,database,clean_db,K,-1., # patchsize=patchsize,shuffle_k=False, # kindex=None,only_middle=False, # search_method="l2",db_level="burst") # sims,csims,wsims,b_dist,b_indx = sim_outputs # -- display images -- print(aligned_simp.shape) print(aligned_split.shape) print_tensor_stats("aligned", aligned_simp) # print(csims.shape) save_image(burst, "lpas_demo_burst.png", [-0.5, 0.5]) save_image(clean, "lpas_demo_clean.png") save_image(aligned_simp, "lpas_demo_aligned_simp.png") save_image(aligned_split, "lpas_demo_aligned_split.png") save_image(cc(aligned_simp, crop_size), "lpas_demo_aligned_simp_ccrop.png") save_image(cc(aligned_split, crop_size), "lpas_demo_aligned_split_ccrop.png") delta_full_simp = aligned_simp - aligned_simp[T // 2] delta_full_split = aligned_split - aligned_split[T // 2] save_image(delta_full_simp, "lpas_demo_aligned_full_delta_simp.png", [-0.5, 0.5]) save_image(delta_full_split, "lpas_demo_aligned_full_delta_split.png", [-0.5, 0.5]) delta_cc_simp = cc(delta_full_simp, crop_size) delta_cc_split = cc(delta_full_split, crop_size) save_image(delta_full_simp, "lpas_demo_aligned_cc_delta_simp.png") save_image(delta_full_split, "lpas_demo_aligned_cc_delta_split.png") top = 75 size = 64 simp = tvF.crop(aligned_simp, top, 200, size, size) split = tvF.crop(aligned_split, top, 200, size, size) print_tensor_stats("delta", simp) save_image(simp, "lpas_demo_aligned_simp_inspect.png") save_image(split, "lpas_demo_aligned_split_inspect.png") delta_simp = simp - simp[T // 2] delta_split = split - split[T // 2] print_tensor_stats("delta", delta_simp) save_image(delta_simp, "lpas_demo_aligned_simp_inspect_delta.png", [-1, 1.]) save_image(delta_split, "lpas_demo_aligned_split_inspect_delta.png", [-1, 1.])
def test_sim_search_attn_v2(cfg, clean, model): # -- init -- N, B, C, H, W = clean.shape ps = cfg.byol_patchsize # -- unfold clean image -- patches = model.patch_helper.prepare_burst_patches(clean) patches = patches.cuda(non_blocking=True) # R,N,B,L,C,H,W = patches.shape # -- start loop -- psnrs = {} noisy_grid = create_noise_level_grid(cfg) for noise_params in noisy_grid: # -- setup noise xform -- cfg.noise_type = noise_params.ntype cfg.noise_params.ntype = cfg.noise_type cfg.noise_params[cfg.noise_type] = noise_params noise_func = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -- apply noise -- noisy_patches = noise_func( patches) # shape = (r n b nh_size^2 c ps_B ps_B) # -- create noisy img -- f_mid = cfg.byol_nh_size**2 // 2 p_mid = cfg.byol_patchsize // 2 noisy_img = noisy_patches[:, :, :, f_mid, :, p_mid, p_mid] noisy_img = rearrange(noisy_img, '(h w) n b c -> n b c h w', h=cfg.frame_size) ftr_img = get_feature_image(cfg, noisy_patches, model, "attn") print("[ftr_img.shape]", ftr_img.shape) # print("[emd] PSNR: ",np.mean(images_to_psnrs(embeddings_0,embeddings_1))) # print("[ftr] PSNR: ",np.mean(images_to_psnrs(ftr_img_0,ftr_img_1))) # -- construct similar image -- query = edict() query.pix = noisy_img[[0]] query.ftr = ftr_img[[0]] query.shape = query.pix.shape database = edict() database.pix = noisy_img[[1]] database.ftr = ftr_img[[1]] database.shape = database.pix.shape clean_db = edict() clean_db.pix = clean[[1]] clean_db.ftr = clean_db.pix clean_db.shape = clean_db.pix.shape sim_outputs = compute_similar_bursts_analysis( cfg, query, database, clean_db, 1, patchsize=cfg.sim_patchsize, shuffle_k=False, kindex=None, only_middle=cfg.sim_only_middle, db_level='frame', search_method=cfg.sim_method, noise_level=None) # -- compute psnr -- ref = clean[0] clean_sims = sim_outputs[1][0, :, 0] psnrs_np = images_to_psnrs(ref.cpu(), clean_sims.cpu()) psnrs[noise_params.name] = edict() psnrs[noise_params.name].psnrs = psnrs_np psnrs[noise_params.name].ave = np.mean(psnrs_np) psnrs[noise_params.name].std = np.std(psnrs_np) psnrs[noise_params.name].min = np.min(psnrs_np) psnrs[noise_params.name].max = np.max(psnrs_np) # print(noise_params.name,psnrs[noise_params.name]) return psnrs
def test_sim_search_pix_v2(cfg, clean, model): # -- init -- N, B, C, H, W = clean.shape cleanBN = rearrange(clean, 'n b c h w -> (b n) c h w') clean_pil = [ tvT.ToPILImage()(cleanBN[i] + 0.5).convert("RGB") for i in range(B * N) ] ps = cfg.byol_patchsize unfold = nn.Unfold(ps, 1, 0, 1) # -- start loop -- psnrs = {} noisy_grid = create_noise_level_grid(cfg) for noise_params in noisy_grid: # -- get noisy images -- cfg.noise_type = noise_params.ntype cfg.noise_params.ntype = cfg.noise_type cfg.noise_params[cfg.noise_type] = noise_params noise_func = get_noise_transform(cfg.noise_params) noisyBN = torch.stack([noise_func(clean_pil[i]) for i in range(B * N)], dim=0) noisy = rearrange(noisyBN, '(b n) c h w -> n b c h w', b=B) # -- construct similar image -- query = edict() query.pix = noisy[[0]] query.ftr = noisy[[0]] query.shape = query.pix.shape database = edict() database.pix = noisy[[1]] database.ftr = noisy[[1]] database.shape = database.pix.shape clean_db = edict() clean_db.pix = clean[[1]] clean_db.ftr = clean_db.pix clean_db.shape = clean_db.pix.shape sim_outputs = compute_similar_bursts_analysis( cfg, query, database, clean_db, 1, patchsize=cfg.sim_patchsize, shuffle_k=False, kindex=None, only_middle=cfg.sim_only_middle, db_level='frame', search_method=cfg.sim_method, noise_level=None) # -- compute psnr -- ref = clean[0] clean_sims = sim_outputs[1][0, :, 0] psnrs_np = images_to_psnrs(ref.cpu(), clean_sims.cpu()) psnrs[noise_params.name] = edict() psnrs[noise_params.name].psnrs = psnrs_np psnrs[noise_params.name].ave = np.mean(psnrs_np) psnrs[noise_params.name].std = np.std(psnrs_np) psnrs[noise_params.name].min = np.min(psnrs_np) psnrs[noise_params.name].max = np.max(psnrs_np) # print(noise_params.name,psnrs[noise_params.name]) return psnrs
def test_sim_search_ftr(cfg, clean, model, ftr_types): # -- init -- N, B, C, H, W = clean.shape ps = cfg.byol_patchsize if clean.min() < 0: clean += 0.5 # non-negative pixels # -- unfold clean image -- patches = model.patch_helper.prepare_burst_patches(clean) patches = patches.cuda(non_blocking=True) ps = cfg.byol_patchsize # shape = (r n b nh_size^2 c ps_B ps_B) # -- start loop -- psnrs = edict({}) for ftr_type in ftr_types: psnrs[ftr_type] = edict({}) noisy_grid = create_noise_level_grid(cfg) with torch.no_grad(): for noise_params in noisy_grid: # -- setup noise xform -- cfg.noise_type = noise_params.ntype cfg.noise_params.ntype = cfg.noise_type cfg.noise_params[cfg.noise_type] = noise_params noise_func = get_noise_transform(cfg.noise_params, noise_only=True) # -- apply noise -- noisy_patches = noise_func( patches) # shape = (r n b nh_size^2 c ps_B ps_B) # -- create noisy img -- noisy_img = get_pixel_features(cfg, noisy_patches) # -- get features -- for ftype in ftr_types: ftr_img = get_feature_image(cfg, noisy_patches, model, ftype) # -- some debugging code -- vis = False if vis: vis_noisy_features(cfg, noisy_img, ftr_img, clean, ftype) testing_indexing = False if testing_indexing: test_patch_helper_indexing(cfg, noisy_img, ftr_img, clean, ftype) # -- construct similar image -- if ftype != "pix": sim_patchsize = cfg.sim_patchsize cfg.sim_patchsize = 1 psnrs_np = compute_similar_psnr(cfg, noisy_img, ftr_img, clean) cfg.sim_patchsize = sim_patchsize else: psnrs_np = compute_similar_psnr(cfg, noisy_img, ftr_img, clean) # -- compute psnr -- psnrs[ftype][noise_params.name] = edict() psnrs[ftype][noise_params.name].psnrs = psnrs_np compute_psnrs_summary(psnrs[ftype][noise_params.name]) # psnrs[ftype][noise_params.name].ave = np.mean(psnrs_np) # psnrs[ftype][noise_params.name].std = np.std(psnrs_np) # psnrs[ftype][noise_params.name].min = np.min(psnrs_np) # psnrs[ftype][noise_params.name].max = np.max(psnrs_np) del ftr_img return psnrs