def GlowInpaint(args): loopOver = zip(args.gamma) for gamma in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size * args.size * 3 modeldir = "./trained_models/%s/glow" % args.model test_folder = "./test_images/%s" % args.dataset save_path = "./results/%s/%s" % (args.dataset, args.experiment) # loading dataset trans = transforms.Compose( [transforms.Resize((args.size, args.size)), transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batchsize, drop_last=False, shuffle=False) # loading glow configurations config_path = modeldir + "/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) # getting test images Original = [] Recovered = [] Masked = [] Mask = [] Residual_Curve = [] for i, data in enumerate(test_dataloader): # getting batch of data x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize" # generate mask mask = gen_mask(args.inpaint_method, args.size, args.mask_size) mask = np.array([mask for i in range(n_test)]) mask = mask.reshape([n_test, 1, args.size, args.size]) mask = torch.tensor(mask, dtype=torch.float, requires_grad=False, device=args.device) # loading glow model glow = Glow((3, args.size, args.size), K=configs["K"], L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt")) glow.eval() # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # initializing z from Gaussian if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) # initializing z from image with noise filled only in masked region elif args.init_strategy == "noisy_filled": x_noisy_filled = x_test.clone().detach() noise = np.random.normal(0, 0.2, x_noisy_filled.size()) noise = torch.tensor(noise, dtype=torch.float, device=args.device) noise = noise * (1 - mask) x_noisy_filled = x_noisy_filled + noise x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1) z, _, _ = glow(x_noisy_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from image with masked region inverted elif args.init_strategy == "inverted_filled": x_inverted_filled = x_test.clone().detach() missing_x = x_inverted_filled.clone() missing_x = missing_x.data.cpu().numpy() missing_x = missing_x[:, :, ::-1, ::-1] missing_x = torch.tensor(missing_x.copy(), dtype=torch.float, device=args.device) missing_x = (1 - mask) * missing_x x_inverted_filled = x_inverted_filled * mask x_inverted_filled = x_inverted_filled + missing_x z, _, _ = glow(x_inverted_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from masked image ( masked region as zeros ) elif args.init_strategy == "black_filled": x_black_filled = x_test.clone().detach() x_black_filled = mask * x_black_filled x_black_filled = x_black_filled * mask z, _, _ = glow(x_black_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from noisy complete image elif args.init_strategy == "noisy": x_noisy = x_test.clone().detach() noise = np.random.normal(0, 0.05, x_noisy.size()) noise = torch.tensor(noise, dtype=torch.float, device=args.device) x_noisy = x_noisy + noise x_noisy = torch.clamp(x_noisy, 0, 1) z, _, _ = glow(x_noisy - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from image with only noise in masked region elif args.init_strategy == "only_noise_filled": x_noisy_filled = x_test.clone().detach() noise = np.random.normal(0, 0.2, x_noisy_filled.size()) noise = torch.tensor(noise, dtype=torch.float, device=args.device) noise = noise * (1 - mask) x_noisy_filled = mask * x_noisy_filled + noise x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1) z, _, _ = glow(x_noisy_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) else: raise "Initialization strategy not defined" # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam( [z_sampled], lr=args.lr, ) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS( [z_sampled], lr=args.lr, ) # metrics to record over training psnr_t = torch.nn.MSELoss().to(device=args.device) residual = [] # running optimizer steps for t in range(args.steps): def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_masked_test = x_test * mask x_masked_gen = x_gen * mask global residual_t residual_t = ((x_masked_gen - x_masked_test)**2).view( len(x_masked_test), -1).sum(dim=1).mean() z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean() loss_t = residual_t + z_reg_loss_t psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print( "\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f" % (t, loss_t.item(), residual_t.item(), z_reg_loss_t.item(), psnr), end="\r") loss_t.backward() return loss_t try: optimizer.step(closure) residual.append(residual_t.item()) except: skip_to_next = True break if skip_to_next: break # getting recovered and true images x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) mask_np = mask.data.cpu().numpy() x_masked_test = x_test * mask x_masked_test_np = x_masked_test.data.cpu().numpy().transpose( 0, 2, 3, 1) x_masked_test_np = np.clip(x_masked_test_np, 0, 1) Original.append(x_test_np) Recovered.append(x_gen_np) Masked.append(x_masked_test_np) Residual_Curve.append(residual) Mask.append(mask_np) # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, mask, torch.cuda.empty_cache() print("\nbatch completed") if skip_to_next: print( "\nskipping current loop due to instability or user triggered quit" ) continue # metric evaluations Original = np.vstack(Original) Recovered = np.vstack(Recovered) Masked = np.vstack(Masked) Mask = np.vstack(Mask) psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)] # print performance analysis printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n" printout = printout + "\t n_test = %d\n" % len(Recovered) printout = printout + "\t inpaint_method = %s\n" % args.inpaint_method printout = printout + "\t mask_size = %0.3f\n" % args.mask_size printout = printout + "\t gamma = %0.6f\n" % gamma printout = printout + "\t PSNR = %0.3f\n" % np.mean(psnr) print(printout) if args.save_metrics_text: with open("%s_inpaint_glow_results.txt" % args.dataset, "a") as f: f.write('\n' + printout) # saving images if args.save_results: gamma = gamma.item() file_names = [ name[0].split("/")[-1].split(".")[0] for name in test_dataset.samples ] if args.init_strategy == 'random': save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s" save_path = save_path % (args.inpaint_method, args.mask_size, gamma, args.steps, args.lr, args.init_std, args.optim) else: save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s" save_path = save_path % (args.inpaint_method, args.mask_size, gamma, args.steps, args.lr, args.init_strategy, args.optim) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 _ = [ sio.imsave(save_path + "/" + name + "_recov.jpg", x) for x, name in zip(Recovered, file_names) ] _ = [ sio.imsave(save_path + "/" + name + "_masked.jpg", x) for x, name in zip(Masked, file_names) ] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path + "/" + "residual_curve.npy", Residual_Curve) np.save(save_path + "/original.npy", Original) np.save(save_path + "/recovered.npy", Recovered) np.save(save_path + "/mask.npy", Mask) np.save(save_path + "/masked.npy", Masked)
def trainGlow(args): save_path = "./trained_models/%s/glow" % args.dataset training_folder = "./data/%s_preprocessed/train" % args.dataset # setting up configs as json config_path = save_path + "/configs.json" configs = { "K": args.K, "L": args.L, "coupling": args.coupling, "last_zeros": args.last_zeros, "batchsize": args.batchsize, "size": args.size, "lr": args.lr, "n_bits_x": args.n_bits_x, "warmup_iter": args.warmup_iter } if (args.squeeze_contig): configs["squeeze_contig"] = True if (args.coupling_bias > 0): configs["coupling_bias"] = args.coupling_bias if not os.path.exists(save_path): print("creating directory to save model weights") os.makedirs(save_path) # loading pre-trained model to resume training if os.path.exists(save_path + "/glowmodel.pt"): print( "loading previous model and saved configs to resume training ...") with open(config_path, 'r') as f: configs = json.load(f) glow = Glow((3, configs["size"], configs["size"]), device=args.device, **configs) glow.load_state_dict(torch.load(save_path + "/glowmodel.pt")) print("pre-trained model and configs loaded successfully") glow.set_actnorm_init() print( "actnorm initialization flag set to True to avoid data dependant re-initialization" ) glow.train() else: # creating and initializing glow model print("creating and initializing model for training") glow = Glow((3, args.size, args.size), K=args.K, L=args.L, coupling=args.coupling, n_bits_x=args.n_bits_x, nn_init_last_zeros=args.last_zeros, device=args.device) glow.train() print("saving configs as json file") with open(config_path, 'w') as f: json.dump(configs, f, sort_keys=True, indent=4, ensure_ascii=False) # setting up dataloader print("setting up dataloader for the training data") trans = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop((args.size, args.size)), transforms.ToTensor() ]) dataset = datasets.ImageFolder(training_folder, transform=trans) dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batchsize, drop_last=True, shuffle=True) # setting up optimizer and learning rate scheduler opt = torch.optim.Adam(glow.parameters(), lr=args.lr) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=1000, verbose=True, min_lr=1e-8) # starting training code here print("+-" * 10, "starting training", "-+" * 10) global_step = 0 global_loss = [] warmup_completed = False for i in range(args.epochs): Loss_epoch = [] for j, data in enumerate(dataloader): opt.zero_grad() glow.zero_grad() # loading batch x = data[0].to(device=args.device) * 255 # pre-processing data x = glow.preprocess(x) # computing loss: "nll" n, c, h, w = x.size() nll, logdet, logpz, z_mu, z_std = glow.nll_loss(x) # skipping first batch due to data dependant initialization (if not initialized) if global_step == 0: global_step += 1 continue # backpropogating loss and gradient clipping nll.backward() torch.nn.utils.clip_grad_value_(glow.parameters(), 5) grad_norm = torch.nn.utils.clip_grad_norm_(glow.parameters(), 100) # linearly increase learning rate till warmup_iter upto args.lr if global_step <= args.warmup_iter: warmup_lr = args.lr / args.warmup_iter * global_step for params in opt.param_groups: params["lr"] = warmup_lr # taking optimizer step opt.step() # learning rate scheduling after warm up iterations if global_step > args.warmup_iter: lr_scheduler.step(nll) if not warmup_completed: if args.warmup_iter == 0: print("no model warming...") else: print("\nwarm up completed") warmup_completed = True # printing training metrics print( "\repoch=%0.2d..nll=%0.2f..logdet=%0.2f..logpz=%0.2f..mu=%0.2f..std=%0.2f..gradnorm=%0.2f" % (i, nll.item(), logdet, logpz, z_mu, z_std, grad_norm), end="\r") # saving generated samples during training try: if j % args.sample_freq == 0: plt.plot(global_loss) plt.xlabel("iterations", size=15) plt.ylabel("nll", size=15) plt.savefig(save_path + "/nll_training_curve.jpg") plt.close() with torch.no_grad(): z_sample, z_sample_t = glow.generate_z(n=10, mu=0, std=0.7, to_torch=True) x_gen = glow(z_sample_t, reverse=True) x_gen = glow.postprocess(x_gen) x_gen = make_grid(x_gen, nrow=int(np.sqrt(len(x_gen)))) x_gen = x_gen.data.cpu().numpy() x_gen = x_gen.transpose([1, 2, 0]) if x_gen.shape[-1] == 1: x_gen = x_gen[..., 0] if not os.path.exists(save_path + "/samples_training"): os.makedirs(save_path + "/samples_training") x_gen = (np.clip(x_gen, 0, 1) * 255).astype("uint8") sio.imsave( save_path + "/samples_training/%0.6d.jpg" % global_step, x_gen) except: print("\n failed to sample from glow at global step = %d" % global_step) global_step = global_step + 1 global_loss.append(nll.item()) if global_step % args.save_freq == 0: torch.save(glow.state_dict(), save_path + "/glowmodel.pt") # # model visualization # temperature = [0.1,0.3,0.4,0.5,0.7,0.8, 0.9] # for temp in temperature: # with torch.no_grad(): # glow.eval() # z_sample, z_sample_t = glow.generate_z(n=10,mu=0,std=temp,to_torch=True) # x_gen = glow(z_sample_t, reverse=True) # x_gen = glow.postprocess(x_gen) # x_gen = make_grid(x_gen,nrow=int(np.sqrt(len(x_gen)))) # x_gen = x_gen.data.cpu().numpy() # x_gen = x_gen.transpose([1,2,0]) # if x_gen.shape[-1] == 1: # x_gen = x_gen[...,0] # plt.figure() # plt.title("temperature = %0.1f"%temp,fontsize=15) # plt.axis("off") # plt.imshow(x_gen) # saving model weights torch.save(glow.state_dict(), save_path + "/glowmodel.pt")
def GlowDenoiser(args): loopOver = zip(args.gamma) # try different gamma values for gamma in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size*args.size*3 # modeldir = f"./trained_models/{args.dataset}/glow-denoising" modeldir = f"./trained_models/{args.dataset}/glow-cs-{args.size}" test_folder = f"./test_images/{args.dataset}_N=12" save_path = f"./results/{args.dataset}/{args.experiment}" # loading dataset trans = transforms.Compose([transforms.Resize((args.size,args.size)),transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=args.batchsize,drop_last=False,shuffle=False) # loading glow configurations config_path = modeldir+"/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # noiser noiser = Noiser(args, configs) # loss loss = recon_loss(args.noise, args.noise_loc, args.noise_scale) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) # getting test images Original = [] Recovered = [] Noisy = [] Noise = [] Residual_Curve = [] for i, data in enumerate(test_dataloader): x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, \ "please make sure that no. of images are evenly divided by batchsize" # loading glow model glow = Glow((3,args.size,args.size), K=configs["K"],L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir+"/glowmodel.pt", map_location=args.device)) glow.eval() # add noise noise = noiser(x_test) x_noisy = x_test + noise x_noisy = torch.clamp(x_noisy, 0., 1.) # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # np.random.seed(args.random_seed) if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) elif args.init_strategy == "from-noisy": z, _, _ = glow(glow.preprocess(x_noisy*255,clone=True)) z = glow.flatten_z(z) z_sampled = z.clone().detach().cpu().numpy() else: raise NotImplementedError("Initialization strategy not defined") z_sampled = torch.from_numpy(z_sampled).float().to(args.device) z_sampled = nn.Parameter(z_sampled, requires_grad=True) # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam([z_sampled], lr=args.lr,) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS([z_sampled], lr=args.lr,) # to be recorded over iteration psnr_t = torch.nn.MSELoss().to(device=args.device) residual = [] # running optimizer steps for t in range(args.steps): try: def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen,floor_clamp=False) global residual_t residual_t = loss(x_gen, x_noisy) if not args.z_penalty_unsquared: z_reg_loss_t= gamma*(z_sampled.norm(dim=1)**2).mean() else: z_reg_loss_t= gamma*z_sampled.norm(dim=1).mean() loss_t = residual_t + z_reg_loss_t psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print("\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f"%( t, loss_t.item(), residual_t.item(), z_reg_loss_t.item(), psnr), end="\r") loss_t.backward(retain_graph=True) return loss_t optimizer.step(closure) residual.append(residual_t.item()) except Exception as e: traceback.print_exc() skip_to_next = True break # if skip_to_next: # break x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) Original.append(x_test_np) noise_np = noise.data.cpu().numpy().transpose(0, 2, 3, 1) Noise.append(noise_np) x_noisy_np = x_noisy.data.cpu().numpy().transpose(0, 2, 3, 1) Noisy.append(x_noisy_np) Residual_Curve.append(residual) x_gen = None try: z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) Recovered.append(x_gen_np) except Exception as e: traceback.print_exc() # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, noise with torch.cuda.device(args.device): torch.cuda.empty_cache() print("\nbatch completed") # todo: remove this break after finishing development. break # if skip_to_next: # print("\nskipping current loop due to instability or user triggered quit") # continue # metric evaluations Original = np.vstack(Original) Noisy = np.vstack(Noisy) Noise = np.vstack(Noise) psnr = None try: Recovered = np.vstack(Recovered) psnr = [compare_psnr(x, y) for x,y in zip(Original, Recovered)] except Exception as e: continue # print performance analysis printout = "+-"*10 + "%s"%args.dataset + "-+"*10 + "\n" printout = printout + "\t n_test = %d\n"%len(Recovered) printout = printout + "\t noise_std = %0.4f\n"%args.noise_scale printout = printout + "\t gamma = %0.6f\n"%gamma if psnr is not None: printout = printout + "\t PSNR = %0.3f\n"%np.mean(psnr) print(printout) if args.save_metrics_text: with open("%s_denoising_glow_results.txt"%args.dataset,"a") as f: f.write('\n' + printout) # saving images if args.save_results: gamma = gamma.item() file_names = [name[0].split("/")[-1].split(".")[0] for name in test_dataset.samples] save_path = os.path.join(save_path, f'{args.noise}_' f'{args.noise_loc}#{args.noise_scale}_' f'{args.noise_channel}_{args.noise_area}_' f'{args.init_strategy}_' f'{round(gamma, 4)}_{gettime()}') print(save_path) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 _ = [sio.imsave(save_path+"/"+name+"_noisy.jpg", x) for x, name in zip(Noisy, file_names)] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path+"/residual_curve.npy", Residual_Curve) np.save(save_path+"/original.npy", Original) np.save(save_path+"/noisy.npy", Noisy) np.save(save_path+"/noise.npy", Noise) if len(Recovered) > 0: np.save(save_path + "/recovered.npy", Recovered) _ = [sio.imsave(save_path + "/" + name + "_recov.jpg", x) for x, name in zip(Recovered, file_names)]
def GlowREDCS(args, filename=None): if args.init_norms == None: args.init_norms = [None] * len(args.m) else: assert args.init_strategy == "random_fixed_norm", "init_strategy should be random_fixed_norm if init_norms is used" assert len(args.m) == len(args.gamma) == len( args.init_norms), "length of either m, gamma or init_norms are not same" loopOver = zip(args.m, args.gamma, args.init_norms) for m, gamma, init_norm in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size * args.size * 3 modeldir = "./trained_models/%s/glow" % args.model test_folder = "./test_images/%s" % args.dataset save_path = "./results/%s/%s" % (args.dataset, args.experiment) # loading dataset trans = transforms.Compose([transforms.Resize((args.size, args.size)), transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, drop_last=False, shuffle=False) # loading glow configurations config_path = modeldir + "/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # sensing matrix A = np.random.normal(0, 1 / np.sqrt(m), size=(n, m)) A = torch.tensor(A, dtype=torch.float, requires_grad=False, device=args.device) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) alpha = args.alpha beta = args.beta # adding noise if args.noise == "random_bora": noise = np.random.normal(0, 1, size=(args.batchsize, m)) noise = noise * 0.1 / np.sqrt(m) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) else: noise = np.random.normal(0, 1, size=(args.batchsize, m)) noise = noise / (np.linalg.norm(noise, 2, axis=-1, keepdims=True)) * float(args.noise) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) # start solving over batches Original = []; Recovered = []; Recovered_f = []; Z_Recovered = []; Residual_Curve = []; Recorded_Z = [] for i, data in enumerate(test_dataloader): x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize" # loading glow model glow = Glow((3, args.size, args.size), K=configs["K"], L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt")) glow.eval() # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # initializing z from Gaussian with std equal to init_std if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) # intializing z from Gaussian and scaling its norm to init_norm elif args.init_strategy == "random_fixed_norm": z_sampled = np.random.normal(0, 1, [n_test, n]) z_sampled = z_sampled / np.linalg.norm(z_sampled, axis=-1, keepdims=True) z_sampled = z_sampled * init_norm z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) print("z intialized with a norm equal to = %0.1f" % init_norm) # initializing z from pseudo inverse elif args.init_strategy == "pseudo_inverse": x_test_flat = x_test.view([-1, n]) y_true = torch.matmul(x_test_flat, A) + noise A_pinv = torch.pinverse(A) x_pinv = torch.matmul(y_true, A_pinv) x_pinv = x_pinv.view([-1, 3, args.size, args.size]) x_pinv = torch.clamp(x_pinv, 0, 1) z, _, _ = glow(glow.preprocess(x_pinv * 255, clone=True)) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) # initializing z from a solution of lasso-wavelet elif args.init_strategy == "lasso_wavelet": new_args = {"batch_size": n_test, "lmbd": 0.01, "lasso_solver": "sklearn"} new_args = easydict.EasyDict(new_args) estimator = celebA_estimators.lasso_wavelet_estimator(new_args) x_ch_last = x_test.permute(0, 2, 3, 1) x_ch_last = x_ch_last.contiguous().view([-1, n]) y_true = torch.matmul(x_ch_last, A) + noise x_lasso = estimator(np.sqrt(2 * m) * A.data.cpu().numpy(), np.sqrt(2 * m) * y_true.data.cpu().numpy(), new_args) x_lasso = np.array(x_lasso) x_lasso = x_lasso.reshape(-1, 64, 64, 3) x_lasso = x_lasso.transpose(0, 3, 1, 2) x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device) z, _, _ = glow(x_lasso - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) print("z intialized from a solution of lasso-wavelet") elif args.init_strategy == "lasso_dct": new_args = {"batch_size": n_test, "lmbd": 0.01, "lasso_solver": "sklearn"} new_args = easydict.EasyDict(new_args) estimator = celebA_estimators.lasso_dct_estimator(new_args) x_ch_last = x_test.permute(0, 2, 3, 1) x_ch_last = x_ch_last.contiguous().view([-1, n]) y_true = torch.matmul(x_ch_last, A) + noise x_lasso = estimator(np.sqrt(2 * m) * A.data.cpu().numpy(), np.sqrt(2 * m) * y_true.data.cpu().numpy(), new_args) x_lasso = np.array(x_lasso) x_lasso = x_lasso.reshape(-1, 64, 64, 3) x_lasso = x_lasso.transpose(0, 3, 1, 2) x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device) z, _, _ = glow(x_lasso - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) print("z intialized from a solution of lasso-dct") elif args.init_strategy == "random_lasso_dct": new_args = {"batch_size": n_test, "lmbd": 0.01, "lasso_solver": "sklearn"} new_args = easydict.EasyDict(new_args) estimator = celebA_estimators.lasso_dct_estimator(new_args) x_ch_last = x_test.permute(0, 2, 3, 1) x_ch_last = x_ch_last.contiguous().view([-1, n]) y_true = torch.matmul(x_ch_last, A) + noise x_lasso = estimator(np.sqrt(2 * m) * A.data.cpu().numpy(), np.sqrt(2 * m) * y_true.data.cpu().numpy(), new_args) x_lasso = np.array(x_lasso) x_lasso = x_lasso.reshape(-1, 64, 64, 3) x_lasso = x_lasso.transpose(0, 3, 1, 2) x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device) z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) print("z intialized randomly and RED is initialized from a solution of lasso-dct") # intializing z from null(A) elif args.init_strategy == "null_space": x_test_flat = x_test.view([-1, n]) x_test_flat_np = x_test_flat.data.cpu().numpy() A_np = A.data.cpu().numpy() nullA = null_space(A_np.T) coeff = np.random.normal(0, 1, (args.batchsize, nullA.shape[1])) x_null = np.array([(nullA * c).sum(axis=-1) for c in coeff]) pert_norm = 5 # <-- 5 gives optimal results -- bad initialization and not too unstable x_null = x_null / np.linalg.norm(x_null, axis=1, keepdims=True) * pert_norm x_perturbed = x_test_flat_np + x_null # no clipping x_perturbed to make sure forward model is ||y-Ax|| is the same err = np.matmul(x_test_flat_np, A_np) - np.matmul(x_perturbed, A_np) assert (err ** 2).sum() < 1e-6, "null space does not satisfy ||y-A(x+x0)|| <= 1e-6" x_perturbed = x_perturbed.reshape(-1, 3, args.size, args.size) x_perturbed = torch.tensor(x_perturbed, dtype=torch.float, device=args.device) z, _, _ = glow(x_perturbed - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) print("z initialized from a point in null space of A") else: raise "Initialization strategy not defined" # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam([z_sampled], lr=args.lr, ) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS([z_sampled], lr=args.lr, ) elif args.optim == "rk2": optimizer = RK2Heun([z_sampled], lr=args.lr) elif args.optim == "raghav": optimizer = RK2Raghav([z_sampled], lr=args.lr) elif args.optim == "sgd": optimizer = torch.optim.SGD([z_sampled], lr=args.lr, momentum=0.9) else: raise "optimizer not defined" # to be recorded over iteration psnr_t = torch.nn.MSELoss().to(device=args.device) residual = []; recorded_z = [] x_f = x_lasso.clone() u = torch.zeros_like(x_test) df_losses = pd.DataFrame(columns=["loss_t", "residual_t", "residual_x", "z_reg_loss"]) ################## alpha = args.alpha beta = args.beta ################## # running optimizer steps for t in range(args.steps): def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_test_flat = x_test.view([-1, n]) x_gen_flat = x_gen.view([-1, n]) y_true = torch.matmul(x_test_flat, A) + noise y_gen = torch.matmul(x_gen_flat, A) global residual_t residual_t = ((y_gen - y_true) ** 2).sum(dim=1).mean() z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean() residual_x = beta * ((x_gen - (x_f - u)) ** 2).view(len(x_f), -1).sum(dim=1).mean() loss_t = residual_t + z_reg_loss_t + residual_x psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print("At step=%0.3d|loss=%0.4f|residual_t=%0.4f|residual_x=%0.4f|z_reg=%0.5f|psnr=%0.3f" % ( t, loss_t.item(), residual_t.item(), residual_x.item(), z_reg_loss_t.item(), psnr)) loss_t.backward() update = [loss_t.item(), residual_t.item(), residual_x.item(), z_reg_loss_t.item()] df_losses.loc[(len(df_losses))] = update df_losses.to_csv(filename) return loss_t def denoiser_step(x_f, u): z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False).detach() x_gen = glow.postprocess(x_gen, floor_clamp=False) x_f = 1 / (beta + alpha) * (beta * Denoiser(args.denoiser, args.sigma_f, x_f) + alpha * (x_gen + u)) u = u + x_gen - x_f return x_f, u optimizer.step(closure) recorded_z.append(z_sampled.data.cpu().numpy()) residual.append(residual_t.item()) if t % args.update_iter == args.update_iter - 1: x_f, u = denoiser_step(x_f, u) # if t == args.steps//2: # gamma /= 10 # try: # optimizer.step(closure) # recorded_z.append(z_sampled.data.cpu().numpy()) # residual.append(residual_t.item()) # except: # # try may not work due to instability in the reverse direction. # skip_to_next = True # break if skip_to_next: break # getting recovered and true images with torch.no_grad(): x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) x_f_np = x_f.cpu().numpy().transpose(0, 2, 3, 1) x_f_np = np.clip(x_f_np, 0, 1) z_recov = z_sampled.data.cpu().numpy() Original.append(x_test_np) Recovered.append(x_gen_np) Recovered_f.append(x_f_np) Z_Recovered.append(z_recov) Residual_Curve.append(residual) Recorded_Z.append(recorded_z) # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow torch.cuda.empty_cache() print("\nbatch completed") if skip_to_next: print("\nskipping current loop due to instability or user triggered quit") continue # collecting everything together Original = np.vstack(Original) Recovered = np.vstack(Recovered) Recovered_f = np.vstack(Recovered_f) Z_Recovered = np.vstack(Z_Recovered) Recorded_Z = np.vstack(Recorded_Z) psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)] psnr_f = [compare_psnr(x, y) for x, y in zip(Original, Recovered_f)] z_recov_norm = np.linalg.norm(Z_Recovered, axis=-1) # print performance analysis printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n" printout = printout + "\t n_test = %d\n" % len(Recovered) printout = printout + "\t n = %d\n" % n printout = printout + "\t m = %d\n" % m printout = printout + "\t update_iter = %0.4f\n" % args.update_iter printout = printout + "\t gamma = %0.6f\n" % gamma printout = printout + "\t alpha = %0.6f\n" % alpha printout = printout + "\t beta = %0.6f\n" % beta printout = printout + "\t optimizer = %s\n" % args.optim printout = printout + "\t lr = %0.3f\n" % args.lr printout = printout + "\t steps = %0.3f\n" % args.steps printout = printout + "\t init_strategy = %s\n" % args.init_strategy printout = printout + "\t init_std = %0.3f\n" % args.init_std if init_norm is not None: printout = printout + "\t init_norm = %0.3f\n" % init_norm printout = printout + "\t z_recov_norm = %0.3f\n" % np.mean(z_recov_norm) printout = printout + "\t mean PSNR = %0.3f\n" % (np.mean(psnr)) printout = printout + "\t mean PSNR_f = %0.3f\n" % (np.mean(psnr_f)) print(printout) # saving printout if args.save_metrics_text: with open("%s_cs_glow_results.txt" % args.dataset, "a") as f: f.write('\n' + printout) # setting folder to save results in if args.save_results: gamma = gamma.item() file_names = [name[0].split("/")[-1] for name in test_dataset.samples] if args.init_strategy == "random": save_path_template = save_path + "/cs_m_%d_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s" save_path = save_path_template % (m, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_std, args.optim) elif args.init_strategy == "random_fixed_norm": save_path_template = save_path + "/cs_m_%d_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_%s_%0.3f_optim_%s" save_path = save_path_template % ( m, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_strategy, init_norm, args.optim) else: save_path_template = save_path + "/cs_m_%d_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s" save_path = save_path_template % (m, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_strategy, args.optim) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 # saving results now _ = [sio.imsave(save_path + "/" + name, x) for x, name in zip(Recovered, file_names)] print(save_path+"/"+file_names[0]) _ = [sio.imsave(save_path + "/f_" + name, x) for x, name in zip(Recovered_f, file_names)] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path + "/original.npy", Original) np.save(save_path + "/recovered.npy", Recovered) np.save(save_path + "/recovered_f.npy", Recovered_f) np.save(save_path + "/z_recovered.npy", Z_Recovered) np.save(save_path + "/residual_curve.npy", Residual_Curve) if init_norm is not None: np.save(save_path + "/Recorded_Z_init_norm_%d.npy" % init_norm, Recorded_Z) torch.cuda.empty_cache()
def GlowREDDenoiser(args): loopOver = zip(args.gamma) for gamma in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size * args.size * 3 modeldir = "./trained_models/%s/glow" % args.model test_folder = "./test_images/%s" % args.dataset save_path = "./results/%s/%s" % (args.dataset, args.experiment) # loading dataset trans = transforms.Compose( [transforms.Resize((args.size, args.size)), transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batchsize, drop_last=False, shuffle=False) # loading glow configurations config_path = modeldir + "/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) alpha = args.alpha beta = args.beta # getting test images gen_steps = [] Original = [] Recovered = [] Noisy = [] Residual_Curve = [] for i, data in enumerate(test_dataloader): # getting batch of data x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize" # noise to be added if args.noise == "gaussian": noise = np.random.normal(0, args.noise_std, size=(n_test, 3, args.size, args.size)) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) elif args.noise == "laplacian": noise = np.random.laplace(scale=args.noise_std, size=(n_test, 3, args.size, args.size)) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) raise "code only supports gaussian for now" # -> no noise type tag in the folder name else: raise "noise type not defined" # loading glow model glow = Glow((3, args.size, args.size), K=configs["K"], L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt")) glow.eval() # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # initializing z from Gaussian if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) # initializing z from noisy image elif args.init_strategy == "from-noisy": x_noisy = x_test + noise z, _, _ = glow(glow.preprocess(x_noisy * 255, clone=True)) z = glow.flatten_z(z) z_sampled = z.clone().detach().requires_grad_(True) else: raise "Initialization strategy not defined" # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam( [z_sampled], lr=args.lr, ) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS( [z_sampled], lr=args.lr, ) # to be recorded over iteration psnr_t = torch.nn.MSELoss().to(device=args.device) residual = [] x_f = (x_test + noise).clone() u = torch.zeros_like(x_test) # running optimizer steps for t in range(args.steps): def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_noisy = x_test + noise global residual_t residual_t = ((x_gen - x_noisy)**2).view( len(x_noisy), -1).sum(dim=1).mean() if args.z_penalty_squared: z_reg_loss_t = gamma * (z_sampled.norm(dim=1)** 2).mean() else: z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean() residual_x = beta * ((x_gen - (x_f - u))**2).view( len(x_noisy), -1).sum(dim=1).mean() loss_t = residual_t + z_reg_loss_t + residual_x psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print( "\rAt step=%0.3d|loss=%0.4f|residual_t=%0.4f|residual_x=%0.4f|z_reg=%0.5f|psnr=%0.3f" % (t, loss_t.item(), residual_t.item(), residual_x.item(), z_reg_loss_t.item(), psnr), end="\r") loss_t.backward() return loss_t def denoiser_step(x_f, u): z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False).detach() x_gen = glow.postprocess(x_gen, floor_clamp=False) x_f = 1 / (beta + alpha) * ( beta * Denoiser(args.denoiser, args.sigma_f, x_f) + alpha * (x_gen + u)) u = u + x_gen - x_f return x_f, u optimizer.step(closure) residual.append(residual_t.item()) if t % args.update_iter == args.update_iter - 1: x_f, u = denoiser_step(x_f, u) with torch.no_grad(): z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) gen_steps.append(x_gen_np) # try: # optimizer.step(closure) # residual.append(residual_t.item()) # if t % args.update_iter == 0: # x_f, u = denoiser_step(x_f, u) # # except: # skip_to_next = True # break if skip_to_next: break # getting recovered and true images x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) x_noisy = x_test + noise x_noisy_np = x_noisy.data.cpu().numpy().transpose(0, 2, 3, 1) x_noisy_np = np.clip(x_noisy_np, 0, 1) Original.append(x_test_np) Recovered.append(x_gen_np) Noisy.append(x_noisy_np) Residual_Curve.append(residual) # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, noise, torch.cuda.empty_cache() print("\nbatch completed") if skip_to_next: print( "\nskipping current loop due to instability or user triggered quit" ) continue # metric evaluations Original = np.vstack(Original) Recovered = np.vstack(Recovered) gen_steps = np.vstack(gen_steps) Noisy = np.vstack(Noisy) psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)] # print performance analysis printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n" printout = printout + "\t n_test = %d\n" % len(Recovered) printout = printout + "\t noise_std = %0.4f\n" % args.noise_std printout = printout + "\t update_iter= %0.4f\n" % args.update_iter printout = printout + "\t gamma = %0.6f\n" % gamma printout = printout + "\t alpha = %0.6f\n" % alpha printout = printout + "\t beta = %0.6f\n" % beta printout = printout + "\t PSNR = %0.3f\n" % np.mean(psnr) print(printout) if args.save_metrics_text: with open("%s_denoising_glow_results.txt" % args.dataset, "a") as f: f.write('\n' + printout) # saving images if args.save_results: gamma = gamma.item() file_names = [ name[0].split("/")[-1].split(".")[0] for name in test_dataset.samples ] save_path = save_path + "/denoising_noisestd_%0.4f_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s" save_path = save_path % (args.noise_std, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_std, args.optim) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 _ = [ sio.imsave(save_path + "/" + name + "_recov.jpg", x) for x, name in zip(Recovered, file_names) ] _ = [ sio.imsave(save_path + "/" + name + "_noisy.jpg", x) for x, name in zip(Noisy, file_names) ] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path + "/residual_curve.npy", Residual_Curve) np.save(save_path + "/original.npy", Original) np.save(save_path + "/recovered.npy", Recovered) np.save(save_path + "/noisy.npy", Noisy) np.save(save_path + "/gen_steps.npy", gen_steps)