def run_fnet(cfg, burst, score_fxn): model = UNet_small(3) model = model.to(burst.device) optim = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.99)) train(cfg, burst, model, optim, 1000) score, scores_t = test(cfg, burst, model, score_fxn) return score, scores_t
def run_cog(cfg, clean, noisy, directions, results): image_volume = noisy backbone = UNet_small(3).to(cfg.device) nn_params = {'lr': 1e-3, 'init_params': None} train_steps = 1000 score = score_cog(cfg, image_volume, backbone, nn_params, train_steps) results['cog'] = score return
def load_model_v2(cfg): simclr = None denoise_model = UNet_small(3 * cfg.input_N) input_patch, output_patch = cfg.patch_sizes d_model = cfg.d_model_attn xformer_args = [ simclr, d_model, cfg.dynamic.frame_size, input_patch, output_patch, cfg.input_N, cfg.dataset.bw, denoise_model ] model = TransformerNetwork32_noxform(*xformer_args) model = model.to(cfg.device) return model
def load_model_v4(cfg): simclr = None #load_model_simclr(cfg) denoise_model = UNet_small(cfg.d_model_attn * cfg.input_N) # denoise_model = load_denoise_model_fp(cfg,denoise_model) input_patch, output_patch = cfg.patch_sizes # d_model = 2048 // (patch_height * patch_width) # d_model = 398336 // (patch_height * patch_width) d_model = cfg.d_model_attn xformer_args = [ simclr, d_model, cfg.dynamic.frame_size, input_patch, output_patch, cfg.input_N, cfg.dataset.bw, denoise_model ] model = TransformerNetwork32_v4(*xformer_args) model = model.to(cfg.device) return model
def __init__(self, num_frames, num_in_ftrs, num_out_ftrs, num_out_channels, patchsize, nh_size, img_size, attn_params = None): super().__init__() # -- init -- self.num_frames = num_frames self.num_in_ftrs = num_in_ftrs self.num_out_ftrs = num_out_ftrs self.nh_size = nh_size # number of patches around center pixel self._patchsize = patchsize self.img_size = img_size self.layer_norm = nn.LayerNorm(num_out_channels*patchsize**2) # -- create model -- self.unet = UNet_small(num_in_ftrs+3,num_out_channels) # self.unet_out_size = num_out_channels * patchsize**2 patchsize_e = patchsize + (patchsize % 2) # add one dim for odd input size self.unet_out_size = num_out_channels * patchsize_e**2 self.mlp = MLP( self.unet_out_size, num_out_ftrs, hidden_size = 256)
def __init__(self, ftr_model, d_model, im_size, input_patch, output_patch, n_frames, bw=False, denoise_model = None, batch_size = 8): super(TransformerNetwork32_dip_v2, self).__init__() """ Transformer for denoising im_size : int : the length and width of the square input image """ self.d_model = d_model self.ftr_model = {'spoof':ftr_model} print(d_model*n_frames,d_model) self.denoise_model_preproc_a = nn.Conv2d(d_model*n_frames,d_model*n_frames,3,1,1) # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3*n_frames,3,1,1) # self.denoise_model_preproc_b = nn.Conv2d(d_model*n_frames,3,3,1,1) self.denoise_model_preproc_b = nn.Conv2d(d_model,3,3,1,1) self.denoise_model = denoise_model self.std = 5./255 nhead = 4 # 4 num_enc_layers = 2 # 6 num_dec_layers = 2 # 6 dim_ff = 256 # 512 dropout = 0.0 xform_args = [d_model, nhead, num_enc_layers, num_dec_layers, dim_ff, dropout] self.use_pos_enc = True self.use_pos_enc_values = False self.use_all_loss = False self.all_loss_v = "v1" self.color_code = nn.Linear(d_model,3) d_model = 3 # self.perform = Performer(d_model,8,1) # 8,4 # self.xform = nn.Transformer(*xform_args) self.clusters = np.load(f"{settings.ROOT_PATH}/data/kmeans_centers.npy") # -- constrastiv learning loss -- num_transforms = n_frames hyperparams = edict() hyperparams.temperature = 0.1 self.simclr_loss = ClBlockLoss(hyperparams, num_transforms, batch_size) # kwargs = {'num_tokens':512,'max_seq_len':input_patch**2, # 'dim':512, 'depth':6,'heads':4,'causal':False,'cross_attend':False} # self.xform_enc = PerformerLM(**kwargs) # self.xform_dec = PerformerLM(**kwargs) nhead = 1 vdim = 1 if bw else 3 self.pos_enc = PositionalEncoder(3,32*32*n_frames) self.attn_1 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False) self.attn_2 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False) self.attn_3 = nn.MultiheadAttention(3,nhead,dropout,qkv_same_params=False) self.attn = [self.attn_1,self.attn_2,self.attn_3] # self.stn = STN_Net() # -- keep output to be shifted inputs -- # print(dir(self.attn)) # print(self.attn.out_proj) # print(self.attn.v_proj_weight) for i in range(3): self.attn[i].v_proj_weight.data = torch.eye(3) # # print(self.attn.v_proj_weight) self.attn[i].v_proj_weight = self.attn[i].v_proj_weight.requires_grad_(False) # # print('o',self.attn.out_proj.weight) self.attn[i].out_proj.weight.data = torch.eye(3) # # print('o',self.attn.out_proj.weight.data) # # print('b',self.attn.out_proj.bias.data) self.attn[i].out_proj.bias.data = torch.zeros_like(self.attn[i].out_proj.bias.data) # # print('b',self.attn.out_proj.bias.data) self.attn[i].out_proj = self.attn[i].out_proj.requires_grad_(False) # self.attn = nn.MultiheadAttention(d_model,nhead,dropout)#,vdim=vdim) self.input_patch,self.output_patch = input_patch,output_patch self.im_size = im_size ftr_size = (input_patch**2)*d_model img_size = (input_patch**2)*3 in_channels = d_model * n_frames # in_channels = 2304 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,1)]) stride = input_patch // output_patch padding = 1 if stride > 1 else 1 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)]) # -- left settings --; reduce imsize by 2 pix ( 10 in ,8 out ) # stride = 1 # padding = 0 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)]) # -- conv settings; reduce imsize by Half -- ( 16 in, 8 out ) ( A in, A/2 out ) # stride = 2 # padding = 1 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,3,3,stride,padding)]) # -- conv settings; maintain imsize -- ( A in, A out ) stride = 1 padding = 1 out_chann = 1 if bw else 3 # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,out_chann,3,stride,padding)]) # -- using MEAN as denoiser -- in_channels = d_model # self.conv = nn.Sequential(*[nn.Conv2d(in_channels,out_chann,3,stride,padding)]) self.end_conv = nn.Sequential(*[nn.Conv2d(3,3,1),nn.LeakyReLU(),nn.Conv2d(3,3,1)]) self.unet = UNet_small(3) padding = (input_patch - output_patch) // 2 # stride = im_size-input_patch stride = output_patch # print(input_patch,padding,stride) self.unfold_input = nn.Unfold(input_patch,1,padding,stride)
def run_experiment(cfg, data, record_fn, bss_dir): # # Experiment Setup # # -- init variables for access -- T = cfg.nframes H = cfg.nblocks framesize = 156 patchsize = 32 P = 9 REF_H = get_ref_block_index(cfg.nblocks) nn_params = edict({'lr': 1e-3, 'init_params': None}) gridT = torch.arange(T) # -- setup noise -- cfg.noise_type = 'g' cfg.ntype = cfg.noise_type cfg.noise_params.ntype = cfg.noise_type noise_level = 25. cfg.noise_params['g']['stddev'] = noise_level noise_level_str = f"{int(noise_level)}" noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) # -- simulate no motion -- nomotion = np.zeros((T, 2)).astype(np.long) image_index = 10 full_image = data.tr[image_index][2] full_burst, aligned, motion = simulate_dynamics(full_image, T, nomotion, 0, framesize) save_image(full_burst, "full_burst.png") clean_full_ref = full_burst[T // 2] # -- apply noisy -- full_noisy = noise_xform(full_burst) save_image(full_noisy, "full_noisy.png") noisy_full_ref = full_noisy[T // 2] # # Start Method # # -- find good patches from full noisy image -- # init_tl_list = sample_good_init_tl(clean_full_ref,P,patchsize) init_tl_list = sample_good_init_tl(noisy_full_ref, P, patchsize) # -- grab blocks from selected patches for burst -- clean, noisy = [], [] clean = crop_burst_to_blocks(full_burst, cfg.nblocks, init_tl_list[0], patchsize) noisy = crop_burst_to_blocks(full_noisy, cfg.nblocks, init_tl_list[0], patchsize) # -- image for "test" function -- REF_PATCH = 0 # backward compat. for functions without patch-dim support # image = clean[T//2,REF_H,REF_PATCH] # legacy image = noisy[T // 2, REF_H, REF_PATCH] + 0.5 # legacy # -- normalize -- # clean -= clean.min() # clean /= clean.max() print_tensor_stats("clean", clean) print_tensor_stats("noisy", clean) print_tensor_stats("full_burst", full_burst) print_tensor_stats("full_noisy", full_noisy) save_image(clean, "fast_unet_clean.png", normalize=True) save_image(noisy, "fast_unet_noisy.png", normalize=True) save_image(image, "fast_unet_image.png", normalize=True) # -- select search space -- # block_search_space = get_block_arangements_subset(cfg.nblocks,cfg.nframes, # tcount=4,difficult=True) # block_search_space = get_block_arangements(cfg.nblocks,cfg.nframes) use_rand = True if use_rand: tcount = 3 size = 30 bss = get_small_test_block_arangements(bss_dir, cfg.nblocks, cfg.nframes, tcount, size, difficult=True) print("LEN BSS", len(bss)) block_search_space = bss # -- setup loop -- clean = clean.to(cfg.device) noisy = noisy.to(cfg.device) image = image.to(cfg.device) record, idx = [], 0 # -- search over search space -- for prop in tqdm(block_search_space): # -- fetch -- clean_prop = clean[gridT, prop].to(cfg.device) noisy_prop = noisy[gridT, prop].to(cfg.device) save_image(clean_prop[:, 0], "clean_prop.png") save_image(noisy_prop[:, 0], "noisy_prop.png") # -- compute COG -- backbone = UNet_small if (nn_params['init_params'] is None): train_steps = cfg.cog_train_steps else: train_steps = cfg.cog_second_train_steps score = 0. # score = score_cog(cfg,image_volume,backbone,nn_params,train_steps) # -- [legacy] fill results with -1's -- # results = fill_results(cfg,image,clean_prop,noisy_prop,None,idx) # -- compute single UNet -- model = UNet_small(3).to(cfg.device) init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=init_lr, betas=(0.9, 0.99)) train(cfg, image, clean_prop[:, REF_PATCH], noisy_prop[:, REF_PATCH], model, optim) print("noisy_prop.shape", noisy_prop.shape) results = test(cfg, image, clean_prop[:, REF_PATCH], noisy_prop[:, REF_PATCH], model, idx) results['cog'] = score print(score, results['ave'], results['psnr_clean'], prop) # -- update -- record.append(results) idx += 1 record = pd.DataFrame(record) print(f"Writing record to {record_fn}") record.to_csv(record_fn) return record
def single_image_unet(cfg, queue, full_image, device): full_image = full_image.to(device) image = tvF.crop(full_image, 128, 128, 32, 32) T = 5 # -- poisson noise -- noise_type = "pn" cfg.noise_type = noise_type cfg.noise_params['pn']['alpha'] = 40.0 cfg.noise_params['pn']['readout'] = 0.0 cfg.noise_params.ntype = cfg.noise_type noise_xform = get_noise_transform(cfg.noise_params, use_to_tensor=False) clean = torch.stack([image for i in range(T)], dim=0) noisy = noise_xform(clean) save_image(clean, "clean.png", normalize=True) m_clean, m_noisy = clean.clone(), noisy.clone() for i in range(T // 2): image_mis = tvF.crop(full_image, 128 + 1, 128, 32, 32) m_clean[i] = image_mis m_noisy[i] = noise_xform(m_clean[i]) noise_level = 50. / 255. # -- model all -- print("-- All Aligned --") model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean, noisy, model, optim) results = test(cfg, image, clean, noisy, model, 0) rec = model(noisy) + 0.5 save_image(rec, "rec_all.png", normalize=True) print("Single Image Unet:") print(images_to_psnrs(clean, rec)) print(images_to_psnrs(rec - 0.5, noisy)) print(images_to_psnrs(rec[[0]], rec[[1]])) print(images_to_psnrs(rec[[0]], rec[[2]])) print(images_to_psnrs(rec[[1]], rec[[2]])) og_clean, og_noisy = clean.clone(), noisy.clone() clean[0] = image_mis noisy[0] = noise_xform(clean[[0]])[0] # -- model all -- print("-- All Misligned --") model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean, noisy, model, optim) results = test(cfg, image, clean, noisy, model, 0) rec = model(noisy) + 0.5 save_image(rec, "rec_all.png", normalize=True) print("Single Image Unet:") print(images_to_psnrs(clean, rec)) print(images_to_psnrs(rec - 0.5, noisy)) print(images_to_psnrs(rec[[0]], rec[[1]])) print(images_to_psnrs(rec[[0]], rec[[2]])) print(images_to_psnrs(rec[[1]], rec[[2]])) for j in range(3): # -- data -- noisy1 = torch.stack([noisy[0], noisy[1]], dim=0) clean1 = torch.stack([clean[0], clean[1]], dim=0) noisy2 = torch.stack([noisy[1], noisy[2]], dim=0) clean2 = torch.stack([clean[1], clean[2]], dim=0) noisy3 = torch.stack([og_noisy[0], noisy[1]], dim=0) clean3 = torch.stack([og_clean[0], clean[1]], dim=0) # -- model 1 -- model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean1, noisy1, model, optim) results = test(cfg, image, clean1, noisy1, model, 0) rec = model(noisy1) + 0.5 xrec = model(noisy2) + 0.5 save_image(rec, "rec1.png", normalize=True) print("[misaligned] Single Image Unet:", images_to_psnrs(clean1, rec), images_to_psnrs(clean2, xrec), images_to_psnrs(rec - 0.5, noisy1), images_to_psnrs(xrec - 0.5, noisy2), images_to_psnrs(rec[[0]], rec[[1]]), images_to_psnrs(xrec[[0]], xrec[[1]]), images_to_psnrs(xrec[[0]], rec[[1]]), images_to_psnrs(xrec[[1]], rec[[0]])) # -- model 2 -- model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean2, noisy2, model, optim) results = test(cfg, image, clean2, noisy2, model, 0) rec = model(noisy2) + 0.5 xrec = model(noisy1) + 0.5 save_image(rec, "rec2.png", normalize=True) print("[aligned] Single Image Unet:", images_to_psnrs(clean2, rec), images_to_psnrs(clean1, xrec), images_to_psnrs(rec - 0.5, noisy2), images_to_psnrs(xrec - 0.5, noisy1), images_to_psnrs(rec[[0]], rec[[1]]), images_to_psnrs(xrec[[0]], xrec[[1]]), images_to_psnrs(xrec[[0]], rec[[1]]), images_to_psnrs(xrec[[1]], rec[[0]])) # -- model 3 -- model = UNet_small(3) # UNet_n2n(1) cfg.init_lr = 1e-4 optim = torch.optim.Adam(model.parameters(), lr=cfg.init_lr, betas=(0.9, 0.99)) train(cfg, image, clean3, noisy3, model, optim) results = test(cfg, image, clean3, noisy3, model, 0) rec = model(noisy3) + 0.5 rec_2 = model(noisy2) + 0.5 rec_1 = model(noisy1) + 0.5 save_image(rec, "rec1.png", normalize=True) print("[aligned (v3)] Single Image Unet:") print("clean-rec", images_to_psnrs(clean3, rec)) print("clean1-rec1", images_to_psnrs(clean1, rec_1)) print("clean2-rec2", images_to_psnrs(clean2, rec_2)) print("rec-noisy3", images_to_psnrs(rec - 0.5, noisy3)) print("rec1-noisy1", images_to_psnrs(rec_1 - 0.5, noisy1)) print("rec2-noisy2", images_to_psnrs(rec_2 - 0.5, noisy2)) print("[v3]: rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]])) print("[v1]: rec0-rec1", images_to_psnrs(rec_1[[0]], rec_1[[1]])) print("[v2]: rec0-rec1", images_to_psnrs(rec_2[[0]], rec_2[[1]])) print("[v1-v2](a):", images_to_psnrs(rec_1[[1]], rec_2[[1]])) print("[v1-v2](b):", images_to_psnrs(rec_1[[1]], rec_2[[0]])) print("[v1-v2](c):", images_to_psnrs(rec_1[[0]], rec_2[[1]])) print("[v1-v2](d):", images_to_psnrs(rec_1[[0]], rec_2[[0]])) print("-" * 20) print("[v2-v3](a):", images_to_psnrs(rec_2[[1]], rec[[1]])) print("[v2-v3](b):", images_to_psnrs(rec_2[[1]], rec[[0]])) print("[v2-v3](c):", images_to_psnrs(rec_2[[0]], rec[[1]])) print("[v2-v3](d):", images_to_psnrs(rec_2[[0]], rec[[0]])) print("-" * 20) print("[v1-v3](a):", images_to_psnrs(rec_1[[1]], rec[[1]])) print("[v1-v3](b):", images_to_psnrs(rec_1[[1]], rec[[0]])) print("[v1-v3](c):", images_to_psnrs(rec_1[[0]], rec[[1]])) print("[v1-v3](d):", images_to_psnrs(rec_1[[0]], rec[[0]])) print("-" * 20) print("rec0-rec1", images_to_psnrs(rec[[0]], rec[[1]]))