def GlowInpaint(args): loopOver = zip(args.gamma) for gamma in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size * args.size * 3 modeldir = "./trained_models/%s/glow" % args.model test_folder = "./test_images/%s" % args.dataset save_path = "./results/%s/%s" % (args.dataset, args.experiment) # loading dataset trans = transforms.Compose( [transforms.Resize((args.size, args.size)), transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batchsize, drop_last=False, shuffle=False) # loading glow configurations config_path = modeldir + "/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) # getting test images Original = [] Recovered = [] Masked = [] Mask = [] Residual_Curve = [] for i, data in enumerate(test_dataloader): # getting batch of data x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize" # generate mask mask = gen_mask(args.inpaint_method, args.size, args.mask_size) mask = np.array([mask for i in range(n_test)]) mask = mask.reshape([n_test, 1, args.size, args.size]) mask = torch.tensor(mask, dtype=torch.float, requires_grad=False, device=args.device) # loading glow model glow = Glow((3, args.size, args.size), K=configs["K"], L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt")) glow.eval() # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # initializing z from Gaussian if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) # initializing z from image with noise filled only in masked region elif args.init_strategy == "noisy_filled": x_noisy_filled = x_test.clone().detach() noise = np.random.normal(0, 0.2, x_noisy_filled.size()) noise = torch.tensor(noise, dtype=torch.float, device=args.device) noise = noise * (1 - mask) x_noisy_filled = x_noisy_filled + noise x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1) z, _, _ = glow(x_noisy_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from image with masked region inverted elif args.init_strategy == "inverted_filled": x_inverted_filled = x_test.clone().detach() missing_x = x_inverted_filled.clone() missing_x = missing_x.data.cpu().numpy() missing_x = missing_x[:, :, ::-1, ::-1] missing_x = torch.tensor(missing_x.copy(), dtype=torch.float, device=args.device) missing_x = (1 - mask) * missing_x x_inverted_filled = x_inverted_filled * mask x_inverted_filled = x_inverted_filled + missing_x z, _, _ = glow(x_inverted_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from masked image ( masked region as zeros ) elif args.init_strategy == "black_filled": x_black_filled = x_test.clone().detach() x_black_filled = mask * x_black_filled x_black_filled = x_black_filled * mask z, _, _ = glow(x_black_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from noisy complete image elif args.init_strategy == "noisy": x_noisy = x_test.clone().detach() noise = np.random.normal(0, 0.05, x_noisy.size()) noise = torch.tensor(noise, dtype=torch.float, device=args.device) x_noisy = x_noisy + noise x_noisy = torch.clamp(x_noisy, 0, 1) z, _, _ = glow(x_noisy - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from image with only noise in masked region elif args.init_strategy == "only_noise_filled": x_noisy_filled = x_test.clone().detach() noise = np.random.normal(0, 0.2, x_noisy_filled.size()) noise = torch.tensor(noise, dtype=torch.float, device=args.device) noise = noise * (1 - mask) x_noisy_filled = mask * x_noisy_filled + noise x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1) z, _, _ = glow(x_noisy_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) else: raise "Initialization strategy not defined" # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam( [z_sampled], lr=args.lr, ) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS( [z_sampled], lr=args.lr, ) # metrics to record over training psnr_t = torch.nn.MSELoss().to(device=args.device) residual = [] # running optimizer steps for t in range(args.steps): def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_masked_test = x_test * mask x_masked_gen = x_gen * mask global residual_t residual_t = ((x_masked_gen - x_masked_test)**2).view( len(x_masked_test), -1).sum(dim=1).mean() if args.z_penalty_unsquared: z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean() else: z_reg_loss_t = gamma * (z_sampled.norm(dim=1)** 2).mean() loss_t = residual_t + z_reg_loss_t psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print( "\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f" % (t, loss_t.item(), residual_t.item(), z_reg_loss_t.item(), psnr), end="\r") loss_t.backward() return loss_t try: optimizer.step(closure) residual.append(residual_t.item()) except: skip_to_next = True break if skip_to_next: break # getting recovered and true images x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) mask_np = mask.data.cpu().numpy() x_masked_test = x_test * mask x_masked_test_np = x_masked_test.data.cpu().numpy().transpose( 0, 2, 3, 1) x_masked_test_np = np.clip(x_masked_test_np, 0, 1) Original.append(x_test_np) Recovered.append(x_gen_np) Masked.append(x_masked_test_np) Residual_Curve.append(residual) Mask.append(mask_np) # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, mask, torch.cuda.empty_cache() print("\nbatch completed") if skip_to_next: print( "\nskipping current loop due to instability or user triggered quit" ) continue # metric evaluations Original = np.vstack(Original) Recovered = np.vstack(Recovered) Masked = np.vstack(Masked) Mask = np.vstack(Mask) psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)] # print performance analysis printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n" printout = printout + "\t n_test = %d\n" % len(Recovered) printout = printout + "\t inpaint_method = %s\n" % args.inpaint_method printout = printout + "\t mask_size = %0.3f\n" % args.mask_size printout = printout + "\t gamma = %0.6f\n" % gamma printout = printout + "\t PSNR = %0.3f\n" % np.mean(psnr) print(printout) if args.save_metrics_text: with open("%s_inpaint_glow_results.txt" % args.dataset, "a") as f: f.write('\n' + printout) # saving images if args.save_results: gamma = gamma.item() file_names = [ name[0].split("/")[-1].split(".")[0] for name in test_dataset.samples ] if args.init_strategy == 'random': save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s" save_path = save_path % (args.inpaint_method, args.mask_size, gamma, args.steps, args.lr, args.init_std, args.optim) else: save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s" save_path = save_path % (args.inpaint_method, args.mask_size, gamma, args.steps, args.lr, args.init_strategy, args.optim) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 _ = [ sio.imsave(save_path + "/" + name + "_recov.jpg", x) for x, name in zip(Recovered, file_names) ] _ = [ sio.imsave(save_path + "/" + name + "_masked.jpg", x) for x, name in zip(Masked, file_names) ] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path + "/" + "residual_curve.npy", Residual_Curve) np.save(save_path + "/original.npy", Original) np.save(save_path + "/recovered.npy", Recovered) np.save(save_path + "/mask.npy", Mask) np.save(save_path + "/masked.npy", Masked)
def GlowREDDenoiser(args): loopOver = zip(args.gamma) for gamma in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size * args.size * 3 modeldir = "./trained_models/%s/glow" % args.model test_folder = "./test_images/%s" % args.dataset save_path = "./results/%s/%s" % (args.dataset, args.experiment) # loading dataset trans = transforms.Compose( [transforms.Resize((args.size, args.size)), transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batchsize, drop_last=False, shuffle=False) # loading glow configurations config_path = modeldir + "/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) alpha = args.alpha beta = args.beta # getting test images gen_steps = [] Original = [] Recovered = [] Noisy = [] Residual_Curve = [] for i, data in enumerate(test_dataloader): # getting batch of data x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize" # noise to be added if args.noise == "gaussian": noise = np.random.normal(0, args.noise_std, size=(n_test, 3, args.size, args.size)) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) elif args.noise == "laplacian": noise = np.random.laplace(scale=args.noise_std, size=(n_test, 3, args.size, args.size)) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) raise "code only supports gaussian for now" # -> no noise type tag in the folder name else: raise "noise type not defined" # loading glow model glow = Glow((3, args.size, args.size), K=configs["K"], L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt")) glow.eval() # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # initializing z from Gaussian if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) # initializing z from noisy image elif args.init_strategy == "from-noisy": x_noisy = x_test + noise z, _, _ = glow(glow.preprocess(x_noisy * 255, clone=True)) z = glow.flatten_z(z) z_sampled = z.clone().detach().requires_grad_(True) else: raise "Initialization strategy not defined" # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam( [z_sampled], lr=args.lr, ) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS( [z_sampled], lr=args.lr, ) # to be recorded over iteration psnr_t = torch.nn.MSELoss().to(device=args.device) residual = [] x_f = (x_test + noise).clone() u = torch.zeros_like(x_test) # running optimizer steps for t in range(args.steps): def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_noisy = x_test + noise global residual_t residual_t = ((x_gen - x_noisy)**2).view( len(x_noisy), -1).sum(dim=1).mean() if args.z_penalty_squared: z_reg_loss_t = gamma * (z_sampled.norm(dim=1)** 2).mean() else: z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean() residual_x = beta * ((x_gen - (x_f - u))**2).view( len(x_noisy), -1).sum(dim=1).mean() loss_t = residual_t + z_reg_loss_t + residual_x psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print( "\rAt step=%0.3d|loss=%0.4f|residual_t=%0.4f|residual_x=%0.4f|z_reg=%0.5f|psnr=%0.3f" % (t, loss_t.item(), residual_t.item(), residual_x.item(), z_reg_loss_t.item(), psnr), end="\r") loss_t.backward() return loss_t def denoiser_step(x_f, u): z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False).detach() x_gen = glow.postprocess(x_gen, floor_clamp=False) x_f = 1 / (beta + alpha) * ( beta * Denoiser(args.denoiser, args.sigma_f, x_f) + alpha * (x_gen + u)) u = u + x_gen - x_f return x_f, u optimizer.step(closure) residual.append(residual_t.item()) if t % args.update_iter == args.update_iter - 1: x_f, u = denoiser_step(x_f, u) with torch.no_grad(): z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) gen_steps.append(x_gen_np) # try: # optimizer.step(closure) # residual.append(residual_t.item()) # if t % args.update_iter == 0: # x_f, u = denoiser_step(x_f, u) # # except: # skip_to_next = True # break if skip_to_next: break # getting recovered and true images x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) x_noisy = x_test + noise x_noisy_np = x_noisy.data.cpu().numpy().transpose(0, 2, 3, 1) x_noisy_np = np.clip(x_noisy_np, 0, 1) Original.append(x_test_np) Recovered.append(x_gen_np) Noisy.append(x_noisy_np) Residual_Curve.append(residual) # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, noise, torch.cuda.empty_cache() print("\nbatch completed") if skip_to_next: print( "\nskipping current loop due to instability or user triggered quit" ) continue # metric evaluations Original = np.vstack(Original) Recovered = np.vstack(Recovered) gen_steps = np.vstack(gen_steps) Noisy = np.vstack(Noisy) psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)] # print performance analysis printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n" printout = printout + "\t n_test = %d\n" % len(Recovered) printout = printout + "\t noise_std = %0.4f\n" % args.noise_std printout = printout + "\t update_iter= %0.4f\n" % args.update_iter printout = printout + "\t gamma = %0.6f\n" % gamma printout = printout + "\t alpha = %0.6f\n" % alpha printout = printout + "\t beta = %0.6f\n" % beta printout = printout + "\t PSNR = %0.3f\n" % np.mean(psnr) print(printout) if args.save_metrics_text: with open("%s_denoising_glow_results.txt" % args.dataset, "a") as f: f.write('\n' + printout) # saving images if args.save_results: gamma = gamma.item() file_names = [ name[0].split("/")[-1].split(".")[0] for name in test_dataset.samples ] save_path = save_path + "/denoising_noisestd_%0.4f_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s" save_path = save_path % (args.noise_std, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_std, args.optim) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 _ = [ sio.imsave(save_path + "/" + name + "_recov.jpg", x) for x, name in zip(Recovered, file_names) ] _ = [ sio.imsave(save_path + "/" + name + "_noisy.jpg", x) for x, name in zip(Noisy, file_names) ] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path + "/residual_curve.npy", Residual_Curve) np.save(save_path + "/original.npy", Original) np.save(save_path + "/recovered.npy", Recovered) np.save(save_path + "/noisy.npy", Noisy) np.save(save_path + "/gen_steps.npy", gen_steps)
def GlowCS(args): if args.init_norms == None: args.init_norms = [None]*len(args.m) else: assert args.init_strategy == "random_fixed_norm", "init_strategy should be random_fixed_norm if init_norms is used" assert len(args.m) == len(args.gamma) == len(args.init_norms), "length of either m, gamma or init_norms are not same" loopOver = zip(args.m, args.gamma, args.init_norms) for m, gamma, init_norm in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size*args.size*3 modeldir = os.path.join(root_path, "trained_models/%s/glow-cs-%d"%(args.model, args.size)) test_folder = os.path.join(root_path, "test_images/%s_N=12"%args.dataset) save_path = os.path.join(root_path, "results/%s/%s"%(args.dataset,args.experiment)) # loading dataset trans = transforms.Compose([transforms.Resize((args.size,args.size)),transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=args.batchsize,drop_last=False,shuffle=False) # loading glow configurations config_path = modeldir+"/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # sensing matrix A = np.random.normal(0,1/np.sqrt(m), size=(n,m)) A = torch.tensor(A,dtype=torch.float, requires_grad=False, device=args.device) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) # adding noise if args.noise == "random_bora": noise = np.random.normal(0,1,size=(args.batchsize,m)) noise = noise * 0.1/np.sqrt(m) noise = torch.tensor(noise,dtype=torch.float,requires_grad=False, device=args.device) else: noise = np.random.normal(0,1,size=(args.batchsize,m)) noise = noise / (np.linalg.norm(noise,2,axis=-1, keepdims=True)) * float(args.noise) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) # start solving over batches Original = []; Recovered = []; Z_Recovered = []; Residual_Curve = []; Recorded_Z = [] for i, data in enumerate(test_dataloader): x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize" # loading glow model print(f'load glow from: {modeldir}...') glow = Glow((3,args.size,args.size), K=configs["K"],L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device).to(args.device) print(glow.device) glow.load_state_dict(torch.load(modeldir+"/glowmodel.pt", map_location=args.device)) glow.eval() # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # initializing z from Gaussian with std equal to init_std if args.init_strategy == "random": z_sampled = np.random.normal(0,args.init_std,[n_test,n]) z_sampled = torch.tensor(z_sampled,requires_grad=True,dtype=torch.float,device=args.device) # intializing z from Gaussian and scaling its norm to init_norm elif args.init_strategy == "random_fixed_norm": z_sampled = np.random.normal(0,1,[n_test,n]) z_sampled = z_sampled / np.linalg.norm(z_sampled, axis=-1, keepdims=True) z_sampled = z_sampled * init_norm z_sampled = torch.tensor(z_sampled,requires_grad=True,dtype=torch.float,device=args.device) print("z intialized with a norm equal to = %0.1f"%init_norm) # initializing z from pseudo inverse elif args.init_strategy == "pseudo_inverse": x_test_flat = x_test.view([-1,n]) y_true = torch.matmul(x_test_flat, A) + noise A_pinv = torch.pinverse(A) x_pinv = torch.matmul(y_true, A_pinv) x_pinv = x_pinv.view([-1,3,args.size,args.size]) x_pinv = torch.clamp(x_pinv,0,1) z, _, _ = glow(glow.preprocess(x_pinv*255,clone=True)) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) # initializing z from a solution of lasso-wavelet elif args.init_strategy == "lasso_wavelet": new_args = {"batch_size":n_test, "lmbd":0.01,"lasso_solver":"sklearn"} new_args = easydict.EasyDict(new_args) estimator = celebA_estimators.lasso_wavelet_estimator(new_args) x_ch_last = x_test.permute(0,2,3,1) x_ch_last = x_ch_last.contiguous().view([-1,n]) y_true = torch.matmul(x_ch_last, A) + noise x_lasso = estimator(np.sqrt(2*m)*A.data.cpu().numpy(), np.sqrt(2*m)*y_true.data.cpu().numpy(), new_args) x_lasso = np.array(x_lasso) x_lasso = x_lasso.reshape(-1,64,64,3) x_lasso = x_lasso.transpose(0,3,1,2) x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device) z, _, _ = glow(x_lasso - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) print("z intialized from a solution of lasso-wavelet") # intializing z from null(A) elif args.init_strategy == "null_space": x_test_flat = x_test.view([-1,n]) x_test_flat_np = x_test_flat.data.cpu().numpy() A_np = A.data.cpu().numpy() nullA = null_space(A_np.T) coeff = np.random.normal(0,1,(args.batchsize, nullA.shape[1])) x_null = np.array([(nullA * c).sum(axis=-1) for c in coeff]) pert_norm = 5 # <-- 5 gives optimal results -- bad initialization and not too unstable x_null = x_null / np.linalg.norm(x_null, axis=1, keepdims=True) * pert_norm x_perturbed = x_test_flat_np + x_null # no clipping x_perturbed to make sure forward model is ||y-Ax|| is the same err = np.matmul(x_test_flat_np,A_np) - np.matmul(x_perturbed,A_np) assert (err **2).sum() < 1e-6, "null space does not satisfy ||y-A(x+x0)|| <= 1e-6" x_perturbed = x_perturbed.reshape(-1,3,args.size,args.size) x_perturbed = torch.tensor(x_perturbed, dtype=torch.float, device=args.device) z, _, _ = glow(x_perturbed - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) print("z initialized from a point in null space of A") else: raise "Initialization strategy not defined" # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam([z_sampled], lr=args.lr,) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS([z_sampled], lr=args.lr,) else: raise "optimizer not defined" # to be recorded over iteration psnr_t = torch.nn.MSELoss().to(device=args.device) residual = []; recorded_z = [] # running optimizer steps for t in range(args.steps): def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen,floor_clamp=False) x_test_flat = x_test.view([-1,n]) x_gen_flat = x_gen.view([-1,n]) y_true = torch.matmul(x_test_flat, A) + noise y_gen = torch.matmul(x_gen_flat, A) global residual_t residual_t = ((y_gen - y_true)**2).sum(dim=1).mean() if not args.z_penalty_unsquared: z_reg_loss_t= gamma*(z_sampled.norm(dim=1)**2).mean() else: z_reg_loss_t= gamma*z_sampled.norm(dim=1).mean() loss_t = residual_t + z_reg_loss_t psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print("\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f"%(t,loss_t.item(),residual_t.item(),z_reg_loss_t.item(), psnr),end="\r") loss_t.backward(retain_graph=True) return loss_t try: optimizer.step(closure) recorded_z.append(z_sampled.data.cpu().numpy()) residual.append(residual_t.item()) except Exception as e: traceback.print_exc() # try may not work due to instability in the reverse direction. skip_to_next = True break if skip_to_next: break # getting recovered and true images with torch.no_grad(): x_test_np = x_test.data.cpu().numpy().transpose(0,2,3,1) z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen,floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0,2,3,1) x_gen_np = np.clip(x_gen_np,0,1) z_recov = z_sampled.data.cpu().numpy() Original.append(x_test_np) Recovered.append(x_gen_np) Z_Recovered.append(z_recov) Residual_Curve.append(residual) Recorded_Z.append(recorded_z) # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow with torch.cuda.device(args.device): torch.cuda.empty_cache() print("\nbatch completed") if skip_to_next: print("\nskipping current loop due to instability or user triggered quit") continue # collecting everything together Original = np.vstack(Original) Recovered = np.vstack(Recovered) Z_Recovered = np.vstack(Z_Recovered) Recorded_Z = np.vstack(Recorded_Z) psnr = [compare_psnr(x, y) for x,y in zip(Original, Recovered)] z_recov_norm = np.linalg.norm(Z_Recovered, axis=-1) # print performance analysis printout = "+-"*10 + "%s"%args.dataset + "-+"*10 + "\n" printout = printout + "\t n_test = %d\n"%len(Recovered) printout = printout + "\t n = %d\n"%n printout = printout + "\t m = %d\n"%m printout = printout + "\t gamma = %0.6f\n"%gamma printout = printout + "\t optimizer = %s\n"%args.optim printout = printout + "\t lr = %0.3f\n"%args.lr printout = printout + "\t steps = %0.3f\n"%args.steps printout = printout + "\t init_strategy = %s\n"%args.init_strategy printout = printout + "\t init_std = %0.3f\n"%args.init_std if init_norm is not None: printout = printout + "\t init_norm = %0.3f\n"%init_norm printout = printout + "\t z_recov_norm = %0.3f\n"%np.mean(z_recov_norm) printout = printout + "\t PSNR = %0.3f\n"%(np.mean(psnr)) print(printout) # saving printout if args.save_metrics_text: with open("%s_cs_glow_results.txt"%args.dataset,"a") as f: f.write('\n' + printout) # setting folder to save results in if args.save_results: gamma = gamma.item() file_names = [name[0].split("/")[-1] for name in test_dataset.samples] if args.init_strategy == "random": save_path_template = save_path + "/cs_m_%d_gamma_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s" save_path = save_path_template%(m,gamma,args.steps,args.lr,args.init_std,args.optim) elif args.init_strategy == "random_fixed_norm": save_path_template = save_path+"/cs_m_%d_gamma_%0.6f_steps_%d_lr_%0.3f_init_%s_%0.3f_optim_%s" save_path = save_path_template%(m,gamma,args.steps,args.lr,args.init_strategy,init_norm, args.optim) else: save_path_template = save_path + "/cs_m_%d_gamma_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s" save_path = save_path_template%(m,gamma,args.steps,args.lr,args.init_strategy,args.optim) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 # saving results now _ = [sio.imsave(save_path+"/"+name, x) for x,name in zip(Recovered,file_names)] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path+"/original.npy", Original) np.save(save_path+"/recovered.npy", Recovered) np.save(save_path+"/z_recovered.npy", Z_Recovered) np.save(save_path+"/residual_curve.npy", Residual_Curve) if init_norm is not None: np.save(save_path+"/Recorded_Z_init_norm_%d.npy"%init_norm, Recorded_Z) torch.cuda.empty_cache()