def get_train_log_info(cfg, model, denoised, loss, dyn_noisy, dyn_clean, sims, masks, aligned, flow, flow_gt): # -- init info -- info = {} nframes, nimages, ncolor, h, w = dyn_clean.shape ref_t = nframes // 2 # -- image psnrs -- image_psnrs = images_to_psnrs(denoised, dyn_clean[ref_t]) info['image_psnrs'] = image_psnrs # -- sim images psnrs nimages = sims.shape[1] nsims = sims.shape[0] - 1 clean = repeat(dyn_clean[ref_t], 'b c h w -> s b c h w', s=nsims) ref_clean = rearrange(clean, 's b c h w -> (s b) c h w') sims = rearrange(sims[1:], 's b c h w -> (s b) c h w') sim_psnrs = images_to_psnrs(ref_clean, sims) sim_psnrs = rearrange(sim_psnrs, '(t b) -> t b', b=nimages) info['sim_psnrs'] = sim_psnrs # -- aligned image psnrs -- T, B, C, H, W = dyn_noisy.shape isize = edict({'h': H, 'w': W}) clean = repeat(dyn_clean[ref_t], 'b c h w -> t b c h w', t=nframes) ref_clean = rearrange(clean, 't b c h w -> (t b) c h w') if not (flow is None): aligned_clean = align_from_flow(dyn_clean, flow, cfg.nblocks, isize=isize) aligned_clean = aligned_clean.to(dyn_clean.device, non_blocking=True) aligned_rs = rearrange(aligned_clean, 't b c h w -> (t b) c h w') aligned_psnrs = images_to_psnrs(ref_clean, aligned_rs) aligned_psnrs = rearrange(aligned_psnrs, '(t b) -> t b', t=nframes) info['aligned_psnrs'] = aligned_psnrs else: info['aligned_psnrs'] = np.zeros(1) # -- epe errors -- if not (flow is None): info['epe'] = compute_epe(flow, flow_gt) else: info['epe'] = np.zeros(1) # -- nnf acc -- if not (flow is None): info['nnf_acc'] = compute_pair_flow_acc(flow, flow_gt) else: info['nnf_acc'] = np.zeros(1) return info
def fill_results(cfg, image, clean, burst, model, idx): results = {} # -- fill in old result params -- T = burst.shape[0] rep = repeat(image, 'c h w -> tile c h w', tile=T) psnr_clean = float(np.mean(images_to_psnrs(rep, clean))) results['psnr_clean'] = psnr_clean results['params_norm_mean'] = -1 results['trace_norm'] = -1 results['mse'] = -1 results['psnr_rec'] = -1 results['mse'] = -1 results['psnr_burst'] = -1 results['psnr_intra_input'] = -1 results['psnr_bc_v1'] = -1 results['psnr_noisy'] = -1 score_fxn_names = [ 'lgsubset_v_ref', 'lgsubset', 'ave', 'lgsubset_v_indices', 'gaussian_ot' ] for name in score_fxn_names: results[f"fu_{name}"] = 0. for name in score_fxn_names: results[name] = 0. return results
def align_psnr(aligned, isize): isize = [isize[k] for k in isize.keys()] nframes = len(aligned) ref = nframes // 2 psnrs = 0 aligned = tvF.center_crop(aligned, isize) for t in range(nframes): if t == ref: continue psnrs += np.mean(images_to_psnrs(aligned[t], aligned[ref])).item() psnrs /= (nframes - 1) return psnrs
def compute_aligned_psnr(aligned_a, aligned_b, csize): nframes = aligned_a.shape[0] crop_a = tvF.center_crop(aligned_a, (csize.h, csize.w)) crop_b = tvF.center_crop(aligned_b, (csize.h, csize.w)) psnrs = [] for t in range(nframes): batch_a = crop_a[t] batch_b = crop_b[t] psnr = images_to_psnrs(batch_a, batch_b) psnrs.append(psnr) psnrs = np.stack(psnrs, axis=0) return psnrs
def compute_recs_psnrs(recs, clean): B, T = recs.shape[0], recs.shape[2] S = recs.shape[1] * recs.shape[2] recs = rearrange(recs, 'b tm1 t c h w -> b (tm1 t) c h w') clean = repeat(clean, 'b c h w -> b tile c h w', tile=S) psnrs = [] for b in range(B): psnrs_b = torch.FloatTensor(images_to_psnrs(recs[b] + 0.5, clean[b])) psnrs.append(psnrs_b) psnrs = torch.stack(psnrs, dim=0) psnrs = rearrange(psnrs, 'b (tm1 t) -> b tm1 t', b=B, t=T) return psnrs
def compute_similar_psnr(cfg, noisy_img, ftr_img, clean, q_index, db_index, crop=False): # -- construct similar image -- query = edict() query.pix = noisy_img[[q_index]] print(query.pix.shape) print(noisy_img.shape, noisy_img[[q_index]].shape, noisy_img[q_index].shape) query.ftr = ftr_img[[q_index]] query.shape = query.pix.shape database = edict() database.pix = noisy_img[[db_index]] database.ftr = ftr_img[[db_index]] database.shape = database.pix.shape clean_db = edict() clean_db.pix = clean[[db_index]] clean_db.ftr = clean_db.pix clean_db.shape = clean_db.pix.shape sim_outputs = compute_similar_bursts_analysis( cfg, query, database, clean_db, 1, patchsize=cfg.sim_patchsize, shuffle_k=False, kindex=None, only_middle=cfg.sim_only_middle, db_level='frame', search_method=cfg.sim_method, noise_level=None) # -- compute psnr -- ref = clean[0] clean_sims = sim_outputs[1][0, :, 0] if crop: ref = tvF.crop(ref, 10, 10, 48, 48) clean_sims = tvF.crop(clean_sims, 10, 10, 48, 48) psnrs_np = images_to_psnrs(ref.cpu(), clean_sims.cpu()) return psnrs_np, clean_sims
def compute_consistency_mat(self, recs, cmpr): B = recs.shape[0] Tm1, T = recs.shape[1:3] c1, c2 = cmpr.shape[1:3] simmat = torch.zeros(B, Tm1, c1, T, c2) for b in range(B): for m_i in range(Tm1): for m_j in range(c1): for l in range(T): for k in range(c2): #simmat[m_i,m_j,l,k] = 1000*F.mse_loss(recs[m_i,l],recs[m_j,k]).item() psnrs = images_to_psnrs(recs[b, m_i, l], cmpr[b, m_j, k]) simmat[b, m_i, m_j, l, k] = np.mean(psnrs) return simmat
def get_test_log_info(cfg, model, denoised, loss, dyn_noisy, dyn_clean): # -- init info -- info = {} nframes, nimages, ncolor, h, w = dyn_clean.shape ref_t = nframes // 2 # -- image psnrs -- image_psnrs = images_to_psnrs(denoised, dyn_clean[ref_t]) info['image_psnrs'] = image_psnrs # -- empty for square matrix later -- info['aligned_psnrs'] = [] info['sim_psnrs'] = [] info['epe'] = [] info['nnf_acc'] = [] return info
def best_global_image_arrangement_psnr(ref, frames): # -- crop images -- NH2, C, H, W = frames.shape NH = int(np.sqrt(NH2)) top, left = 0, 0 #NH//2,NH//2 crop_ref = tvF.crop(ref, top, left, H, W) crop_frames = tvF.crop(frames, top, left, H, W) # -- compute score -- scores = torch.FloatTensor(images_to_psnrs(crop_ref, crop_frames)) # scores = F.mse_loss( crop_ref, crop_frames, reduction='none').reshape(NH2,-1) # scores = torch.mean(scores,dim=1) # -- find best -- best_index = torch.argmax(scores) best_score = scores[best_index] return best_score, best_index
def alignment_optimizer(cfg,score_fxn,blocks,clean,block_search_space,scores_path): # -- vectorize search since single patch -- R,B,T,N,C,PS1,PS2 = blocks.shape REF_N = get_ref_block_index(int(np.sqrt(N))) #print(cfg.nframes,T,cfg.nblocks,N,block_search_space.shape) assert (R == 1) and (B == 1), "single pixel's block and single sample please." expanded = blocks[:,:,np.arange(T),block_search_space] E = expanded.shape[2] # -- evaluate block -- scores = score_fxn(cfg,expanded) scores = scores[0,0] best_index = torch.argmin(scores).item() best_score = torch.min(scores).item() assert E >= best_index, "No score can be greater than best index." # -- select the best block -- best_block = block_search_space[best_index] best_block_str = ''.join([str(i) for i in best_block.cpu().numpy()]) # -- construct image and compute the associated psnr -- ref = repeat(clean[0,0,T//2,REF_N],'c h w -> tile c h w',tile=T) aligned = clean[0,0,np.arange(T),best_block] psnr = images_to_psnrs(ref,aligned) # -- save scores to numpy array -- np.save(scores_path,scores.cpu().numpy()) # -- compute results -- results = {'scores':scores_path, 'best_idx':best_index, 'best_score':best_score, 'best_block':best_block_str, 'psnr':psnr} return results
def test(cfg, image, clean, burst, model, idx): # i3 = len(image.shape) == 3 # c3 = len(clean.shape) == 3 # b3 = len(burst.shape) == 3 # # assert (i3 == c3) and (i3 == b3), "All three dims same" # if i3 and c3 and b3: T = burst.shape[0] # -- create results -- results = {} # -- repeat along axis -- rep = repeat(image, 'c h w -> tile c h w', tile=T) # -- reconstruct a clean image -- rec = model(burst) + 0.5 # -- parameters -- params = torch.cat([param.view(-1) for param in model.parameters()]) params_norm_mean = float(torch.norm(params).item()) results['params_norm_mean'] = params_norm_mean # -- parameters -- # named_params = dict(model.named_parameters()) # print(named_params.keys()) # filters = named_params['conv1.single_conv.0.weight'] # print(filters.shape) # # params = torch.cat([param.view(-1) for param in model.parameters()]) # # params_norm_mean = float(torch.norm(params).item()) # results['params_filter_diff'] = params_filter_diff # -- size of params for each sample's activations path -- trace_norm = activation_trace(model, burst, 'norm') results['trace_norm'] = trace_norm # -- save -- if idx == 49 or idx == 40 or idx == 60: save_image(rec, f"fast_unet_rec_{idx}.png", normalize=True) # -- compute results -- loss = F.mse_loss(rec, rep) psnr = float(np.mean(images_to_psnrs(image, rec[T // 2]))) results['mse'] = loss.item() results['psnr_rec'] = psnr psnr = float(np.mean(images_to_psnrs(rec, rep))) results['psnr_burst'] = psnr # -- intra and input -- intra_input = 0 for t in range(T): intra_input += F.mse_loss(rec[t], rec[T // 2]).item() intra_input += F.mse_loss(rec[t], burst[T // 2]).item() results['psnr_intra_input'] = intra_input # -- this n2n training creates a barycenter for center image -- bc_loss = 0 for t in range(T): bc_loss += F.mse_loss(burst[t], rec[T // 2]).item() results['psnr_bc_v1'] = bc_loss # -- compute psnr of clean and noisy frames -- psnr_noisy = float(np.mean(images_to_psnrs(rep, burst + 0.5))) results['psnr_noisy'] = psnr_noisy psnr_clean = float(np.mean(images_to_psnrs(rep, clean))) results['psnr_clean'] = psnr_clean # -- compute scores -- score_fxn_names = [ 'lgsubset_v_ref', 'lgsubset', 'ave', 'lgsubset_v_indices', 'gaussian_ot' ] wrapped_l = [] for name in score_fxn_names: score_fxn = get_score_function(name) wrapped_score = score_function_wrapper(score_fxn) if name == "gaussian_ot": score, scores_t = wrapped_score(cfg, rec - rep) else: score, scores_t = wrapped_score(cfg, rec) results[f"fnet_{name}"] = score.item() for t in range(T): results[f"fnet_{name}_{t}f"] = scores_t[t].item() # -- on raw pixels too -- for name in score_fxn_names: if name == "gaussian_ot": continue score_fxn = get_score_function(name) wrapped_score = score_function_wrapper(score_fxn) score, scores_t = wrapped_score(cfg, burst) results[name] = score.item() for t in range(T): results[f"{name}_{t}f"] = scores_t[t].item() # print("Test Loss",loss.item()) # print("Test PSNR: %2.3e" % np.mean(images_to_psnrs(rec+0.5,rep))) tv_utils.save_image(rec, "fast_unet_rec.png", normalize=True) tv_utils.save_image(burst, "fast_unet_burst.png", normalize=True) return results
def single_image_unet(cfg, queue, full_image, device): full_image = full_image.to(device) image = tvF.crop(full_image, 128, 128, 32, 32) T = 5 # -- poisson noise -- noise_type = "pn" cfg.noise_type = noise_type cfg.noise_params['pn']['alpha'] = 40.0 cfg.noise_params['pn']['readout'] = 0.0 cfg.noise_params.ntype = cfg.noise_type noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) clean = torch.stack([image for i in range(T)], dim=0) noisy = noise_xform(clean) save_image(clean, "clean.png", normalize=True) m_clean, m_noisy = clean.clone(), noisy.clone() for i in range(T // 2): image_mis = tvF.crop(full_image, 128 + 1, 128, 32, 32) m_clean[i] = image_mis m_noisy[i] = noise_xform(m_clean[i]) noise_level = 50. / 255. # -- model all -- print("-- All Aligned --") model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean, noisy, model, optim) results = test(cfg, image, clean, noisy, model, 0) rec = model(noisy) + 0.5 save_image(rec, "rec_all.png", normalize=True) print("Single Image Unet:") print(images_to_psnrs(clean, rec)) print(images_to_psnrs(rec - 0.5, noisy)) print(images_to_psnrs(rec[[0]], rec[[1]])) print(images_to_psnrs(rec[[0]], rec[[2]])) print(images_to_psnrs(rec[[1]], rec[[2]])) og_clean, og_noisy = clean.clone(), noisy.clone() clean[0] = image_mis noisy[0] = noise_xform(clean[[0]])[0] # -- model all -- print("-- All Misligned --") model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean, noisy, model, optim) results = test(cfg, image, clean, noisy, model, 0) rec = model(noisy) + 0.5 save_image(rec, "rec_all.png", normalize=True) print("Single Image Unet:") print(images_to_psnrs(clean, rec)) print(images_to_psnrs(rec - 0.5, noisy)) print(images_to_psnrs(rec[[0]], rec[[1]])) print(images_to_psnrs(rec[[0]], rec[[2]])) print(images_to_psnrs(rec[[1]], rec[[2]])) for j in range(3): # -- data -- noisy1 = torch.stack([noisy[0], noisy[1]], dim=0) clean1 = torch.stack([clean[0], clean[1]], dim=0) noisy2 = torch.stack([noisy[1], noisy[2]], dim=0) clean2 = torch.stack([clean[1], clean[2]], dim=0) noisy3 = torch.stack([og_noisy[0], noisy[1]], dim=0) clean3 = torch.stack([og_clean[0], clean[1]], dim=0) # -- model 1 -- model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean1, noisy1, model, optim) results = test(cfg, image, clean1, noisy1, model, 0) rec = model(noisy1) + 0.5 xrec = model(noisy2) + 0.5 save_image(rec, "rec1.png", normalize=True) print("[misaligned] Single Image Unet:", images_to_psnrs(clean1, rec), images_to_psnrs(clean2, xrec), images_to_psnrs(rec - 0.5, noisy1), images_to_psnrs(xrec - 0.5, noisy2), images_to_psnrs(rec[[0]], rec[[1]]), images_to_psnrs(xrec[[0]], xrec[[1]]), images_to_psnrs(xrec[[0]], rec[[1]]), images_to_psnrs(xrec[[1]], rec[[0]])) # -- model 2 -- model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean2, noisy2, model, optim) results = test(cfg, image, clean2, noisy2, model, 0) rec = model(noisy2) + 0.5 xrec = model(noisy1) + 0.5 save_image(rec, "rec2.png", normalize=True) print("[aligned] Single Image Unet:", images_to_psnrs(clean2, rec), images_to_psnrs(clean1, xrec), images_to_psnrs(rec - 0.5, noisy2), images_to_psnrs(xrec - 0.5, noisy1), images_to_psnrs(rec[[0]], rec[[1]]), images_to_psnrs(xrec[[0]], xrec[[1]]), images_to_psnrs(xrec[[0]], rec[[1]]), images_to_psnrs(xrec[[1]], rec[[0]])) # -- model 3 -- model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean3, noisy3, model, optim) results = test(cfg, image, clean3, noisy3, model, 0) rec = model(noisy3) + 0.5 rec_2 = model(noisy2) + 0.5 rec_1 = model(noisy1) + 0.5 save_image(rec, "rec1.png", normalize=True) print("[aligned (v3)] Single Image Unet:") print("clean-rec", images_to_psnrs(clean3, rec)) print("clean1-rec1", images_to_psnrs(clean1, rec_1)) print("clean2-rec2", images_to_psnrs(clean2, rec_2)) print("rec-noisy3", images_to_psnrs(rec - 0.5, noisy3)) print("rec1-noisy1", images_to_psnrs(rec_1 - 0.5, noisy1)) print("rec2-noisy2", images_to_psnrs(rec_2 - 0.5, noisy2)) print("[v3]: rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]])) print("[v1]: rec0-rec1", images_to_psnrs(rec_1[[0]], rec_1[[1]])) print("[v2]: rec0-rec1", images_to_psnrs(rec_2[[0]], rec_2[[1]])) print("[v1-v2](a):", images_to_psnrs(rec_1[[1]], rec_2[[1]])) print("[v1-v2](b):", images_to_psnrs(rec_1[[1]], rec_2[[0]])) print("[v1-v2](c):", images_to_psnrs(rec_1[[0]], rec_2[[1]])) print("[v1-v2](d):", images_to_psnrs(rec_1[[0]], rec_2[[0]])) print("-" * 20) print("[v2-v3](a):", images_to_psnrs(rec_2[[1]], rec[[1]])) print("[v2-v3](b):", images_to_psnrs(rec_2[[1]], rec[[0]])) print("[v2-v3](c):", images_to_psnrs(rec_2[[0]], rec[[1]])) print("[v2-v3](d):", images_to_psnrs(rec_2[[0]], rec[[0]])) print("-" * 20) print("[v1-v3](a):", images_to_psnrs(rec_1[[1]], rec[[1]])) print("[v1-v3](b):", images_to_psnrs(rec_1[[1]], rec[[0]])) print("[v1-v3](c):", images_to_psnrs(rec_1[[0]], rec[[1]])) print("[v1-v3](d):", images_to_psnrs(rec_1[[0]], rec[[0]])) print("-" * 20) print("rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]]))
def test_sim_search_attn_v2(cfg, clean, model): # -- init -- N, B, C, H, W = clean.shape ps = cfg.byol_patchsize # -- unfold clean image -- patches = model.patch_helper.prepare_burst_patches(clean) patches = patches.cuda(non_blocking=True) # R,N,B,L,C,H,W = patches.shape # -- start loop -- psnrs = {} noisy_grid = create_noise_level_grid(cfg) for noise_params in noisy_grid: # -- setup noise xform -- cfg.noise_type = noise_params.ntype cfg.noise_params.ntype = cfg.noise_type cfg.noise_params[cfg.noise_type] = noise_params noise_func = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -- apply noise -- noisy_patches = noise_func( patches) # shape = (r n b nh_size^2 c ps_B ps_B) # -- create noisy img -- f_mid = cfg.byol_nh_size**2 // 2 p_mid = cfg.byol_patchsize // 2 noisy_img = noisy_patches[:, :, :, f_mid, :, p_mid, p_mid] noisy_img = rearrange(noisy_img, '(h w) n b c -> n b c h w', h=cfg.frame_size) ftr_img = get_feature_image(cfg, noisy_patches, model, "attn") print("[ftr_img.shape]", ftr_img.shape) # print("[emd] PSNR: ",np.mean(images_to_psnrs(embeddings_0,embeddings_1))) # print("[ftr] PSNR: ",np.mean(images_to_psnrs(ftr_img_0,ftr_img_1))) # -- construct similar image -- query = edict() query.pix = noisy_img[[0]] query.ftr = ftr_img[[0]] query.shape = query.pix.shape database = edict() database.pix = noisy_img[[1]] database.ftr = ftr_img[[1]] database.shape = database.pix.shape clean_db = edict() clean_db.pix = clean[[1]] clean_db.ftr = clean_db.pix clean_db.shape = clean_db.pix.shape sim_outputs = compute_similar_bursts_analysis( cfg, query, database, clean_db, 1, patchsize=cfg.sim_patchsize, shuffle_k=False, kindex=None, only_middle=cfg.sim_only_middle, db_level='frame', search_method=cfg.sim_method, noise_level=None) # -- compute psnr -- ref = clean[0] clean_sims = sim_outputs[1][0, :, 0] psnrs_np = images_to_psnrs(ref.cpu(), clean_sims.cpu()) psnrs[noise_params.name] = edict() psnrs[noise_params.name].psnrs = psnrs_np psnrs[noise_params.name].ave = np.mean(psnrs_np) psnrs[noise_params.name].std = np.std(psnrs_np) psnrs[noise_params.name].min = np.min(psnrs_np) psnrs[noise_params.name].max = np.max(psnrs_np) # print(noise_params.name,psnrs[noise_params.name]) return psnrs
def checkout_haar(): cfg = None std = 50. alpha = 20. # filter_name = "bior1.3" filter_name = "haar" image = pywt.data.camera() noisy = npr.normal(image, scale=std) noisy = npr.poisson(alpha * image / 255.) / alpha * 255. haar_image = pywt.dwt2(np.copy(noisy), filter_name)[0] sratio = haar_image.shape[-1] / image.shape[-1] image_shrink = cv2.resize(image, None, fx=sratio, fy=sratio, interpolation=cv2.INTER_NEAREST) noisy_shrink = cv2.resize(noisy, None, fx=sratio, fy=sratio, interpolation=cv2.INTER_NEAREST) # haar_image -= haar_image.min() haar_image /= 2. print(image.min(), image.max(), image.mean()) print(haar_image.min(), haar_image.max(), haar_image.mean()) # -- plot result -- fig, axes = plt.subplots(1, 3, figsize=(3 * 4, 4)) images = [haar_image, noisy_shrink, image_shrink] nmlz_const = [haar_image.max(), 255., 255.] titles = ["Haar Image", "Noisy", "Original Image"] shift_x, shift_y = 1, 1 for idx, ax in enumerate(axes): nmlz_img = images[idx] / 255. #nmlz_const[idx] ref_img = image_shrink / 255. # diff = nmlz_img[shift_y:,shift_x:] - ref_img[:-shift_y,:-shift_x] # diff = nmlz_img[:-shift_y,:-shift_x:] - ref_img[shift_y:,shift_x:] diff = nmlz_img # - ref_img ax_image = 255. * diff ax_title = titles[idx] ax.imshow(ax_image) ax.set_title(ax_title, fontsize=20) ax.set_xticks([]) ax.set_yticks([]) fig.tight_layout() plt.savefig("./tmp.png", dpi=300) print(haar_image) haar_image = haar_image / 255. noisy_shrink = noisy_shrink / 255. image_shrink = image_shrink / 255. haar_psnrs = images_to_psnrs(haar_image, image_shrink) noisy_psnrs = images_to_psnrs(noisy_shrink, image_shrink) print(haar_psnrs, noisy_psnrs, sratio, haar_image.shape, image.shape)
def main(): # # -- init experiment -- # cfg = edict() cfg.gpuid = 1 cfg.noise_params = edict() cfg.noise_params.g = edict() # data = load_dataset(cfg) torch.manual_seed(143) #131 = 80% vs 20% # # -- pick our noise -- # # -- gaussian noise -- # cfg.noise_type = 'g' # cfg.noise_params['g']['mean'] = 0. # cfg.noise_params['g']['stddev'] = 125. # cfg.noise_params.ntype = cfg.noise_type # -- poisson noise -- cfg.noise_type = "pn" cfg.noise_params['pn'] = edict() cfg.noise_params['pn']['alpha'] = 1.0 cfg.noise_params['pn']['std'] = 0.0 cfg.noise_params.ntype = cfg.noise_type # -- low-light noise -- # cfg.noise_type = "qis" # cfg.noise_params['qis'] = edict() # cfg.noise_params['qis']['alpha'] = 4.0 # cfg.noise_params['qis']['readout'] = 0.0 # cfg.noise_params['qis']['nbits'] = 3 # cfg.noise_params['qis']['use_adc'] = True # cfg.noise_params.ntype = cfg.noise_type # # -- setup the dynamics -- # cfg.nframes = 5 cfg.frame_size = 350 cfg.nblocks = 5 T = cfg.nframes cfg.dynamic = edict() cfg.dynamic.frames = cfg.nframes cfg.dynamic.bool = True cfg.dynamic.ppf = 1 cfg.dynamic.mode = "global" cfg.dynamic.random_eraser = False cfg.dynamic.frame_size = cfg.frame_size cfg.dynamic.total_pixels = cfg.dynamic.ppf * (cfg.nframes - 1) # -- setup noise and dynamics -- noise_xform = get_noise_transform(cfg.noise_params, noise_only=True) def null(image): return image dynamics_xform = get_dynamic_transform(cfg.dynamic, null) # -- sample data -- image_path = "./data/512-512-grayscale-image-Cameraman.png" image = Image.open(image_path).convert("RGB") image = image.crop((0, 0, cfg.frame_size, cfg.frame_size)) clean, res, raw, flow = dynamics_xform(image) clean = clean[:, None] burst = noise_xform(clean + 0.5) flow = flow[None, :] reference = repeat(clean[[T // 2]], '1 b c h w -> t b c h w', t=T) print("Flow") print(flow) # -- our method -- ref_frame = T // 2 nblocks = cfg.nblocks method = "simple" noise_info = cfg.noise_params scores, aligned_simp, dacc_simp = lpas_search(burst, ref_frame, nblocks, flow, method, clean, noise_info) # -- split search -- ref_frame = T // 2 nblocks = cfg.nblocks method = "split" noise_info = cfg.noise_params scores, aligned_split, dacc_split = lpas_search(burst, ref_frame, nblocks, flow, method, clean, noise_info) # -- quantitative comparison -- crop_size = 256 image1, image2 = cc(aligned_simp, crop_size), cc(reference, crop_size) psnrs = images_to_psnrs(image1, image2) print("Aligned Simple Method: ", psnrs, dacc_simp.item()) image1, image2 = cc(aligned_split, crop_size), cc(reference, crop_size) psnrs = images_to_psnrs(image1, image2) print("Aligned Split Method: ", psnrs, dacc_split.item()) # -- compute noise 2 sim -- # T,K = cfg.nframes,cfg.nframes # patchsize = 31 # query = burst[[T//2]] # database = torch.cat([burst[:T//2],burst[T//2+1:]]) # clean_db = clean # sim_outputs = compute_similar_bursts_analysis(cfg,query,database,clean_db,K,-1., # patchsize=patchsize,shuffle_k=False, # kindex=None,only_middle=False, # search_method="l2",db_level="burst") # sims,csims,wsims,b_dist,b_indx = sim_outputs # -- display images -- print(aligned_simp.shape) print(aligned_split.shape) print_tensor_stats("aligned", aligned_simp) # print(csims.shape) save_image(burst, "lpas_demo_burst.png", [-0.5, 0.5]) save_image(clean, "lpas_demo_clean.png") save_image(aligned_simp, "lpas_demo_aligned_simp.png") save_image(aligned_split, "lpas_demo_aligned_split.png") save_image(cc(aligned_simp, crop_size), "lpas_demo_aligned_simp_ccrop.png") save_image(cc(aligned_split, crop_size), "lpas_demo_aligned_split_ccrop.png") delta_full_simp = aligned_simp - aligned_simp[T // 2] delta_full_split = aligned_split - aligned_split[T // 2] save_image(delta_full_simp, "lpas_demo_aligned_full_delta_simp.png", [-0.5, 0.5]) save_image(delta_full_split, "lpas_demo_aligned_full_delta_split.png", [-0.5, 0.5]) delta_cc_simp = cc(delta_full_simp, crop_size) delta_cc_split = cc(delta_full_split, crop_size) save_image(delta_full_simp, "lpas_demo_aligned_cc_delta_simp.png") save_image(delta_full_split, "lpas_demo_aligned_cc_delta_split.png") top = 75 size = 64 simp = tvF.crop(aligned_simp, top, 200, size, size) split = tvF.crop(aligned_split, top, 200, size, size) print_tensor_stats("delta", simp) save_image(simp, "lpas_demo_aligned_simp_inspect.png") save_image(split, "lpas_demo_aligned_split_inspect.png") delta_simp = simp - simp[T // 2] delta_split = split - split[T // 2] print_tensor_stats("delta", delta_simp) save_image(delta_simp, "lpas_demo_aligned_simp_inspect_delta.png", [-1, 1.]) save_image(delta_split, "lpas_demo_aligned_split_inspect_delta.png", [-1, 1.])
def test_abp_search_exhaustive_global_dynamics(noisy, clean, burst_indices, PS, NH, K=-1, nh_grids=None): # -- init vars -- R, B, N = noisy.shape[:3] FMAX = np.finfo(np.float).max REF_NH = get_ref_nh(NH) print(f"REF_NH: {REF_NH}") BI = burst_indices.shape[0] ref_patch = noisy[:, :, [N // 2], [REF_NH], :, :, :] # -- create clean testing image -- H = int(np.sqrt(clean.shape[0])) clean_img = rearrange(clean[..., N // 2, REF_NH, :, PS // 2, PS // 2], '(h w) b c -> b c h w', h=H) clean_img = repeat(clean_img, 'b c h w -> tile b c h w', tile=N) # -- create search grids -- if nh_grids is None: nh_grids = create_nh_grids(BI, NH) n_grids = create_n_grids(BI) print(f"NH_GRIDS {len(nh_grids)} | N_GRIDS {len(n_grids)}") # -- randomly initialize grids -- # np.random.shuffle(nh_grids) # np.random.shuffle(n_grids) # -- init loop vars -- psnrs = np.zeros((len(nh_grids), BI)) scores = np.zeros(len(nh_grids)) scores_old = np.zeros(len(nh_grids)) best_score, best_select = FMAX, None # -- remove boundary -- aug_burst_indices = insert_n_middle(burst_indices, N) aug_burst_indices = torch.LongTensor(aug_burst_indices) subR = torch.arange(H * H // 3 * 2) + NH * H search = noisy[subR] ref_patch = ref_patch[subR] # -- coordinate descent -- for nh_index, nh_grid in enumerate(nh_grids): # -- compute score -- grid_patches = search[:, :, burst_indices, nh_grid, :, :, :] grid_patches = torch.cat([ref_patch, grid_patches], dim=2) score, score_old, count = 0, 0, 0 for (nset0, nset1) in n_grids[:100]: denoised0 = torch.mean(grid_patches[:, :, nset0], dim=2) denoised1 = torch.mean(grid_patches[:, :, nset1], dim=2) score_old += F.mse_loss(denoised0, denoised1).item() # -- neurips 2019 -- rep0 = repeat(denoised0, 'r b c p1 p2 -> r b tile c p1 p2', tile=len(nset0)) rep01 = repeat(denoised0, 'r b c p1 p2 -> r b tile c p1 p2', tile=len(nset1)) res0 = grid_patches[:, :, nset0] - rep0 rep1 = repeat(denoised1, 'r b c p1 p2 -> r b tile c p1 p2', tile=len(nset1)) rep10 = repeat(denoised1, 'r b c p1 p2 -> r b tile c p1 p2', tile=len(nset0)) res1 = grid_patches[:, :, nset1] - rep1 n0, n1 = len(nset0), len(nset1) xterms0, xterms1 = np.mgrid[:n0, :n1] xterms0, xterms1 = xterms0.ravel(), xterms1.ravel() # print(xterms0.shape,xterms1.shape,res0.shape,xterms0.max(),xterms1.max()) score += F.mse_loss(res0[:, :, xterms0], res1[:, :, xterms1]).item() # xterms01 = res0 + rep10 # xterms10 = res1 + rep01 # score += F.mse_loss(xterms01,xterms10).item() # score += F.mse_loss(xterms01,grid_patches[:,:,nset0]).item() # score += F.mse_loss(xterms10,grid_patches[:,:,nset1]).item() count += 1 score /= count # -- store best score -- if score < best_score: best_score = score best_select = nh_grid # -- add score to results -- scores[nh_index] = score scores_old[nh_index] = score_old # -- compute and store psnrs -- pgrid = insert_nh_middle(nh_grid, NH, BI)[None, ] bgrid = aug_burst_indices nh_grid = nh_grid[None, ] rec_img = aligned_burst_image_from_indices_global_dynamics( clean, burst_indices, nh_grid) #bgrid,pgrid) nh_psnrs = images_to_psnrs(rec_img, clean_img[burst_indices]) psnrs[nh_index, :] = nh_psnrs score_idx = np.argmin(scores) print(f"Best Score [{scores[score_idx]}] PSNRS @ [{score_idx}]:", psnrs[score_idx]) psnr_idx = np.argmax(np.mean(psnrs, 1)) print(f"Best PSNR @ [{psnr_idx}]", psnrs[psnr_idx]) # print(scores[score_idx] - scores[psnr_idx]) old_score_idx = np.argmin(scores_old) print( f"Best OLD Score [{scores_old[old_score_idx]}] PSNRS @ [{old_score_idx}]:", psnrs[old_score_idx]) print(f"Current Score @ OLD Score [{scores[old_score_idx]}]") print( f"[Old score idx v.s. Current score idx v.s. Best PSNR] {old_score_idx} v.s. {score_idx} v.s. {psnr_idx}" ) # # Recording Score Info # # -- save score info -- scores /= np.sum(scores) score_fn = f"scores_{NH}_{N}_{len(nh_grids)}_{len(n_grids)}" txt_fn = Path(f"output/abps/{score_fn}.txt") np.savetxt(txt_fn, scores) # -- plot score -- plot_fn = Path(f"output/abps/{score_fn}.png") fig, ax = plt.subplots(figsize=(8, 8)) ax.plot(np.arange(scores.shape[0]), scores, '-+') ax.axvline(x=psnr_idx, color="r") ax.axvline(x=score_idx, color="k") plt.savefig(plot_fn, dpi=300) plt.close("all") # # Recording PSNR Info # # -- save score info -- psnr_fn = f"psnrs_{NH}_{N}_{len(nh_grids)}_{len(n_grids)}" txt_fn = Path(f"output/abps/{psnr_fn}.txt") np.savetxt(txt_fn, psnrs) # -- plot psnr -- plot_fn = Path(f"output/abps/{psnr_fn}.png") fig, ax = plt.subplots(figsize=(8, 8)) ax.plot(np.arange(psnrs.shape[0]), psnrs, '-+') ax.axvline(x=psnr_idx, color="r") ax.axvline(x=score_idx, color="k") plt.savefig(plot_fn, dpi=300) plt.close("all") print(f"Wrote {score_fn} and {psnr_fn}") if K == -1: return best_score, best_select else: search_indices_topK = np.argsort(scores)[:K] scores_topK = scores[search_indices_topK] nh_grids_topK = nh_grids[search_indices_topK] return scores_topK, nh_grids_topK
def test_sim_search_pix_v2(cfg, clean, model): # -- init -- N, B, C, H, W = clean.shape cleanBN = rearrange(clean, 'n b c h w -> (b n) c h w') clean_pil = [ tvT.ToPILImage()(cleanBN[i] + 0.5).convert("RGB") for i in range(B * N) ] ps = cfg.byol_patchsize unfold = nn.Unfold(ps, 1, 0, 1) # -- start loop -- psnrs = {} noisy_grid = create_noise_level_grid(cfg) for noise_params in noisy_grid: # -- get noisy images -- cfg.noise_type = noise_params.ntype cfg.noise_params.ntype = cfg.noise_type cfg.noise_params[cfg.noise_type] = noise_params noise_func = get_noise_transform(cfg.noise_params) noisyBN = torch.stack([noise_func(clean_pil[i]) for i in range(B * N)], dim=0) noisy = rearrange(noisyBN, '(b n) c h w -> n b c h w', b=B) # -- construct similar image -- query = edict() query.pix = noisy[[0]] query.ftr = noisy[[0]] query.shape = query.pix.shape database = edict() database.pix = noisy[[1]] database.ftr = noisy[[1]] database.shape = database.pix.shape clean_db = edict() clean_db.pix = clean[[1]] clean_db.ftr = clean_db.pix clean_db.shape = clean_db.pix.shape sim_outputs = compute_similar_bursts_analysis( cfg, query, database, clean_db, 1, patchsize=cfg.sim_patchsize, shuffle_k=False, kindex=None, only_middle=cfg.sim_only_middle, db_level='frame', search_method=cfg.sim_method, noise_level=None) # -- compute psnr -- ref = clean[0] clean_sims = sim_outputs[1][0, :, 0] psnrs_np = images_to_psnrs(ref.cpu(), clean_sims.cpu()) psnrs[noise_params.name] = edict() psnrs[noise_params.name].psnrs = psnrs_np psnrs[noise_params.name].ave = np.mean(psnrs_np) psnrs[noise_params.name].std = np.std(psnrs_np) psnrs[noise_params.name].min = np.min(psnrs_np) psnrs[noise_params.name].max = np.max(psnrs_np) # print(noise_params.name,psnrs[noise_params.name]) return psnrs