def write_input_output(cfg,model,burst,aligned,filters,directions): """ :params burst: input images to the model, :shape [B, N, C, H, W] :params aligned: output images from the model, :shape [B, N, C, H, W] :params filters: filters used by model, :shape [B, N, K2, 1, Hf, Wf] with Hf = (H or 1) """ # -- file path -- path = Path(f"./output/n2n-kpn/io_examples/{cfg.exp_name}/") if not path.exists(): path.mkdir(parents=True) # -- init -- B,N,C,H,W = burst.shape # -- save histogram of residuals -- denoised_np = aligned.detach().cpu().numpy() plot_histogram_residuals_batch(denoised_np,cfg.global_step,path,rand_name=False) # -- save histogram of gradients -- plot_histogram_gradients(model,cfg.global_step,path,rand_name=False) # -- save gradient norm by layer -- plot_histogram_gradient_norms(model,cfg.global_step,path,rand_name=False) # -- save file per burst -- for b in range(B): # -- save images -- fn = path / Path(f"{cfg.global_step}_{b}.png") burst_b = torch.cat([burst[b][[N//2]] - burst[b][[0]],burst[b],burst[b][[N//2]] - burst[b][[-1]]],dim=0) aligned_b = torch.cat([aligned[b][[N//2]] - aligned[b][[0]],aligned[b],aligned[b][[N//2]] - aligned[b][[-1]]],dim=0) imgs = torch.cat([burst_b,aligned_b],dim=0) # 2N,C,H,W tv_utils.save_image(imgs,fn,nrow=N+2,normalize=True,range=(-0.5,0.5)) # -- save filters -- fn = path / Path(f"filters_{cfg.global_step}_{b}.png") K = int(np.sqrt(filters.shape[2])) if filters.shape[-1] > 1: S = npr.permutation(filters.shape[-1])[:10] filters_b = filters[b,:,:,0,S,S].view(N*10,1,K,K) else: filters_b = filters[b,:,:,0,0,0].view(N,1,K,K) tv_utils.save_image(filters_b,fn,nrow=N,normalize=True) # -- save direction image -- fn = path / Path(f"arrows_{cfg.global_step}_{b}.png") arrows = create_arrow_image(directions[b],pad=2) tv_utils.save_image([arrows],fn) plt.close("all") print(f"Wrote example images to file at [{path}]")
def write_input_output(cfg, model, burst, aligned, denoised, filters, motion): """ :params burst: input images to the model, :shape [B, N, C, H, W] :params aligned: output images from the alignment layers, :shape [B, N, C, H, W] :params denoised: output images from the denoiser, :shape [B, N, C, H, W] :params filters: filters used by model, :shape [B, L, N, K2, 1, Hf, Wf] with Hf = (H or 1) for L = number of cascaded filters """ # -- file path -- path = Path(f"./output/n2sim/io_examples/{cfg.exp_name}/") if not path.exists(): path.mkdir(parents=True) # -- init -- B, N, C, H, W = burst.shape # -- save histogram of residuals -- denoised_np = denoised.detach().cpu().numpy() plot_histogram_residuals_batch(denoised_np, cfg.global_step, path, rand_name=False) # -- save histogram of gradients (denoiser) -- if not model.use_unet_only: denoiser = model.denoiser_info.model plot_histogram_gradients(denoiser, "denoiser", cfg.global_step, path, rand_name=False) # -- save histogram of gradients (alignment) -- if model.use_alignment: alignment = model.align_info.model plot_histogram_gradients(alignment, "alignment", cfg.global_step, path, rand_name=False) # -- save gradient norm by layer (denoiser) -- if not model.use_unet_only: denoiser = model.denoiser_info.model plot_histogram_gradient_norms(denoiser, "denoiser", cfg.global_step, path, rand_name=False) # -- save gradient norm by layer (alignment) -- if model.use_alignment: alignment = model.align_info.model plot_histogram_gradient_norms(alignment, "alignment", cfg.global_step, path, rand_name=False) if B > 4: B = 4 for b in range(B): # -- save dirty & clean & res triplet -- fn = path / Path(f"image_{cfg.global_step}_{b}.png") res = burst[b][N // 2] - denoised[b].mean(0) imgs = torch.stack([burst[b][N // 2], denoised[b].mean(0), res], dim=0) tv_utils.save_image(imgs, fn, nrow=3, normalize=True, range=(-0.5, 0.5)) # -- save images -- fn = path / Path(f"{cfg.global_step}_{b}.png") burst_b = torch.cat([ burst[b][[N // 2]] - burst[b][[0]], burst[b], burst[b][[N // 2]] - burst[b][[-1]] ], dim=0) aligned_b = torch.cat([ aligned[b][[N // 2]] - aligned[b][[0]], aligned[b], aligned[b][[N // 2]] - aligned[b][[-1]] ], dim=0) denoised_b = torch.cat([ denoised[b][[N // 2]] - denoised[b][[0]], denoised[b], denoised[b][[N // 2]] - denoised[b][[-1]] ], dim=0) imgs = torch.cat([burst_b, aligned_b, denoised_b], dim=0) # 2N,C,H,W tv_utils.save_image(imgs, fn, nrow=N + 2, normalize=True, range=(-0.5, 0.5)) # -- save filters -- fn = path / Path(f"filters_{cfg.global_step}_{b}.png") K = int(np.sqrt(filters.shape[3])) L = filters.shape[1] if filters.shape[-1] > 1: S = npr.permutation(filters.shape[-1])[:10] filters_b = filters[b, ..., 0, S, S].view(N * 10 * L, 1, K, K) else: filters_b = filters[b, ..., 0, 0, 0].view(N * L, 1, K, K) tv_utils.save_image(filters_b, fn, nrow=N, normalize=True) # -- save direction image -- fn = path / Path(f"arrows_{cfg.global_step}_{b}.png") if len(motion[b]) > 1 and len(motion[b].shape) > 1: arrows = create_arrow_image(motion[b], pad=2) tv_utils.save_image([arrows], fn) print(f"Wrote example images to file at [{path}]") plt.close("all")