Exemple #1
0
def GlowInpaint(args):
    loopOver = zip(args.gamma)
    for gamma in loopOver:
        skip_to_next = False  # flag to skip to next loop if recovery is fails due to instability
        n = args.size * args.size * 3
        modeldir = "./trained_models/%s/glow" % args.model
        test_folder = "./test_images/%s" % args.dataset
        save_path = "./results/%s/%s" % (args.dataset, args.experiment)

        # loading dataset
        trans = transforms.Compose(
            [transforms.Resize((args.size, args.size)),
             transforms.ToTensor()])
        test_dataset = datasets.ImageFolder(test_folder, transform=trans)
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.batchsize,
            drop_last=False,
            shuffle=False)

        # loading glow configurations
        config_path = modeldir + "/configs.json"
        with open(config_path, 'r') as f:
            configs = json.load(f)

        # regularizor
        gamma = torch.tensor(gamma,
                             requires_grad=True,
                             dtype=torch.float,
                             device=args.device)

        # getting test images
        Original = []
        Recovered = []
        Masked = []
        Mask = []
        Residual_Curve = []
        for i, data in enumerate(test_dataloader):
            # getting batch of data
            x_test = data[0]
            x_test = x_test.clone().to(device=args.device)
            n_test = x_test.size()[0]
            assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize"

            # generate mask
            mask = gen_mask(args.inpaint_method, args.size, args.mask_size)
            mask = np.array([mask for i in range(n_test)])
            mask = mask.reshape([n_test, 1, args.size, args.size])
            mask = torch.tensor(mask,
                                dtype=torch.float,
                                requires_grad=False,
                                device=args.device)

            # loading glow model
            glow = Glow((3, args.size, args.size),
                        K=configs["K"],
                        L=configs["L"],
                        coupling=configs["coupling"],
                        n_bits_x=configs["n_bits_x"],
                        nn_init_last_zeros=configs["last_zeros"],
                        device=args.device)
            glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt"))
            glow.eval()

            # making a forward to record shapes of z's for reverse pass
            _ = glow(glow.preprocess(torch.zeros_like(x_test)))

            # initializing z from Gaussian
            if args.init_strategy == "random":
                z_sampled = np.random.normal(0, args.init_std, [n_test, n])
                z_sampled = torch.tensor(z_sampled,
                                         requires_grad=True,
                                         dtype=torch.float,
                                         device=args.device)
            # initializing z from image with noise filled only in masked region
            elif args.init_strategy == "noisy_filled":
                x_noisy_filled = x_test.clone().detach()
                noise = np.random.normal(0, 0.2, x_noisy_filled.size())
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     device=args.device)
                noise = noise * (1 - mask)
                x_noisy_filled = x_noisy_filled + noise
                x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1)
                z, _, _ = glow(x_noisy_filled - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            # initializing z from image with masked region inverted
            elif args.init_strategy == "inverted_filled":
                x_inverted_filled = x_test.clone().detach()
                missing_x = x_inverted_filled.clone()
                missing_x = missing_x.data.cpu().numpy()
                missing_x = missing_x[:, :, ::-1, ::-1]
                missing_x = torch.tensor(missing_x.copy(),
                                         dtype=torch.float,
                                         device=args.device)
                missing_x = (1 - mask) * missing_x
                x_inverted_filled = x_inverted_filled * mask
                x_inverted_filled = x_inverted_filled + missing_x
                z, _, _ = glow(x_inverted_filled - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            # initializing z from masked image ( masked region as zeros )
            elif args.init_strategy == "black_filled":
                x_black_filled = x_test.clone().detach()
                x_black_filled = mask * x_black_filled
                x_black_filled = x_black_filled * mask
                z, _, _ = glow(x_black_filled - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            # initializing z from noisy complete image
            elif args.init_strategy == "noisy":
                x_noisy = x_test.clone().detach()
                noise = np.random.normal(0, 0.05, x_noisy.size())
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     device=args.device)
                x_noisy = x_noisy + noise
                x_noisy = torch.clamp(x_noisy, 0, 1)
                z, _, _ = glow(x_noisy - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            # initializing z from image with only noise in masked region
            elif args.init_strategy == "only_noise_filled":
                x_noisy_filled = x_test.clone().detach()
                noise = np.random.normal(0, 0.2, x_noisy_filled.size())
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     device=args.device)
                noise = noise * (1 - mask)
                x_noisy_filled = mask * x_noisy_filled + noise
                x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1)
                z, _, _ = glow(x_noisy_filled - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            else:
                raise "Initialization strategy not defined"

            # selecting optimizer
            if args.optim == "adam":
                optimizer = torch.optim.Adam(
                    [z_sampled],
                    lr=args.lr,
                )
            elif args.optim == "lbfgs":
                optimizer = torch.optim.LBFGS(
                    [z_sampled],
                    lr=args.lr,
                )

            # metrics to record over training
            psnr_t = torch.nn.MSELoss().to(device=args.device)
            residual = []
            # running optimizer steps
            for t in range(args.steps):

                def closure():
                    optimizer.zero_grad()
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)
                    x_masked_test = x_test * mask
                    x_masked_gen = x_gen * mask
                    global residual_t
                    residual_t = ((x_masked_gen - x_masked_test)**2).view(
                        len(x_masked_test), -1).sum(dim=1).mean()
                    if args.z_penalty_unsquared:
                        z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean()
                    else:
                        z_reg_loss_t = gamma * (z_sampled.norm(dim=1)**
                                                2).mean()
                    loss_t = residual_t + z_reg_loss_t
                    psnr = psnr_t(x_test, x_gen)
                    psnr = 10 * np.log10(1 / psnr.item())
                    print(
                        "\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f"
                        % (t, loss_t.item(), residual_t.item(),
                           z_reg_loss_t.item(), psnr),
                        end="\r")
                    loss_t.backward()
                    return loss_t

                try:
                    optimizer.step(closure)
                    residual.append(residual_t.item())
                except:
                    skip_to_next = True
                    break

            if skip_to_next:
                break

            # getting recovered and true images
            x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1)
            z_unflat = glow.unflatten_z(z_sampled, clone=False)
            x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
            x_gen = glow.postprocess(x_gen, floor_clamp=False)
            x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1)
            x_gen_np = np.clip(x_gen_np, 0, 1)
            mask_np = mask.data.cpu().numpy()
            x_masked_test = x_test * mask
            x_masked_test_np = x_masked_test.data.cpu().numpy().transpose(
                0, 2, 3, 1)
            x_masked_test_np = np.clip(x_masked_test_np, 0, 1)

            Original.append(x_test_np)
            Recovered.append(x_gen_np)
            Masked.append(x_masked_test_np)
            Residual_Curve.append(residual)
            Mask.append(mask_np)

            # freeing up memory for second loop
            glow.zero_grad()
            optimizer.zero_grad()
            del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, mask,
            torch.cuda.empty_cache()
            print("\nbatch completed")

        if skip_to_next:
            print(
                "\nskipping current loop due to instability or user triggered quit"
            )
            continue

        # metric evaluations
        Original = np.vstack(Original)
        Recovered = np.vstack(Recovered)
        Masked = np.vstack(Masked)
        Mask = np.vstack(Mask)
        psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)]

        # print performance analysis
        printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n"
        printout = printout + "\t n_test         = %d\n" % len(Recovered)
        printout = printout + "\t inpaint_method = %s\n" % args.inpaint_method
        printout = printout + "\t mask_size      = %0.3f\n" % args.mask_size
        printout = printout + "\t gamma          = %0.6f\n" % gamma
        printout = printout + "\t PSNR           = %0.3f\n" % np.mean(psnr)
        print(printout)
        if args.save_metrics_text:
            with open("%s_inpaint_glow_results.txt" % args.dataset, "a") as f:
                f.write('\n' + printout)

        # saving images
        if args.save_results:
            gamma = gamma.item()
            file_names = [
                name[0].split("/")[-1].split(".")[0]
                for name in test_dataset.samples
            ]
            if args.init_strategy == 'random':
                save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s"
                save_path = save_path % (args.inpaint_method, args.mask_size,
                                         gamma, args.steps, args.lr,
                                         args.init_std, args.optim)
            else:
                save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s"
                save_path = save_path % (args.inpaint_method, args.mask_size,
                                         gamma, args.steps, args.lr,
                                         args.init_strategy, args.optim)

            if not os.path.exists(save_path):
                os.makedirs(save_path)
            else:
                save_path_1 = save_path + "_1"
                if not os.path.exists(save_path_1):
                    os.makedirs(save_path_1)
                    save_path = save_path_1
                else:
                    save_path_2 = save_path + "_2"
                    if not os.path.exists(save_path_2):
                        os.makedirs(save_path_2)
                        save_path = save_path_2

            _ = [
                sio.imsave(save_path + "/" + name + "_recov.jpg", x)
                for x, name in zip(Recovered, file_names)
            ]
            _ = [
                sio.imsave(save_path + "/" + name + "_masked.jpg", x)
                for x, name in zip(Masked, file_names)
            ]
            Residual_Curve = np.array(Residual_Curve).mean(axis=0)
            np.save(save_path + "/" + "residual_curve.npy", Residual_Curve)
            np.save(save_path + "/original.npy", Original)
            np.save(save_path + "/recovered.npy", Recovered)
            np.save(save_path + "/mask.npy", Mask)
            np.save(save_path + "/masked.npy", Masked)
Exemple #2
0
def GlowREDDenoiser(args):
    loopOver = zip(args.gamma)
    for gamma in loopOver:
        skip_to_next = False  # flag to skip to next loop if recovery is fails due to instability
        n = args.size * args.size * 3
        modeldir = "./trained_models/%s/glow" % args.model
        test_folder = "./test_images/%s" % args.dataset
        save_path = "./results/%s/%s" % (args.dataset, args.experiment)

        # loading dataset
        trans = transforms.Compose(
            [transforms.Resize((args.size, args.size)),
             transforms.ToTensor()])
        test_dataset = datasets.ImageFolder(test_folder, transform=trans)
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.batchsize,
            drop_last=False,
            shuffle=False)

        # loading glow configurations
        config_path = modeldir + "/configs.json"
        with open(config_path, 'r') as f:
            configs = json.load(f)

        # regularizor
        gamma = torch.tensor(gamma,
                             requires_grad=True,
                             dtype=torch.float,
                             device=args.device)
        alpha = args.alpha
        beta = args.beta

        # getting test images
        gen_steps = []
        Original = []
        Recovered = []
        Noisy = []
        Residual_Curve = []
        for i, data in enumerate(test_dataloader):
            # getting batch of data
            x_test = data[0]
            x_test = x_test.clone().to(device=args.device)
            n_test = x_test.size()[0]
            assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize"

            # noise to be added
            if args.noise == "gaussian":
                noise = np.random.normal(0,
                                         args.noise_std,
                                         size=(n_test, 3, args.size,
                                               args.size))
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     requires_grad=False,
                                     device=args.device)
            elif args.noise == "laplacian":
                noise = np.random.laplace(scale=args.noise_std,
                                          size=(n_test, 3, args.size,
                                                args.size))
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     requires_grad=False,
                                     device=args.device)
                raise "code only supports gaussian for now"  # -> no noise type tag in the folder name
            else:
                raise "noise type not defined"

            # loading glow model
            glow = Glow((3, args.size, args.size),
                        K=configs["K"],
                        L=configs["L"],
                        coupling=configs["coupling"],
                        n_bits_x=configs["n_bits_x"],
                        nn_init_last_zeros=configs["last_zeros"],
                        device=args.device)
            glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt"))
            glow.eval()

            # making a forward to record shapes of z's for reverse pass
            _ = glow(glow.preprocess(torch.zeros_like(x_test)))

            # initializing z from Gaussian
            if args.init_strategy == "random":
                z_sampled = np.random.normal(0, args.init_std, [n_test, n])
                z_sampled = torch.tensor(z_sampled,
                                         requires_grad=True,
                                         dtype=torch.float,
                                         device=args.device)
            # initializing z from noisy image
            elif args.init_strategy == "from-noisy":
                x_noisy = x_test + noise
                z, _, _ = glow(glow.preprocess(x_noisy * 255, clone=True))
                z = glow.flatten_z(z)
                z_sampled = z.clone().detach().requires_grad_(True)
            else:
                raise "Initialization strategy not defined"

            # selecting optimizer
            if args.optim == "adam":
                optimizer = torch.optim.Adam(
                    [z_sampled],
                    lr=args.lr,
                )
            elif args.optim == "lbfgs":
                optimizer = torch.optim.LBFGS(
                    [z_sampled],
                    lr=args.lr,
                )

            # to be recorded over iteration
            psnr_t = torch.nn.MSELoss().to(device=args.device)
            residual = []
            x_f = (x_test + noise).clone()
            u = torch.zeros_like(x_test)

            # running optimizer steps
            for t in range(args.steps):

                def closure():
                    optimizer.zero_grad()
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)
                    x_noisy = x_test + noise
                    global residual_t
                    residual_t = ((x_gen - x_noisy)**2).view(
                        len(x_noisy), -1).sum(dim=1).mean()
                    if args.z_penalty_squared:
                        z_reg_loss_t = gamma * (z_sampled.norm(dim=1)**
                                                2).mean()
                    else:
                        z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean()
                    residual_x = beta * ((x_gen - (x_f - u))**2).view(
                        len(x_noisy), -1).sum(dim=1).mean()
                    loss_t = residual_t + z_reg_loss_t + residual_x
                    psnr = psnr_t(x_test, x_gen)
                    psnr = 10 * np.log10(1 / psnr.item())
                    print(
                        "\rAt step=%0.3d|loss=%0.4f|residual_t=%0.4f|residual_x=%0.4f|z_reg=%0.5f|psnr=%0.3f"
                        % (t, loss_t.item(), residual_t.item(),
                           residual_x.item(), z_reg_loss_t.item(), psnr),
                        end="\r")
                    loss_t.backward()
                    return loss_t

                def denoiser_step(x_f, u):
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True,
                                 reverse_clone=False).detach()
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)

                    x_f = 1 / (beta + alpha) * (
                        beta * Denoiser(args.denoiser, args.sigma_f, x_f) +
                        alpha * (x_gen + u))
                    u = u + x_gen - x_f
                    return x_f, u

                optimizer.step(closure)
                residual.append(residual_t.item())
                if t % args.update_iter == args.update_iter - 1:
                    x_f, u = denoiser_step(x_f, u)
                with torch.no_grad():
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)
                    x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1)
                    x_gen_np = np.clip(x_gen_np, 0, 1)
                    gen_steps.append(x_gen_np)

                # try:
                #     optimizer.step(closure)
                #     residual.append(residual_t.item())
                #     if t % args.update_iter == 0:
                #         x_f, u = denoiser_step(x_f, u)
                #
                # except:
                #     skip_to_next = True
                #     break

            if skip_to_next:
                break

            # getting recovered and true images
            x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1)
            z_unflat = glow.unflatten_z(z_sampled, clone=False)
            x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
            x_gen = glow.postprocess(x_gen, floor_clamp=False)
            x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1)
            x_gen_np = np.clip(x_gen_np, 0, 1)
            x_noisy = x_test + noise
            x_noisy_np = x_noisy.data.cpu().numpy().transpose(0, 2, 3, 1)
            x_noisy_np = np.clip(x_noisy_np, 0, 1)

            Original.append(x_test_np)
            Recovered.append(x_gen_np)
            Noisy.append(x_noisy_np)
            Residual_Curve.append(residual)

            # freeing up memory for second loop
            glow.zero_grad()
            optimizer.zero_grad()
            del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, noise,
            torch.cuda.empty_cache()
            print("\nbatch completed")

        if skip_to_next:
            print(
                "\nskipping current loop due to instability or user triggered quit"
            )
            continue

        # metric evaluations
        Original = np.vstack(Original)
        Recovered = np.vstack(Recovered)
        gen_steps = np.vstack(gen_steps)
        Noisy = np.vstack(Noisy)
        psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)]

        # print performance analysis
        printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n"
        printout = printout + "\t n_test     = %d\n" % len(Recovered)
        printout = printout + "\t noise_std  = %0.4f\n" % args.noise_std
        printout = printout + "\t update_iter= %0.4f\n" % args.update_iter
        printout = printout + "\t gamma      = %0.6f\n" % gamma
        printout = printout + "\t alpha      = %0.6f\n" % alpha
        printout = printout + "\t beta       = %0.6f\n" % beta
        printout = printout + "\t PSNR       = %0.3f\n" % np.mean(psnr)
        print(printout)
        if args.save_metrics_text:
            with open("%s_denoising_glow_results.txt" % args.dataset,
                      "a") as f:
                f.write('\n' + printout)

        # saving images
        if args.save_results:
            gamma = gamma.item()
            file_names = [
                name[0].split("/")[-1].split(".")[0]
                for name in test_dataset.samples
            ]
            save_path = save_path + "/denoising_noisestd_%0.4f_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s"
            save_path = save_path % (args.noise_std, args.update_iter, gamma,
                                     alpha, beta, args.steps, args.lr,
                                     args.init_std, args.optim)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            else:
                save_path_1 = save_path + "_1"
                if not os.path.exists(save_path_1):
                    os.makedirs(save_path_1)
                    save_path = save_path_1
                else:
                    save_path_2 = save_path + "_2"
                    if not os.path.exists(save_path_2):
                        os.makedirs(save_path_2)
                        save_path = save_path_2
            _ = [
                sio.imsave(save_path + "/" + name + "_recov.jpg", x)
                for x, name in zip(Recovered, file_names)
            ]
            _ = [
                sio.imsave(save_path + "/" + name + "_noisy.jpg", x)
                for x, name in zip(Noisy, file_names)
            ]
            Residual_Curve = np.array(Residual_Curve).mean(axis=0)
            np.save(save_path + "/residual_curve.npy", Residual_Curve)
            np.save(save_path + "/original.npy", Original)
            np.save(save_path + "/recovered.npy", Recovered)
            np.save(save_path + "/noisy.npy", Noisy)
            np.save(save_path + "/gen_steps.npy", gen_steps)
Exemple #3
0
def GlowCS(args):
    if args.init_norms == None:
        args.init_norms = [None]*len(args.m)
    else:
        assert args.init_strategy == "random_fixed_norm", "init_strategy should be random_fixed_norm if init_norms is used"
    assert len(args.m) == len(args.gamma) == len(args.init_norms), "length of either m, gamma or init_norms are not same"
    loopOver = zip(args.m, args.gamma, args.init_norms)

    for m, gamma, init_norm in loopOver:
        skip_to_next = False # flag to skip to next loop if recovery is fails due to instability
        n                  = args.size*args.size*3
        modeldir           = os.path.join(root_path, "trained_models/%s/glow-cs-%d"%(args.model, args.size))
        test_folder        = os.path.join(root_path, "test_images/%s_N=12"%args.dataset)
        save_path          = os.path.join(root_path, "results/%s/%s"%(args.dataset,args.experiment))

        # loading dataset
        trans           = transforms.Compose([transforms.Resize((args.size,args.size)),transforms.ToTensor()])
        test_dataset    = datasets.ImageFolder(test_folder, transform=trans)
        test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=args.batchsize,drop_last=False,shuffle=False)
        
        # loading glow configurations
        config_path = modeldir+"/configs.json"
        with open(config_path, 'r') as f:
            configs = json.load(f)
        
        # sensing matrix
        A = np.random.normal(0,1/np.sqrt(m), size=(n,m))
        A = torch.tensor(A,dtype=torch.float, requires_grad=False, device=args.device)
        
        # regularizor
        gamma     = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device)
    
        # adding noise
        if  args.noise == "random_bora":
            noise = np.random.normal(0,1,size=(args.batchsize,m))
            noise = noise * 0.1/np.sqrt(m)
            noise = torch.tensor(noise,dtype=torch.float,requires_grad=False, device=args.device)
        else:
            noise = np.random.normal(0,1,size=(args.batchsize,m))
            noise = noise / (np.linalg.norm(noise,2,axis=-1, keepdims=True)) * float(args.noise)
            noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device)
        
        # start solving over batches
        Original = []; Recovered = []; Z_Recovered = []; Residual_Curve = []; Recorded_Z = []
        for i, data in enumerate(test_dataloader):
            x_test = data[0]
            x_test = x_test.clone().to(device=args.device)
            n_test = x_test.size()[0]
            assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize"
            
            # loading glow model
            print(f'load glow from: {modeldir}...')
            glow = Glow((3,args.size,args.size),
                        K=configs["K"],L=configs["L"],
                        coupling=configs["coupling"],
                        n_bits_x=configs["n_bits_x"],
                        nn_init_last_zeros=configs["last_zeros"],
                        device=args.device).to(args.device)
            print(glow.device)

            glow.load_state_dict(torch.load(modeldir+"/glowmodel.pt", map_location=args.device))
            glow.eval()
            # making a forward to record shapes of z's for reverse pass
            _ = glow(glow.preprocess(torch.zeros_like(x_test)))
            
            # initializing z from Gaussian with std equal to init_std
            if args.init_strategy == "random":
                z_sampled = np.random.normal(0,args.init_std,[n_test,n])
                z_sampled = torch.tensor(z_sampled,requires_grad=True,dtype=torch.float,device=args.device)
            # intializing z from Gaussian and scaling its norm to init_norm
            elif args.init_strategy == "random_fixed_norm":
                z_sampled = np.random.normal(0,1,[n_test,n])
                z_sampled = z_sampled / np.linalg.norm(z_sampled, axis=-1, keepdims=True)
                z_sampled = z_sampled * init_norm
                z_sampled = torch.tensor(z_sampled,requires_grad=True,dtype=torch.float,device=args.device)
                print("z intialized with a norm equal to = %0.1f"%init_norm)
            # initializing z from pseudo inverse
            elif args.init_strategy == "pseudo_inverse":
                x_test_flat = x_test.view([-1,n])
                y_true      = torch.matmul(x_test_flat, A) + noise
                A_pinv      = torch.pinverse(A)
                x_pinv      = torch.matmul(y_true, A_pinv)
                x_pinv      = x_pinv.view([-1,3,args.size,args.size])
                x_pinv      = torch.clamp(x_pinv,0,1)
                z, _, _     = glow(glow.preprocess(x_pinv*255,clone=True))
                z           = glow.flatten_z(z).clone().detach()
                z_sampled   = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device)
            # initializing z from a solution of lasso-wavelet 
            elif args.init_strategy == "lasso_wavelet":
                new_args    = {"batch_size":n_test, "lmbd":0.01,"lasso_solver":"sklearn"}
                new_args    = easydict.EasyDict(new_args)   
                estimator   = celebA_estimators.lasso_wavelet_estimator(new_args)
                x_ch_last   = x_test.permute(0,2,3,1)
                x_ch_last   = x_ch_last.contiguous().view([-1,n])
                y_true      = torch.matmul(x_ch_last, A) + noise
                x_lasso     = estimator(np.sqrt(2*m)*A.data.cpu().numpy(), np.sqrt(2*m)*y_true.data.cpu().numpy(), new_args)
                x_lasso     = np.array(x_lasso)
                x_lasso     = x_lasso.reshape(-1,64,64,3)
                x_lasso     = x_lasso.transpose(0,3,1,2)
                x_lasso     = torch.tensor(x_lasso, dtype=torch.float, device=args.device)
                z, _, _     = glow(x_lasso - 0.5)
                z           = glow.flatten_z(z).clone().detach()
                z_sampled   = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device)
                print("z intialized from a solution of lasso-wavelet")
            # intializing z from null(A)
            elif args.init_strategy == "null_space":
                x_test_flat    = x_test.view([-1,n])
                x_test_flat_np = x_test_flat.data.cpu().numpy()
                A_np        = A.data.cpu().numpy()
                nullA       = null_space(A_np.T)
                coeff       = np.random.normal(0,1,(args.batchsize, nullA.shape[1]))            
                x_null      = np.array([(nullA * c).sum(axis=-1) for c in coeff])
                pert_norm   = 5 # <-- 5 gives optimal results --  bad initialization and not too unstable
                x_null      = x_null / np.linalg.norm(x_null, axis=1, keepdims=True) * pert_norm
                x_perturbed = x_test_flat_np + x_null
                # no clipping x_perturbed to make sure forward model is ||y-Ax|| is the same
                err         = np.matmul(x_test_flat_np,A_np) - np.matmul(x_perturbed,A_np)
                assert (err **2).sum() < 1e-6, "null space does not satisfy ||y-A(x+x0)|| <= 1e-6"
                x_perturbed = x_perturbed.reshape(-1,3,args.size,args.size)
                x_perturbed = torch.tensor(x_perturbed, dtype=torch.float, device=args.device)
                z, _, _     = glow(x_perturbed - 0.5)
                z           = glow.flatten_z(z).clone().detach()
                z_sampled   = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device)
                print("z initialized from a point in null space of A")
            else:
                raise "Initialization strategy not defined"
                        
            # selecting optimizer
            if args.optim == "adam":
                optimizer = torch.optim.Adam([z_sampled], lr=args.lr,)
            elif args.optim == "lbfgs":
                optimizer = torch.optim.LBFGS([z_sampled], lr=args.lr,)
            else:
                raise "optimizer not defined"
            
            # to be recorded over iteration
            psnr_t    = torch.nn.MSELoss().to(device=args.device)
            residual  = []; recorded_z = []
            # running optimizer steps
            for t in range(args.steps):
                def closure():
                    optimizer.zero_grad()
                    z_unflat    = glow.unflatten_z(z_sampled, clone=False)
                    x_gen       = glow(z_unflat, reverse=True, reverse_clone=False)
                    x_gen       = glow.postprocess(x_gen,floor_clamp=False)
                    x_test_flat = x_test.view([-1,n])
                    x_gen_flat  = x_gen.view([-1,n])
                    y_true      = torch.matmul(x_test_flat, A) + noise
                    y_gen       = torch.matmul(x_gen_flat, A)
                    global residual_t
                    residual_t = ((y_gen - y_true)**2).sum(dim=1).mean()
                    if not args.z_penalty_unsquared:
                        z_reg_loss_t= gamma*(z_sampled.norm(dim=1)**2).mean()
                    else:
                        z_reg_loss_t= gamma*z_sampled.norm(dim=1).mean()
                    loss_t      = residual_t + z_reg_loss_t
                    psnr        = psnr_t(x_test, x_gen)
                    psnr        = 10 * np.log10(1 / psnr.item())
                    print("\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f"%(t,loss_t.item(),residual_t.item(),z_reg_loss_t.item(), psnr),end="\r")
                    loss_t.backward(retain_graph=True)
                    return loss_t
                try:
                    optimizer.step(closure)
                    recorded_z.append(z_sampled.data.cpu().numpy())
                    residual.append(residual_t.item())
                except Exception as e:
                    traceback.print_exc()
                    # try may not work due to instability in the reverse direction.
                    skip_to_next = True
                    break
            
            if skip_to_next:
                break
            
            # getting recovered and true images
            with torch.no_grad():
                x_test_np = x_test.data.cpu().numpy().transpose(0,2,3,1)
                z_unflat  = glow.unflatten_z(z_sampled, clone=False)
                x_gen     = glow(z_unflat, reverse=True, reverse_clone=False)
                x_gen     = glow.postprocess(x_gen,floor_clamp=False)
                x_gen_np  = x_gen.data.cpu().numpy().transpose(0,2,3,1)
                x_gen_np  = np.clip(x_gen_np,0,1)
                z_recov   = z_sampled.data.cpu().numpy()
            Original.append(x_test_np)
            Recovered.append(x_gen_np)
            Z_Recovered.append(z_recov)
            Residual_Curve.append(residual)
            Recorded_Z.append(recorded_z)
                    
            # freeing up memory for second loop
            glow.zero_grad()
            optimizer.zero_grad()
            del x_test, x_gen, optimizer, psnr_t, z_sampled, glow
            with torch.cuda.device(args.device):
                torch.cuda.empty_cache()
            print("\nbatch completed")
        
        if skip_to_next:
            print("\nskipping current loop due to instability or user triggered quit")
            continue
    
        # collecting everything together 
        Original     = np.vstack(Original)
        Recovered    = np.vstack(Recovered)
        Z_Recovered  = np.vstack(Z_Recovered)
        Recorded_Z   = np.vstack(Recorded_Z) 
        psnr         = [compare_psnr(x, y) for x,y in zip(Original, Recovered)]
        z_recov_norm = np.linalg.norm(Z_Recovered, axis=-1)
        
        # print performance analysis
        printout = "+-"*10 + "%s"%args.dataset + "-+"*10 + "\n"
        printout = printout + "\t n_test        = %d\n"%len(Recovered)
        printout = printout + "\t n             = %d\n"%n
        printout = printout + "\t m             = %d\n"%m
        printout = printout + "\t gamma         = %0.6f\n"%gamma
        printout = printout + "\t optimizer     = %s\n"%args.optim
        printout = printout + "\t lr            = %0.3f\n"%args.lr
        printout = printout + "\t steps         = %0.3f\n"%args.steps
        printout = printout + "\t init_strategy = %s\n"%args.init_strategy
        printout = printout + "\t init_std      = %0.3f\n"%args.init_std
        if init_norm is not None:
            printout = printout + "\t init_norm     = %0.3f\n"%init_norm
        printout = printout + "\t z_recov_norm  = %0.3f\n"%np.mean(z_recov_norm)
        printout = printout + "\t PSNR          = %0.3f\n"%(np.mean(psnr))
        print(printout)
        
        # saving printout
        if args.save_metrics_text:
            with open("%s_cs_glow_results.txt"%args.dataset,"a") as f:
                f.write('\n' + printout)
    
        
        # setting folder to save results in 
        if args.save_results:
            gamma = gamma.item()
            file_names = [name[0].split("/")[-1] for name in test_dataset.samples]
            if args.init_strategy == "random":
                save_path_template = save_path + "/cs_m_%d_gamma_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s"
                save_path = save_path_template%(m,gamma,args.steps,args.lr,args.init_std,args.optim)
            elif args.init_strategy == "random_fixed_norm":
                save_path_template = save_path+"/cs_m_%d_gamma_%0.6f_steps_%d_lr_%0.3f_init_%s_%0.3f_optim_%s"
                save_path          = save_path_template%(m,gamma,args.steps,args.lr,args.init_strategy,init_norm, args.optim)
            else:
                save_path_template = save_path + "/cs_m_%d_gamma_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s"
                save_path          = save_path_template%(m,gamma,args.steps,args.lr,args.init_strategy,args.optim)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            else:
                save_path_1 = save_path + "_1"
                if not os.path.exists(save_path_1):
                    os.makedirs(save_path_1)
                    save_path = save_path_1
                else:
                    save_path_2 = save_path + "_2"
                    if not os.path.exists(save_path_2):
                        os.makedirs(save_path_2)
                        save_path = save_path_2
            # saving results now
            _ = [sio.imsave(save_path+"/"+name, x) for x,name in zip(Recovered,file_names)]
            Residual_Curve = np.array(Residual_Curve).mean(axis=0)
            np.save(save_path+"/original.npy", Original)
            np.save(save_path+"/recovered.npy", Recovered)
            np.save(save_path+"/z_recovered.npy", Z_Recovered)
            np.save(save_path+"/residual_curve.npy", Residual_Curve)                
            if init_norm is not None:
                np.save(save_path+"/Recorded_Z_init_norm_%d.npy"%init_norm, Recorded_Z) 
        torch.cuda.empty_cache()