def GlowInpaint(args): loopOver = zip(args.gamma) for gamma in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size * args.size * 3 modeldir = "./trained_models/%s/glow" % args.model test_folder = "./test_images/%s" % args.dataset save_path = "./results/%s/%s" % (args.dataset, args.experiment) # loading dataset trans = transforms.Compose( [transforms.Resize((args.size, args.size)), transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batchsize, drop_last=False, shuffle=False) # loading glow configurations config_path = modeldir + "/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) # getting test images Original = [] Recovered = [] Masked = [] Mask = [] Residual_Curve = [] for i, data in enumerate(test_dataloader): # getting batch of data x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize" # generate mask mask = gen_mask(args.inpaint_method, args.size, args.mask_size) mask = np.array([mask for i in range(n_test)]) mask = mask.reshape([n_test, 1, args.size, args.size]) mask = torch.tensor(mask, dtype=torch.float, requires_grad=False, device=args.device) # loading glow model glow = Glow((3, args.size, args.size), K=configs["K"], L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt")) glow.eval() # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # initializing z from Gaussian if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) # initializing z from image with noise filled only in masked region elif args.init_strategy == "noisy_filled": x_noisy_filled = x_test.clone().detach() noise = np.random.normal(0, 0.2, x_noisy_filled.size()) noise = torch.tensor(noise, dtype=torch.float, device=args.device) noise = noise * (1 - mask) x_noisy_filled = x_noisy_filled + noise x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1) z, _, _ = glow(x_noisy_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from image with masked region inverted elif args.init_strategy == "inverted_filled": x_inverted_filled = x_test.clone().detach() missing_x = x_inverted_filled.clone() missing_x = missing_x.data.cpu().numpy() missing_x = missing_x[:, :, ::-1, ::-1] missing_x = torch.tensor(missing_x.copy(), dtype=torch.float, device=args.device) missing_x = (1 - mask) * missing_x x_inverted_filled = x_inverted_filled * mask x_inverted_filled = x_inverted_filled + missing_x z, _, _ = glow(x_inverted_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from masked image ( masked region as zeros ) elif args.init_strategy == "black_filled": x_black_filled = x_test.clone().detach() x_black_filled = mask * x_black_filled x_black_filled = x_black_filled * mask z, _, _ = glow(x_black_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from noisy complete image elif args.init_strategy == "noisy": x_noisy = x_test.clone().detach() noise = np.random.normal(0, 0.05, x_noisy.size()) noise = torch.tensor(noise, dtype=torch.float, device=args.device) x_noisy = x_noisy + noise x_noisy = torch.clamp(x_noisy, 0, 1) z, _, _ = glow(x_noisy - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) # initializing z from image with only noise in masked region elif args.init_strategy == "only_noise_filled": x_noisy_filled = x_test.clone().detach() noise = np.random.normal(0, 0.2, x_noisy_filled.size()) noise = torch.tensor(noise, dtype=torch.float, device=args.device) noise = noise * (1 - mask) x_noisy_filled = mask * x_noisy_filled + noise x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1) z, _, _ = glow(x_noisy_filled - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = z.clone().detach().requires_grad_(True) else: raise "Initialization strategy not defined" # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam( [z_sampled], lr=args.lr, ) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS( [z_sampled], lr=args.lr, ) # metrics to record over training psnr_t = torch.nn.MSELoss().to(device=args.device) residual = [] # running optimizer steps for t in range(args.steps): def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_masked_test = x_test * mask x_masked_gen = x_gen * mask global residual_t residual_t = ((x_masked_gen - x_masked_test)**2).view( len(x_masked_test), -1).sum(dim=1).mean() z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean() loss_t = residual_t + z_reg_loss_t psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print( "\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f" % (t, loss_t.item(), residual_t.item(), z_reg_loss_t.item(), psnr), end="\r") loss_t.backward() return loss_t try: optimizer.step(closure) residual.append(residual_t.item()) except: skip_to_next = True break if skip_to_next: break # getting recovered and true images x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) mask_np = mask.data.cpu().numpy() x_masked_test = x_test * mask x_masked_test_np = x_masked_test.data.cpu().numpy().transpose( 0, 2, 3, 1) x_masked_test_np = np.clip(x_masked_test_np, 0, 1) Original.append(x_test_np) Recovered.append(x_gen_np) Masked.append(x_masked_test_np) Residual_Curve.append(residual) Mask.append(mask_np) # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, mask, torch.cuda.empty_cache() print("\nbatch completed") if skip_to_next: print( "\nskipping current loop due to instability or user triggered quit" ) continue # metric evaluations Original = np.vstack(Original) Recovered = np.vstack(Recovered) Masked = np.vstack(Masked) Mask = np.vstack(Mask) psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)] # print performance analysis printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n" printout = printout + "\t n_test = %d\n" % len(Recovered) printout = printout + "\t inpaint_method = %s\n" % args.inpaint_method printout = printout + "\t mask_size = %0.3f\n" % args.mask_size printout = printout + "\t gamma = %0.6f\n" % gamma printout = printout + "\t PSNR = %0.3f\n" % np.mean(psnr) print(printout) if args.save_metrics_text: with open("%s_inpaint_glow_results.txt" % args.dataset, "a") as f: f.write('\n' + printout) # saving images if args.save_results: gamma = gamma.item() file_names = [ name[0].split("/")[-1].split(".")[0] for name in test_dataset.samples ] if args.init_strategy == 'random': save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s" save_path = save_path % (args.inpaint_method, args.mask_size, gamma, args.steps, args.lr, args.init_std, args.optim) else: save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s" save_path = save_path % (args.inpaint_method, args.mask_size, gamma, args.steps, args.lr, args.init_strategy, args.optim) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 _ = [ sio.imsave(save_path + "/" + name + "_recov.jpg", x) for x, name in zip(Recovered, file_names) ] _ = [ sio.imsave(save_path + "/" + name + "_masked.jpg", x) for x, name in zip(Masked, file_names) ] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path + "/" + "residual_curve.npy", Residual_Curve) np.save(save_path + "/original.npy", Original) np.save(save_path + "/recovered.npy", Recovered) np.save(save_path + "/mask.npy", Mask) np.save(save_path + "/masked.npy", Masked)
def GlowREDCS(args, filename=None): if args.init_norms == None: args.init_norms = [None] * len(args.m) else: assert args.init_strategy == "random_fixed_norm", "init_strategy should be random_fixed_norm if init_norms is used" assert len(args.m) == len(args.gamma) == len( args.init_norms), "length of either m, gamma or init_norms are not same" loopOver = zip(args.m, args.gamma, args.init_norms) for m, gamma, init_norm in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size * args.size * 3 modeldir = "./trained_models/%s/glow" % args.model test_folder = "./test_images/%s" % args.dataset save_path = "./results/%s/%s" % (args.dataset, args.experiment) # loading dataset trans = transforms.Compose([transforms.Resize((args.size, args.size)), transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, drop_last=False, shuffle=False) # loading glow configurations config_path = modeldir + "/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # sensing matrix A = np.random.normal(0, 1 / np.sqrt(m), size=(n, m)) A = torch.tensor(A, dtype=torch.float, requires_grad=False, device=args.device) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) alpha = args.alpha beta = args.beta # adding noise if args.noise == "random_bora": noise = np.random.normal(0, 1, size=(args.batchsize, m)) noise = noise * 0.1 / np.sqrt(m) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) else: noise = np.random.normal(0, 1, size=(args.batchsize, m)) noise = noise / (np.linalg.norm(noise, 2, axis=-1, keepdims=True)) * float(args.noise) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) # start solving over batches Original = []; Recovered = []; Recovered_f = []; Z_Recovered = []; Residual_Curve = []; Recorded_Z = [] for i, data in enumerate(test_dataloader): x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize" # loading glow model glow = Glow((3, args.size, args.size), K=configs["K"], L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt")) glow.eval() # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # initializing z from Gaussian with std equal to init_std if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) # intializing z from Gaussian and scaling its norm to init_norm elif args.init_strategy == "random_fixed_norm": z_sampled = np.random.normal(0, 1, [n_test, n]) z_sampled = z_sampled / np.linalg.norm(z_sampled, axis=-1, keepdims=True) z_sampled = z_sampled * init_norm z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) print("z intialized with a norm equal to = %0.1f" % init_norm) # initializing z from pseudo inverse elif args.init_strategy == "pseudo_inverse": x_test_flat = x_test.view([-1, n]) y_true = torch.matmul(x_test_flat, A) + noise A_pinv = torch.pinverse(A) x_pinv = torch.matmul(y_true, A_pinv) x_pinv = x_pinv.view([-1, 3, args.size, args.size]) x_pinv = torch.clamp(x_pinv, 0, 1) z, _, _ = glow(glow.preprocess(x_pinv * 255, clone=True)) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) # initializing z from a solution of lasso-wavelet elif args.init_strategy == "lasso_wavelet": new_args = {"batch_size": n_test, "lmbd": 0.01, "lasso_solver": "sklearn"} new_args = easydict.EasyDict(new_args) estimator = celebA_estimators.lasso_wavelet_estimator(new_args) x_ch_last = x_test.permute(0, 2, 3, 1) x_ch_last = x_ch_last.contiguous().view([-1, n]) y_true = torch.matmul(x_ch_last, A) + noise x_lasso = estimator(np.sqrt(2 * m) * A.data.cpu().numpy(), np.sqrt(2 * m) * y_true.data.cpu().numpy(), new_args) x_lasso = np.array(x_lasso) x_lasso = x_lasso.reshape(-1, 64, 64, 3) x_lasso = x_lasso.transpose(0, 3, 1, 2) x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device) z, _, _ = glow(x_lasso - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) print("z intialized from a solution of lasso-wavelet") elif args.init_strategy == "lasso_dct": new_args = {"batch_size": n_test, "lmbd": 0.01, "lasso_solver": "sklearn"} new_args = easydict.EasyDict(new_args) estimator = celebA_estimators.lasso_dct_estimator(new_args) x_ch_last = x_test.permute(0, 2, 3, 1) x_ch_last = x_ch_last.contiguous().view([-1, n]) y_true = torch.matmul(x_ch_last, A) + noise x_lasso = estimator(np.sqrt(2 * m) * A.data.cpu().numpy(), np.sqrt(2 * m) * y_true.data.cpu().numpy(), new_args) x_lasso = np.array(x_lasso) x_lasso = x_lasso.reshape(-1, 64, 64, 3) x_lasso = x_lasso.transpose(0, 3, 1, 2) x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device) z, _, _ = glow(x_lasso - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) print("z intialized from a solution of lasso-dct") elif args.init_strategy == "random_lasso_dct": new_args = {"batch_size": n_test, "lmbd": 0.01, "lasso_solver": "sklearn"} new_args = easydict.EasyDict(new_args) estimator = celebA_estimators.lasso_dct_estimator(new_args) x_ch_last = x_test.permute(0, 2, 3, 1) x_ch_last = x_ch_last.contiguous().view([-1, n]) y_true = torch.matmul(x_ch_last, A) + noise x_lasso = estimator(np.sqrt(2 * m) * A.data.cpu().numpy(), np.sqrt(2 * m) * y_true.data.cpu().numpy(), new_args) x_lasso = np.array(x_lasso) x_lasso = x_lasso.reshape(-1, 64, 64, 3) x_lasso = x_lasso.transpose(0, 3, 1, 2) x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device) z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) print("z intialized randomly and RED is initialized from a solution of lasso-dct") # intializing z from null(A) elif args.init_strategy == "null_space": x_test_flat = x_test.view([-1, n]) x_test_flat_np = x_test_flat.data.cpu().numpy() A_np = A.data.cpu().numpy() nullA = null_space(A_np.T) coeff = np.random.normal(0, 1, (args.batchsize, nullA.shape[1])) x_null = np.array([(nullA * c).sum(axis=-1) for c in coeff]) pert_norm = 5 # <-- 5 gives optimal results -- bad initialization and not too unstable x_null = x_null / np.linalg.norm(x_null, axis=1, keepdims=True) * pert_norm x_perturbed = x_test_flat_np + x_null # no clipping x_perturbed to make sure forward model is ||y-Ax|| is the same err = np.matmul(x_test_flat_np, A_np) - np.matmul(x_perturbed, A_np) assert (err ** 2).sum() < 1e-6, "null space does not satisfy ||y-A(x+x0)|| <= 1e-6" x_perturbed = x_perturbed.reshape(-1, 3, args.size, args.size) x_perturbed = torch.tensor(x_perturbed, dtype=torch.float, device=args.device) z, _, _ = glow(x_perturbed - 0.5) z = glow.flatten_z(z).clone().detach() z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device) print("z initialized from a point in null space of A") else: raise "Initialization strategy not defined" # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam([z_sampled], lr=args.lr, ) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS([z_sampled], lr=args.lr, ) elif args.optim == "rk2": optimizer = RK2Heun([z_sampled], lr=args.lr) elif args.optim == "raghav": optimizer = RK2Raghav([z_sampled], lr=args.lr) elif args.optim == "sgd": optimizer = torch.optim.SGD([z_sampled], lr=args.lr, momentum=0.9) else: raise "optimizer not defined" # to be recorded over iteration psnr_t = torch.nn.MSELoss().to(device=args.device) residual = []; recorded_z = [] x_f = x_lasso.clone() u = torch.zeros_like(x_test) df_losses = pd.DataFrame(columns=["loss_t", "residual_t", "residual_x", "z_reg_loss"]) ################## alpha = args.alpha beta = args.beta ################## # running optimizer steps for t in range(args.steps): def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_test_flat = x_test.view([-1, n]) x_gen_flat = x_gen.view([-1, n]) y_true = torch.matmul(x_test_flat, A) + noise y_gen = torch.matmul(x_gen_flat, A) global residual_t residual_t = ((y_gen - y_true) ** 2).sum(dim=1).mean() z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean() residual_x = beta * ((x_gen - (x_f - u)) ** 2).view(len(x_f), -1).sum(dim=1).mean() loss_t = residual_t + z_reg_loss_t + residual_x psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print("At step=%0.3d|loss=%0.4f|residual_t=%0.4f|residual_x=%0.4f|z_reg=%0.5f|psnr=%0.3f" % ( t, loss_t.item(), residual_t.item(), residual_x.item(), z_reg_loss_t.item(), psnr)) loss_t.backward() update = [loss_t.item(), residual_t.item(), residual_x.item(), z_reg_loss_t.item()] df_losses.loc[(len(df_losses))] = update df_losses.to_csv(filename) return loss_t def denoiser_step(x_f, u): z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False).detach() x_gen = glow.postprocess(x_gen, floor_clamp=False) x_f = 1 / (beta + alpha) * (beta * Denoiser(args.denoiser, args.sigma_f, x_f) + alpha * (x_gen + u)) u = u + x_gen - x_f return x_f, u optimizer.step(closure) recorded_z.append(z_sampled.data.cpu().numpy()) residual.append(residual_t.item()) if t % args.update_iter == args.update_iter - 1: x_f, u = denoiser_step(x_f, u) # if t == args.steps//2: # gamma /= 10 # try: # optimizer.step(closure) # recorded_z.append(z_sampled.data.cpu().numpy()) # residual.append(residual_t.item()) # except: # # try may not work due to instability in the reverse direction. # skip_to_next = True # break if skip_to_next: break # getting recovered and true images with torch.no_grad(): x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) x_f_np = x_f.cpu().numpy().transpose(0, 2, 3, 1) x_f_np = np.clip(x_f_np, 0, 1) z_recov = z_sampled.data.cpu().numpy() Original.append(x_test_np) Recovered.append(x_gen_np) Recovered_f.append(x_f_np) Z_Recovered.append(z_recov) Residual_Curve.append(residual) Recorded_Z.append(recorded_z) # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow torch.cuda.empty_cache() print("\nbatch completed") if skip_to_next: print("\nskipping current loop due to instability or user triggered quit") continue # collecting everything together Original = np.vstack(Original) Recovered = np.vstack(Recovered) Recovered_f = np.vstack(Recovered_f) Z_Recovered = np.vstack(Z_Recovered) Recorded_Z = np.vstack(Recorded_Z) psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)] psnr_f = [compare_psnr(x, y) for x, y in zip(Original, Recovered_f)] z_recov_norm = np.linalg.norm(Z_Recovered, axis=-1) # print performance analysis printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n" printout = printout + "\t n_test = %d\n" % len(Recovered) printout = printout + "\t n = %d\n" % n printout = printout + "\t m = %d\n" % m printout = printout + "\t update_iter = %0.4f\n" % args.update_iter printout = printout + "\t gamma = %0.6f\n" % gamma printout = printout + "\t alpha = %0.6f\n" % alpha printout = printout + "\t beta = %0.6f\n" % beta printout = printout + "\t optimizer = %s\n" % args.optim printout = printout + "\t lr = %0.3f\n" % args.lr printout = printout + "\t steps = %0.3f\n" % args.steps printout = printout + "\t init_strategy = %s\n" % args.init_strategy printout = printout + "\t init_std = %0.3f\n" % args.init_std if init_norm is not None: printout = printout + "\t init_norm = %0.3f\n" % init_norm printout = printout + "\t z_recov_norm = %0.3f\n" % np.mean(z_recov_norm) printout = printout + "\t mean PSNR = %0.3f\n" % (np.mean(psnr)) printout = printout + "\t mean PSNR_f = %0.3f\n" % (np.mean(psnr_f)) print(printout) # saving printout if args.save_metrics_text: with open("%s_cs_glow_results.txt" % args.dataset, "a") as f: f.write('\n' + printout) # setting folder to save results in if args.save_results: gamma = gamma.item() file_names = [name[0].split("/")[-1] for name in test_dataset.samples] if args.init_strategy == "random": save_path_template = save_path + "/cs_m_%d_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s" save_path = save_path_template % (m, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_std, args.optim) elif args.init_strategy == "random_fixed_norm": save_path_template = save_path + "/cs_m_%d_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_%s_%0.3f_optim_%s" save_path = save_path_template % ( m, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_strategy, init_norm, args.optim) else: save_path_template = save_path + "/cs_m_%d_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s" save_path = save_path_template % (m, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_strategy, args.optim) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 # saving results now _ = [sio.imsave(save_path + "/" + name, x) for x, name in zip(Recovered, file_names)] print(save_path+"/"+file_names[0]) _ = [sio.imsave(save_path + "/f_" + name, x) for x, name in zip(Recovered_f, file_names)] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path + "/original.npy", Original) np.save(save_path + "/recovered.npy", Recovered) np.save(save_path + "/recovered_f.npy", Recovered_f) np.save(save_path + "/z_recovered.npy", Z_Recovered) np.save(save_path + "/residual_curve.npy", Residual_Curve) if init_norm is not None: np.save(save_path + "/Recorded_Z_init_norm_%d.npy" % init_norm, Recorded_Z) torch.cuda.empty_cache()
def image_noise(unused_loc, scale, **image_prior): noise = image_prior.get('noise', 'glow') size = image_prior.get('size') bsz = image_prior.get('bsz') configs = image_prior.get('configs') device = image_prior.get('device') dataset = image_prior.get('dataset') if noise == 'glow': modeldir = f"./trained_models/{dataset}/glow-cs-{size}" glow = Glow((3, size, size), K=configs["K"], L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=device) glow.load_state_dict( torch.load(modeldir + "/glowmodel.pt", map_location=device)) glow.eval() _ = glow( glow.preprocess( torch.zeros(size=(bsz, 3, size, size), device=device))) n = size * size * 3 def _image_noise(unused_sample_size): np.random.seed(1) torch.manual_seed(1) _, z = glow.generate_z(n=bsz, mu=0, std=0.5, to_torch=True) noise = glow.postprocess(glow.forward(z, reverse=True)) * scale return noise return _image_noise elif noise == 'dcgan': modeldir = "./trained_models/%s/dcgan" % dataset generator = Generator(ngpu=1).to(device=device) generator.load_state_dict(torch.load(modeldir + '/dcgan_G.pt')) generator.eval() n = 100 def _image_noise(unused_sample_size): np.random.seed(1) z = np.random.normal(size=(bsz, n, 1, 1)) z = torch.tensor(z, dtype=torch.float, requires_grad=False, device=device) noise = generator(z) # todo: why? noise = (noise + 1) / 2 return noise * scale return _image_noise else: raise NotImplementedError()
def GlowDenoiser(args): loopOver = zip(args.gamma) # try different gamma values for gamma in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size*args.size*3 # modeldir = f"./trained_models/{args.dataset}/glow-denoising" modeldir = f"./trained_models/{args.dataset}/glow-cs-{args.size}" test_folder = f"./test_images/{args.dataset}_N=12" save_path = f"./results/{args.dataset}/{args.experiment}" # loading dataset trans = transforms.Compose([transforms.Resize((args.size,args.size)),transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=args.batchsize,drop_last=False,shuffle=False) # loading glow configurations config_path = modeldir+"/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # noiser noiser = Noiser(args, configs) # loss loss = recon_loss(args.noise, args.noise_loc, args.noise_scale) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) # getting test images Original = [] Recovered = [] Noisy = [] Noise = [] Residual_Curve = [] for i, data in enumerate(test_dataloader): x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, \ "please make sure that no. of images are evenly divided by batchsize" # loading glow model glow = Glow((3,args.size,args.size), K=configs["K"],L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir+"/glowmodel.pt", map_location=args.device)) glow.eval() # add noise noise = noiser(x_test) x_noisy = x_test + noise x_noisy = torch.clamp(x_noisy, 0., 1.) # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # np.random.seed(args.random_seed) if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) elif args.init_strategy == "from-noisy": z, _, _ = glow(glow.preprocess(x_noisy*255,clone=True)) z = glow.flatten_z(z) z_sampled = z.clone().detach().cpu().numpy() else: raise NotImplementedError("Initialization strategy not defined") z_sampled = torch.from_numpy(z_sampled).float().to(args.device) z_sampled = nn.Parameter(z_sampled, requires_grad=True) # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam([z_sampled], lr=args.lr,) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS([z_sampled], lr=args.lr,) # to be recorded over iteration psnr_t = torch.nn.MSELoss().to(device=args.device) residual = [] # running optimizer steps for t in range(args.steps): try: def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen,floor_clamp=False) global residual_t residual_t = loss(x_gen, x_noisy) if not args.z_penalty_unsquared: z_reg_loss_t= gamma*(z_sampled.norm(dim=1)**2).mean() else: z_reg_loss_t= gamma*z_sampled.norm(dim=1).mean() loss_t = residual_t + z_reg_loss_t psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print("\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f"%( t, loss_t.item(), residual_t.item(), z_reg_loss_t.item(), psnr), end="\r") loss_t.backward(retain_graph=True) return loss_t optimizer.step(closure) residual.append(residual_t.item()) except Exception as e: traceback.print_exc() skip_to_next = True break # if skip_to_next: # break x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) Original.append(x_test_np) noise_np = noise.data.cpu().numpy().transpose(0, 2, 3, 1) Noise.append(noise_np) x_noisy_np = x_noisy.data.cpu().numpy().transpose(0, 2, 3, 1) Noisy.append(x_noisy_np) Residual_Curve.append(residual) x_gen = None try: z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) Recovered.append(x_gen_np) except Exception as e: traceback.print_exc() # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, noise with torch.cuda.device(args.device): torch.cuda.empty_cache() print("\nbatch completed") # todo: remove this break after finishing development. break # if skip_to_next: # print("\nskipping current loop due to instability or user triggered quit") # continue # metric evaluations Original = np.vstack(Original) Noisy = np.vstack(Noisy) Noise = np.vstack(Noise) psnr = None try: Recovered = np.vstack(Recovered) psnr = [compare_psnr(x, y) for x,y in zip(Original, Recovered)] except Exception as e: continue # print performance analysis printout = "+-"*10 + "%s"%args.dataset + "-+"*10 + "\n" printout = printout + "\t n_test = %d\n"%len(Recovered) printout = printout + "\t noise_std = %0.4f\n"%args.noise_scale printout = printout + "\t gamma = %0.6f\n"%gamma if psnr is not None: printout = printout + "\t PSNR = %0.3f\n"%np.mean(psnr) print(printout) if args.save_metrics_text: with open("%s_denoising_glow_results.txt"%args.dataset,"a") as f: f.write('\n' + printout) # saving images if args.save_results: gamma = gamma.item() file_names = [name[0].split("/")[-1].split(".")[0] for name in test_dataset.samples] save_path = os.path.join(save_path, f'{args.noise}_' f'{args.noise_loc}#{args.noise_scale}_' f'{args.noise_channel}_{args.noise_area}_' f'{args.init_strategy}_' f'{round(gamma, 4)}_{gettime()}') print(save_path) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 _ = [sio.imsave(save_path+"/"+name+"_noisy.jpg", x) for x, name in zip(Noisy, file_names)] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path+"/residual_curve.npy", Residual_Curve) np.save(save_path+"/original.npy", Original) np.save(save_path+"/noisy.npy", Noisy) np.save(save_path+"/noise.npy", Noise) if len(Recovered) > 0: np.save(save_path + "/recovered.npy", Recovered) _ = [sio.imsave(save_path + "/" + name + "_recov.jpg", x) for x, name in zip(Recovered, file_names)]
def GlowREDDenoiser(args): loopOver = zip(args.gamma) for gamma in loopOver: skip_to_next = False # flag to skip to next loop if recovery is fails due to instability n = args.size * args.size * 3 modeldir = "./trained_models/%s/glow" % args.model test_folder = "./test_images/%s" % args.dataset save_path = "./results/%s/%s" % (args.dataset, args.experiment) # loading dataset trans = transforms.Compose( [transforms.Resize((args.size, args.size)), transforms.ToTensor()]) test_dataset = datasets.ImageFolder(test_folder, transform=trans) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batchsize, drop_last=False, shuffle=False) # loading glow configurations config_path = modeldir + "/configs.json" with open(config_path, 'r') as f: configs = json.load(f) # regularizor gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device) alpha = args.alpha beta = args.beta # getting test images gen_steps = [] Original = [] Recovered = [] Noisy = [] Residual_Curve = [] for i, data in enumerate(test_dataloader): # getting batch of data x_test = data[0] x_test = x_test.clone().to(device=args.device) n_test = x_test.size()[0] assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize" # noise to be added if args.noise == "gaussian": noise = np.random.normal(0, args.noise_std, size=(n_test, 3, args.size, args.size)) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) elif args.noise == "laplacian": noise = np.random.laplace(scale=args.noise_std, size=(n_test, 3, args.size, args.size)) noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device) raise "code only supports gaussian for now" # -> no noise type tag in the folder name else: raise "noise type not defined" # loading glow model glow = Glow((3, args.size, args.size), K=configs["K"], L=configs["L"], coupling=configs["coupling"], n_bits_x=configs["n_bits_x"], nn_init_last_zeros=configs["last_zeros"], device=args.device) glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt")) glow.eval() # making a forward to record shapes of z's for reverse pass _ = glow(glow.preprocess(torch.zeros_like(x_test))) # initializing z from Gaussian if args.init_strategy == "random": z_sampled = np.random.normal(0, args.init_std, [n_test, n]) z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device) # initializing z from noisy image elif args.init_strategy == "from-noisy": x_noisy = x_test + noise z, _, _ = glow(glow.preprocess(x_noisy * 255, clone=True)) z = glow.flatten_z(z) z_sampled = z.clone().detach().requires_grad_(True) else: raise "Initialization strategy not defined" # selecting optimizer if args.optim == "adam": optimizer = torch.optim.Adam( [z_sampled], lr=args.lr, ) elif args.optim == "lbfgs": optimizer = torch.optim.LBFGS( [z_sampled], lr=args.lr, ) # to be recorded over iteration psnr_t = torch.nn.MSELoss().to(device=args.device) residual = [] x_f = (x_test + noise).clone() u = torch.zeros_like(x_test) # running optimizer steps for t in range(args.steps): def closure(): optimizer.zero_grad() z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_noisy = x_test + noise global residual_t residual_t = ((x_gen - x_noisy)**2).view( len(x_noisy), -1).sum(dim=1).mean() if args.z_penalty_squared: z_reg_loss_t = gamma * (z_sampled.norm(dim=1)** 2).mean() else: z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean() residual_x = beta * ((x_gen - (x_f - u))**2).view( len(x_noisy), -1).sum(dim=1).mean() loss_t = residual_t + z_reg_loss_t + residual_x psnr = psnr_t(x_test, x_gen) psnr = 10 * np.log10(1 / psnr.item()) print( "\rAt step=%0.3d|loss=%0.4f|residual_t=%0.4f|residual_x=%0.4f|z_reg=%0.5f|psnr=%0.3f" % (t, loss_t.item(), residual_t.item(), residual_x.item(), z_reg_loss_t.item(), psnr), end="\r") loss_t.backward() return loss_t def denoiser_step(x_f, u): z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False).detach() x_gen = glow.postprocess(x_gen, floor_clamp=False) x_f = 1 / (beta + alpha) * ( beta * Denoiser(args.denoiser, args.sigma_f, x_f) + alpha * (x_gen + u)) u = u + x_gen - x_f return x_f, u optimizer.step(closure) residual.append(residual_t.item()) if t % args.update_iter == args.update_iter - 1: x_f, u = denoiser_step(x_f, u) with torch.no_grad(): z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) gen_steps.append(x_gen_np) # try: # optimizer.step(closure) # residual.append(residual_t.item()) # if t % args.update_iter == 0: # x_f, u = denoiser_step(x_f, u) # # except: # skip_to_next = True # break if skip_to_next: break # getting recovered and true images x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1) z_unflat = glow.unflatten_z(z_sampled, clone=False) x_gen = glow(z_unflat, reverse=True, reverse_clone=False) x_gen = glow.postprocess(x_gen, floor_clamp=False) x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1) x_gen_np = np.clip(x_gen_np, 0, 1) x_noisy = x_test + noise x_noisy_np = x_noisy.data.cpu().numpy().transpose(0, 2, 3, 1) x_noisy_np = np.clip(x_noisy_np, 0, 1) Original.append(x_test_np) Recovered.append(x_gen_np) Noisy.append(x_noisy_np) Residual_Curve.append(residual) # freeing up memory for second loop glow.zero_grad() optimizer.zero_grad() del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, noise, torch.cuda.empty_cache() print("\nbatch completed") if skip_to_next: print( "\nskipping current loop due to instability or user triggered quit" ) continue # metric evaluations Original = np.vstack(Original) Recovered = np.vstack(Recovered) gen_steps = np.vstack(gen_steps) Noisy = np.vstack(Noisy) psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)] # print performance analysis printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n" printout = printout + "\t n_test = %d\n" % len(Recovered) printout = printout + "\t noise_std = %0.4f\n" % args.noise_std printout = printout + "\t update_iter= %0.4f\n" % args.update_iter printout = printout + "\t gamma = %0.6f\n" % gamma printout = printout + "\t alpha = %0.6f\n" % alpha printout = printout + "\t beta = %0.6f\n" % beta printout = printout + "\t PSNR = %0.3f\n" % np.mean(psnr) print(printout) if args.save_metrics_text: with open("%s_denoising_glow_results.txt" % args.dataset, "a") as f: f.write('\n' + printout) # saving images if args.save_results: gamma = gamma.item() file_names = [ name[0].split("/")[-1].split(".")[0] for name in test_dataset.samples ] save_path = save_path + "/denoising_noisestd_%0.4f_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s" save_path = save_path % (args.noise_std, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_std, args.optim) if not os.path.exists(save_path): os.makedirs(save_path) else: save_path_1 = save_path + "_1" if not os.path.exists(save_path_1): os.makedirs(save_path_1) save_path = save_path_1 else: save_path_2 = save_path + "_2" if not os.path.exists(save_path_2): os.makedirs(save_path_2) save_path = save_path_2 _ = [ sio.imsave(save_path + "/" + name + "_recov.jpg", x) for x, name in zip(Recovered, file_names) ] _ = [ sio.imsave(save_path + "/" + name + "_noisy.jpg", x) for x, name in zip(Noisy, file_names) ] Residual_Curve = np.array(Residual_Curve).mean(axis=0) np.save(save_path + "/residual_curve.npy", Residual_Curve) np.save(save_path + "/original.npy", Original) np.save(save_path + "/recovered.npy", Recovered) np.save(save_path + "/noisy.npy", Noisy) np.save(save_path + "/gen_steps.npy", gen_steps)