Пример #1
0
def GlowInpaint(args):
    loopOver = zip(args.gamma)
    for gamma in loopOver:
        skip_to_next = False  # flag to skip to next loop if recovery is fails due to instability
        n = args.size * args.size * 3
        modeldir = "./trained_models/%s/glow" % args.model
        test_folder = "./test_images/%s" % args.dataset
        save_path = "./results/%s/%s" % (args.dataset, args.experiment)

        # loading dataset
        trans = transforms.Compose(
            [transforms.Resize((args.size, args.size)),
             transforms.ToTensor()])
        test_dataset = datasets.ImageFolder(test_folder, transform=trans)
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.batchsize,
            drop_last=False,
            shuffle=False)

        # loading glow configurations
        config_path = modeldir + "/configs.json"
        with open(config_path, 'r') as f:
            configs = json.load(f)

        # regularizor
        gamma = torch.tensor(gamma,
                             requires_grad=True,
                             dtype=torch.float,
                             device=args.device)

        # getting test images
        Original = []
        Recovered = []
        Masked = []
        Mask = []
        Residual_Curve = []
        for i, data in enumerate(test_dataloader):
            # getting batch of data
            x_test = data[0]
            x_test = x_test.clone().to(device=args.device)
            n_test = x_test.size()[0]
            assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize"

            # generate mask
            mask = gen_mask(args.inpaint_method, args.size, args.mask_size)
            mask = np.array([mask for i in range(n_test)])
            mask = mask.reshape([n_test, 1, args.size, args.size])
            mask = torch.tensor(mask,
                                dtype=torch.float,
                                requires_grad=False,
                                device=args.device)

            # loading glow model
            glow = Glow((3, args.size, args.size),
                        K=configs["K"],
                        L=configs["L"],
                        coupling=configs["coupling"],
                        n_bits_x=configs["n_bits_x"],
                        nn_init_last_zeros=configs["last_zeros"],
                        device=args.device)
            glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt"))
            glow.eval()

            # making a forward to record shapes of z's for reverse pass
            _ = glow(glow.preprocess(torch.zeros_like(x_test)))

            # initializing z from Gaussian
            if args.init_strategy == "random":
                z_sampled = np.random.normal(0, args.init_std, [n_test, n])
                z_sampled = torch.tensor(z_sampled,
                                         requires_grad=True,
                                         dtype=torch.float,
                                         device=args.device)
            # initializing z from image with noise filled only in masked region
            elif args.init_strategy == "noisy_filled":
                x_noisy_filled = x_test.clone().detach()
                noise = np.random.normal(0, 0.2, x_noisy_filled.size())
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     device=args.device)
                noise = noise * (1 - mask)
                x_noisy_filled = x_noisy_filled + noise
                x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1)
                z, _, _ = glow(x_noisy_filled - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            # initializing z from image with masked region inverted
            elif args.init_strategy == "inverted_filled":
                x_inverted_filled = x_test.clone().detach()
                missing_x = x_inverted_filled.clone()
                missing_x = missing_x.data.cpu().numpy()
                missing_x = missing_x[:, :, ::-1, ::-1]
                missing_x = torch.tensor(missing_x.copy(),
                                         dtype=torch.float,
                                         device=args.device)
                missing_x = (1 - mask) * missing_x
                x_inverted_filled = x_inverted_filled * mask
                x_inverted_filled = x_inverted_filled + missing_x
                z, _, _ = glow(x_inverted_filled - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            # initializing z from masked image ( masked region as zeros )
            elif args.init_strategy == "black_filled":
                x_black_filled = x_test.clone().detach()
                x_black_filled = mask * x_black_filled
                x_black_filled = x_black_filled * mask
                z, _, _ = glow(x_black_filled - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            # initializing z from noisy complete image
            elif args.init_strategy == "noisy":
                x_noisy = x_test.clone().detach()
                noise = np.random.normal(0, 0.05, x_noisy.size())
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     device=args.device)
                x_noisy = x_noisy + noise
                x_noisy = torch.clamp(x_noisy, 0, 1)
                z, _, _ = glow(x_noisy - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            # initializing z from image with only noise in masked region
            elif args.init_strategy == "only_noise_filled":
                x_noisy_filled = x_test.clone().detach()
                noise = np.random.normal(0, 0.2, x_noisy_filled.size())
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     device=args.device)
                noise = noise * (1 - mask)
                x_noisy_filled = mask * x_noisy_filled + noise
                x_noisy_filled = torch.clamp(x_noisy_filled, 0, 1)
                z, _, _ = glow(x_noisy_filled - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = z.clone().detach().requires_grad_(True)
            else:
                raise "Initialization strategy not defined"

            # selecting optimizer
            if args.optim == "adam":
                optimizer = torch.optim.Adam(
                    [z_sampled],
                    lr=args.lr,
                )
            elif args.optim == "lbfgs":
                optimizer = torch.optim.LBFGS(
                    [z_sampled],
                    lr=args.lr,
                )

            # metrics to record over training
            psnr_t = torch.nn.MSELoss().to(device=args.device)
            residual = []
            # running optimizer steps
            for t in range(args.steps):

                def closure():
                    optimizer.zero_grad()
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)
                    x_masked_test = x_test * mask
                    x_masked_gen = x_gen * mask
                    global residual_t
                    residual_t = ((x_masked_gen - x_masked_test)**2).view(
                        len(x_masked_test), -1).sum(dim=1).mean()
                    z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean()
                    loss_t = residual_t + z_reg_loss_t
                    psnr = psnr_t(x_test, x_gen)
                    psnr = 10 * np.log10(1 / psnr.item())
                    print(
                        "\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f"
                        % (t, loss_t.item(), residual_t.item(),
                           z_reg_loss_t.item(), psnr),
                        end="\r")
                    loss_t.backward()
                    return loss_t

                try:
                    optimizer.step(closure)
                    residual.append(residual_t.item())
                except:
                    skip_to_next = True
                    break

            if skip_to_next:
                break

            # getting recovered and true images
            x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1)
            z_unflat = glow.unflatten_z(z_sampled, clone=False)
            x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
            x_gen = glow.postprocess(x_gen, floor_clamp=False)
            x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1)
            x_gen_np = np.clip(x_gen_np, 0, 1)
            mask_np = mask.data.cpu().numpy()
            x_masked_test = x_test * mask
            x_masked_test_np = x_masked_test.data.cpu().numpy().transpose(
                0, 2, 3, 1)
            x_masked_test_np = np.clip(x_masked_test_np, 0, 1)

            Original.append(x_test_np)
            Recovered.append(x_gen_np)
            Masked.append(x_masked_test_np)
            Residual_Curve.append(residual)
            Mask.append(mask_np)

            # freeing up memory for second loop
            glow.zero_grad()
            optimizer.zero_grad()
            del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, mask,
            torch.cuda.empty_cache()
            print("\nbatch completed")

        if skip_to_next:
            print(
                "\nskipping current loop due to instability or user triggered quit"
            )
            continue

        # metric evaluations
        Original = np.vstack(Original)
        Recovered = np.vstack(Recovered)
        Masked = np.vstack(Masked)
        Mask = np.vstack(Mask)
        psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)]

        # print performance analysis
        printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n"
        printout = printout + "\t n_test         = %d\n" % len(Recovered)
        printout = printout + "\t inpaint_method = %s\n" % args.inpaint_method
        printout = printout + "\t mask_size      = %0.3f\n" % args.mask_size
        printout = printout + "\t gamma          = %0.6f\n" % gamma
        printout = printout + "\t PSNR           = %0.3f\n" % np.mean(psnr)
        print(printout)
        if args.save_metrics_text:
            with open("%s_inpaint_glow_results.txt" % args.dataset, "a") as f:
                f.write('\n' + printout)

        # saving images
        if args.save_results:
            gamma = gamma.item()
            file_names = [
                name[0].split("/")[-1].split(".")[0]
                for name in test_dataset.samples
            ]
            if args.init_strategy == 'random':
                save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s"
                save_path = save_path % (args.inpaint_method, args.mask_size,
                                         gamma, args.steps, args.lr,
                                         args.init_std, args.optim)
            else:
                save_path = save_path + "/inpaint_%s_masksize_%0.4f_gamma_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s"
                save_path = save_path % (args.inpaint_method, args.mask_size,
                                         gamma, args.steps, args.lr,
                                         args.init_strategy, args.optim)

            if not os.path.exists(save_path):
                os.makedirs(save_path)
            else:
                save_path_1 = save_path + "_1"
                if not os.path.exists(save_path_1):
                    os.makedirs(save_path_1)
                    save_path = save_path_1
                else:
                    save_path_2 = save_path + "_2"
                    if not os.path.exists(save_path_2):
                        os.makedirs(save_path_2)
                        save_path = save_path_2

            _ = [
                sio.imsave(save_path + "/" + name + "_recov.jpg", x)
                for x, name in zip(Recovered, file_names)
            ]
            _ = [
                sio.imsave(save_path + "/" + name + "_masked.jpg", x)
                for x, name in zip(Masked, file_names)
            ]
            Residual_Curve = np.array(Residual_Curve).mean(axis=0)
            np.save(save_path + "/" + "residual_curve.npy", Residual_Curve)
            np.save(save_path + "/original.npy", Original)
            np.save(save_path + "/recovered.npy", Recovered)
            np.save(save_path + "/mask.npy", Mask)
            np.save(save_path + "/masked.npy", Masked)
Пример #2
0
def trainGlow(args):
    save_path = "./trained_models/%s/glow" % args.dataset
    training_folder = "./data/%s_preprocessed/train" % args.dataset

    # setting up configs as json
    config_path = save_path + "/configs.json"
    configs = {
        "K": args.K,
        "L": args.L,
        "coupling": args.coupling,
        "last_zeros": args.last_zeros,
        "batchsize": args.batchsize,
        "size": args.size,
        "lr": args.lr,
        "n_bits_x": args.n_bits_x,
        "warmup_iter": args.warmup_iter
    }

    if (args.squeeze_contig):
        configs["squeeze_contig"] = True
    if (args.coupling_bias > 0):
        configs["coupling_bias"] = args.coupling_bias
    if not os.path.exists(save_path):
        print("creating directory to save model weights")
        os.makedirs(save_path)

    # loading pre-trained model to resume training
    if os.path.exists(save_path + "/glowmodel.pt"):
        print(
            "loading previous model and saved configs to resume training ...")
        with open(config_path, 'r') as f:
            configs = json.load(f)
        glow = Glow((3, configs["size"], configs["size"]),
                    device=args.device,
                    **configs)
        glow.load_state_dict(torch.load(save_path + "/glowmodel.pt"))
        print("pre-trained model and configs loaded successfully")
        glow.set_actnorm_init()
        print(
            "actnorm initialization flag set to True to avoid data dependant re-initialization"
        )
        glow.train()

    else:
        # creating and initializing glow model
        print("creating and initializing model for training")
        glow = Glow((3, args.size, args.size),
                    K=args.K,
                    L=args.L,
                    coupling=args.coupling,
                    n_bits_x=args.n_bits_x,
                    nn_init_last_zeros=args.last_zeros,
                    device=args.device)
        glow.train()
        print("saving configs as json file")
        with open(config_path, 'w') as f:
            json.dump(configs, f, sort_keys=True, indent=4, ensure_ascii=False)

    # setting up dataloader
    print("setting up dataloader for the training data")
    trans = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop((args.size, args.size)),
        transforms.ToTensor()
    ])
    dataset = datasets.ImageFolder(training_folder, transform=trans)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batchsize,
                                             drop_last=True,
                                             shuffle=True)

    # setting up optimizer and learning rate scheduler
    opt = torch.optim.Adam(glow.parameters(), lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,
                                                              mode="min",
                                                              factor=0.5,
                                                              patience=1000,
                                                              verbose=True,
                                                              min_lr=1e-8)

    # starting training code here
    print("+-" * 10, "starting training", "-+" * 10)
    global_step = 0
    global_loss = []
    warmup_completed = False
    for i in range(args.epochs):
        Loss_epoch = []
        for j, data in enumerate(dataloader):
            opt.zero_grad()
            glow.zero_grad()
            # loading batch
            x = data[0].to(device=args.device) * 255
            # pre-processing data
            x = glow.preprocess(x)
            # computing loss: "nll"
            n, c, h, w = x.size()
            nll, logdet, logpz, z_mu, z_std = glow.nll_loss(x)
            # skipping first batch due to data dependant initialization (if not initialized)
            if global_step == 0:
                global_step += 1
                continue
            # backpropogating loss and gradient clipping
            nll.backward()
            torch.nn.utils.clip_grad_value_(glow.parameters(), 5)
            grad_norm = torch.nn.utils.clip_grad_norm_(glow.parameters(), 100)
            # linearly increase learning rate till warmup_iter upto args.lr
            if global_step <= args.warmup_iter:
                warmup_lr = args.lr / args.warmup_iter * global_step
                for params in opt.param_groups:
                    params["lr"] = warmup_lr
            # taking optimizer step
            opt.step()
            # learning rate scheduling after warm up iterations
            if global_step > args.warmup_iter:
                lr_scheduler.step(nll)
                if not warmup_completed:
                    if args.warmup_iter == 0:
                        print("no model warming...")
                    else:
                        print("\nwarm up completed")
                warmup_completed = True
            # printing training metrics
            print(
                "\repoch=%0.2d..nll=%0.2f..logdet=%0.2f..logpz=%0.2f..mu=%0.2f..std=%0.2f..gradnorm=%0.2f"
                % (i, nll.item(), logdet, logpz, z_mu, z_std, grad_norm),
                end="\r")
            # saving generated samples during training
            try:
                if j % args.sample_freq == 0:
                    plt.plot(global_loss)
                    plt.xlabel("iterations", size=15)
                    plt.ylabel("nll", size=15)
                    plt.savefig(save_path + "/nll_training_curve.jpg")
                    plt.close()
                    with torch.no_grad():
                        z_sample, z_sample_t = glow.generate_z(n=10,
                                                               mu=0,
                                                               std=0.7,
                                                               to_torch=True)
                        x_gen = glow(z_sample_t, reverse=True)
                        x_gen = glow.postprocess(x_gen)
                        x_gen = make_grid(x_gen, nrow=int(np.sqrt(len(x_gen))))
                        x_gen = x_gen.data.cpu().numpy()
                        x_gen = x_gen.transpose([1, 2, 0])
                        if x_gen.shape[-1] == 1:
                            x_gen = x_gen[..., 0]
                        if not os.path.exists(save_path + "/samples_training"):
                            os.makedirs(save_path + "/samples_training")
                        x_gen = (np.clip(x_gen, 0, 1) * 255).astype("uint8")
                        sio.imsave(
                            save_path +
                            "/samples_training/%0.6d.jpg" % global_step, x_gen)
            except:
                print("\n failed to sample from glow at global step = %d" %
                      global_step)
            global_step = global_step + 1
            global_loss.append(nll.item())
            if global_step % args.save_freq == 0:
                torch.save(glow.state_dict(), save_path + "/glowmodel.pt")


#    # model visualization
#    temperature = [0.1,0.3,0.4,0.5,0.7,0.8, 0.9]
#    for temp in temperature:
#        with torch.no_grad():
#            glow.eval()
#            z_sample, z_sample_t = glow.generate_z(n=10,mu=0,std=temp,to_torch=True)
#            x_gen = glow(z_sample_t, reverse=True)
#            x_gen = glow.postprocess(x_gen)
#            x_gen = make_grid(x_gen,nrow=int(np.sqrt(len(x_gen))))
#            x_gen = x_gen.data.cpu().numpy()
#            x_gen = x_gen.transpose([1,2,0])
#            if x_gen.shape[-1] == 1:
#                x_gen = x_gen[...,0]
#            plt.figure()
#            plt.title("temperature = %0.1f"%temp,fontsize=15)
#            plt.axis("off")
#            plt.imshow(x_gen)

# saving model weights
    torch.save(glow.state_dict(), save_path + "/glowmodel.pt")
Пример #3
0
def GlowDenoiser(args):
    loopOver = zip(args.gamma)

    # try different gamma values
    for gamma in loopOver:
        skip_to_next  = False # flag to skip to next loop if recovery is fails due to instability
        n             = args.size*args.size*3
        # modeldir      = f"./trained_models/{args.dataset}/glow-denoising"
        modeldir      = f"./trained_models/{args.dataset}/glow-cs-{args.size}"
        test_folder   = f"./test_images/{args.dataset}_N=12"
        save_path     = f"./results/{args.dataset}/{args.experiment}"

        # loading dataset
        trans           = transforms.Compose([transforms.Resize((args.size,args.size)),transforms.ToTensor()])
        test_dataset    = datasets.ImageFolder(test_folder, transform=trans)
        test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=args.batchsize,drop_last=False,shuffle=False)

        # loading glow configurations
        config_path = modeldir+"/configs.json"
        with open(config_path, 'r') as f:
            configs = json.load(f)

        # noiser
        noiser = Noiser(args, configs)

        # loss
        loss = recon_loss(args.noise, args.noise_loc, args.noise_scale)

        # regularizor
        gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device)

        # getting test images
        Original  = []
        Recovered = []
        Noisy     = []
        Noise     = []
        Residual_Curve = []
        for i, data in enumerate(test_dataloader):

            x_test = data[0]
            x_test = x_test.clone().to(device=args.device)
            n_test = x_test.size()[0]
            assert n_test == args.batchsize, \
                "please make sure that no. of images are evenly divided by batchsize"

            # loading glow model

            glow = Glow((3,args.size,args.size),
                        K=configs["K"],L=configs["L"],
                        coupling=configs["coupling"],
                        n_bits_x=configs["n_bits_x"],
                        nn_init_last_zeros=configs["last_zeros"],
                        device=args.device)
            glow.load_state_dict(torch.load(modeldir+"/glowmodel.pt", map_location=args.device))
            glow.eval()

            # add noise
            noise = noiser(x_test)
            x_noisy = x_test + noise
            x_noisy = torch.clamp(x_noisy, 0., 1.)

            # making a forward to record shapes of z's for reverse pass

            _ = glow(glow.preprocess(torch.zeros_like(x_test)))

            # np.random.seed(args.random_seed)
            if args.init_strategy == "random":
                z_sampled = np.random.normal(0, args.init_std, [n_test, n])
            elif args.init_strategy == "from-noisy":
                z, _, _     = glow(glow.preprocess(x_noisy*255,clone=True))
                z           = glow.flatten_z(z)
                z_sampled   = z.clone().detach().cpu().numpy()
            else:
                raise NotImplementedError("Initialization strategy not defined")

            z_sampled = torch.from_numpy(z_sampled).float().to(args.device)
            z_sampled = nn.Parameter(z_sampled, requires_grad=True)

            # selecting optimizer

            if args.optim == "adam":
                optimizer = torch.optim.Adam([z_sampled], lr=args.lr,)
            elif args.optim == "lbfgs":
                optimizer = torch.optim.LBFGS([z_sampled], lr=args.lr,)

            # to be recorded over iteration

            psnr_t    = torch.nn.MSELoss().to(device=args.device)
            residual  = []

            # running optimizer steps

            for t in range(args.steps):
                try:
                    def closure():
                        optimizer.zero_grad()
                        z_unflat = glow.unflatten_z(z_sampled, clone=False)
                        x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                        x_gen = glow.postprocess(x_gen,floor_clamp=False)
                        global residual_t
                        residual_t = loss(x_gen, x_noisy)
                        if not args.z_penalty_unsquared:
                            z_reg_loss_t= gamma*(z_sampled.norm(dim=1)**2).mean()
                        else:
                            z_reg_loss_t= gamma*z_sampled.norm(dim=1).mean()
                        loss_t = residual_t + z_reg_loss_t
                        psnr = psnr_t(x_test, x_gen)
                        psnr = 10 * np.log10(1 / psnr.item())
                        print("\rAt step=%0.3d|loss=%0.4f|residual=%0.4f|z_reg=%0.5f|psnr=%0.3f"%(
                            t, loss_t.item(), residual_t.item(), z_reg_loss_t.item(), psnr),
                              end="\r")
                        loss_t.backward(retain_graph=True)
                        return loss_t
                    optimizer.step(closure)
                    residual.append(residual_t.item())
                except Exception as e:
                    traceback.print_exc()
                    skip_to_next = True
                    break

            # if skip_to_next:
            #     break

            x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1)
            Original.append(x_test_np)

            noise_np = noise.data.cpu().numpy().transpose(0, 2, 3, 1)
            Noise.append(noise_np)

            x_noisy_np = x_noisy.data.cpu().numpy().transpose(0, 2, 3, 1)
            Noisy.append(x_noisy_np)

            Residual_Curve.append(residual)

            x_gen = None
            try:
                z_unflat = glow.unflatten_z(z_sampled, clone=False)
                x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                x_gen = glow.postprocess(x_gen, floor_clamp=False)

                x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1)
                x_gen_np = np.clip(x_gen_np, 0, 1)
                Recovered.append(x_gen_np)
            except Exception as e:
                traceback.print_exc()

            # freeing up memory for second loop
            glow.zero_grad()
            optimizer.zero_grad()
            del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, noise
            with torch.cuda.device(args.device):
                torch.cuda.empty_cache()
            print("\nbatch completed")

            # todo: remove this break after finishing development.
            break

        # if skip_to_next:
        #     print("\nskipping current loop due to instability or user triggered quit")
        #     continue

        # metric evaluations

        Original = np.vstack(Original)
        Noisy = np.vstack(Noisy)
        Noise = np.vstack(Noise)

        psnr = None
        try:
            Recovered = np.vstack(Recovered)
            psnr = [compare_psnr(x, y) for x,y in zip(Original, Recovered)]
        except Exception as e:
            continue

        # print performance analysis

        printout = "+-"*10 + "%s"%args.dataset + "-+"*10 + "\n"
        printout = printout + "\t n_test = %d\n"%len(Recovered)
        printout = printout + "\t noise_std = %0.4f\n"%args.noise_scale
        printout = printout + "\t gamma = %0.6f\n"%gamma
        if psnr is not None:
            printout = printout + "\t PSNR = %0.3f\n"%np.mean(psnr)
        print(printout)

        if args.save_metrics_text:
            with open("%s_denoising_glow_results.txt"%args.dataset,"a") as f:
                f.write('\n' + printout)

        # saving images

        if args.save_results:
            gamma = gamma.item()
            file_names = [name[0].split("/")[-1].split(".")[0] for name in test_dataset.samples]

            save_path = os.path.join(save_path, f'{args.noise}_'
                                                f'{args.noise_loc}#{args.noise_scale}_'
                                                f'{args.noise_channel}_{args.noise_area}_'
                                                f'{args.init_strategy}_'
                                                f'{round(gamma, 4)}_{gettime()}')
            print(save_path)

            if not os.path.exists(save_path):
                os.makedirs(save_path)
            else:
                save_path_1 = save_path + "_1"
                if not os.path.exists(save_path_1):
                    os.makedirs(save_path_1)
                    save_path = save_path_1
                else:
                    save_path_2 = save_path + "_2"
                    if not os.path.exists(save_path_2):
                        os.makedirs(save_path_2)
                        save_path = save_path_2
            _ = [sio.imsave(save_path+"/"+name+"_noisy.jpg", x) for x, name in zip(Noisy, file_names)]
            Residual_Curve = np.array(Residual_Curve).mean(axis=0)
            np.save(save_path+"/residual_curve.npy", Residual_Curve)
            np.save(save_path+"/original.npy", Original)
            np.save(save_path+"/noisy.npy", Noisy)
            np.save(save_path+"/noise.npy", Noise)

            if len(Recovered) > 0:
                np.save(save_path + "/recovered.npy", Recovered)
                _ = [sio.imsave(save_path + "/" + name + "_recov.jpg", x) for x, name in
                     zip(Recovered, file_names)]
Пример #4
0
def GlowREDCS(args, filename=None):
    if args.init_norms == None:
        args.init_norms = [None] * len(args.m)
    else:
        assert args.init_strategy == "random_fixed_norm", "init_strategy should be random_fixed_norm if init_norms is used"
    assert len(args.m) == len(args.gamma) == len(
        args.init_norms), "length of either m, gamma or init_norms are not same"
    loopOver = zip(args.m, args.gamma, args.init_norms)
    for m, gamma, init_norm in loopOver:
        skip_to_next = False  # flag to skip to next loop if recovery is fails due to instability
        n = args.size * args.size * 3
        modeldir = "./trained_models/%s/glow" % args.model
        test_folder = "./test_images/%s" % args.dataset
        save_path = "./results/%s/%s" % (args.dataset, args.experiment)

        # loading dataset
        trans = transforms.Compose([transforms.Resize((args.size, args.size)), transforms.ToTensor()])
        test_dataset = datasets.ImageFolder(test_folder, transform=trans)
        test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batchsize, drop_last=False,
                                                      shuffle=False)

        # loading glow configurations
        config_path = modeldir + "/configs.json"
        with open(config_path, 'r') as f:
            configs = json.load(f)

        # sensing matrix
        A = np.random.normal(0, 1 / np.sqrt(m), size=(n, m))
        A = torch.tensor(A, dtype=torch.float, requires_grad=False, device=args.device)

        # regularizor
        gamma = torch.tensor(gamma, requires_grad=True, dtype=torch.float, device=args.device)
        alpha = args.alpha
        beta = args.beta

        # adding noise
        if args.noise == "random_bora":
            noise = np.random.normal(0, 1, size=(args.batchsize, m))
            noise = noise * 0.1 / np.sqrt(m)
            noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device)
        else:
            noise = np.random.normal(0, 1, size=(args.batchsize, m))
            noise = noise / (np.linalg.norm(noise, 2, axis=-1, keepdims=True)) * float(args.noise)
            noise = torch.tensor(noise, dtype=torch.float, requires_grad=False, device=args.device)

        # start solving over batches
        Original = [];
        Recovered = [];
        Recovered_f = [];
        Z_Recovered = [];
        Residual_Curve = [];
        Recorded_Z = []
        for i, data in enumerate(test_dataloader):
            x_test = data[0]
            x_test = x_test.clone().to(device=args.device)
            n_test = x_test.size()[0]
            assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize"

            # loading glow model
            glow = Glow((3, args.size, args.size),
                        K=configs["K"], L=configs["L"],
                        coupling=configs["coupling"],
                        n_bits_x=configs["n_bits_x"],
                        nn_init_last_zeros=configs["last_zeros"],
                        device=args.device)
            glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt"))
            glow.eval()

            # making a forward to record shapes of z's for reverse pass
            _ = glow(glow.preprocess(torch.zeros_like(x_test)))

            # initializing z from Gaussian with std equal to init_std
            if args.init_strategy == "random":
                z_sampled = np.random.normal(0, args.init_std, [n_test, n])
                z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device)
            # intializing z from Gaussian and scaling its norm to init_norm
            elif args.init_strategy == "random_fixed_norm":
                z_sampled = np.random.normal(0, 1, [n_test, n])
                z_sampled = z_sampled / np.linalg.norm(z_sampled, axis=-1, keepdims=True)
                z_sampled = z_sampled * init_norm
                z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device)
                print("z intialized with a norm equal to = %0.1f" % init_norm)
            # initializing z from pseudo inverse
            elif args.init_strategy == "pseudo_inverse":
                x_test_flat = x_test.view([-1, n])
                y_true = torch.matmul(x_test_flat, A) + noise
                A_pinv = torch.pinverse(A)
                x_pinv = torch.matmul(y_true, A_pinv)
                x_pinv = x_pinv.view([-1, 3, args.size, args.size])
                x_pinv = torch.clamp(x_pinv, 0, 1)
                z, _, _ = glow(glow.preprocess(x_pinv * 255, clone=True))
                z = glow.flatten_z(z).clone().detach()
                z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device)
            # initializing z from a solution of lasso-wavelet
            elif args.init_strategy == "lasso_wavelet":
                new_args = {"batch_size": n_test, "lmbd": 0.01, "lasso_solver": "sklearn"}
                new_args = easydict.EasyDict(new_args)
                estimator = celebA_estimators.lasso_wavelet_estimator(new_args)
                x_ch_last = x_test.permute(0, 2, 3, 1)
                x_ch_last = x_ch_last.contiguous().view([-1, n])
                y_true = torch.matmul(x_ch_last, A) + noise
                x_lasso = estimator(np.sqrt(2 * m) * A.data.cpu().numpy(), np.sqrt(2 * m) * y_true.data.cpu().numpy(),
                                    new_args)
                x_lasso = np.array(x_lasso)
                x_lasso = x_lasso.reshape(-1, 64, 64, 3)
                x_lasso = x_lasso.transpose(0, 3, 1, 2)
                x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device)
                z, _, _ = glow(x_lasso - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device)
                print("z intialized from a solution of lasso-wavelet")
            elif args.init_strategy == "lasso_dct":
                new_args = {"batch_size": n_test, "lmbd": 0.01, "lasso_solver": "sklearn"}
                new_args = easydict.EasyDict(new_args)
                estimator = celebA_estimators.lasso_dct_estimator(new_args)
                x_ch_last = x_test.permute(0, 2, 3, 1)
                x_ch_last = x_ch_last.contiguous().view([-1, n])
                y_true = torch.matmul(x_ch_last, A) + noise
                x_lasso = estimator(np.sqrt(2 * m) * A.data.cpu().numpy(), np.sqrt(2 * m) * y_true.data.cpu().numpy(),
                                    new_args)
                x_lasso = np.array(x_lasso)
                x_lasso = x_lasso.reshape(-1, 64, 64, 3)
                x_lasso = x_lasso.transpose(0, 3, 1, 2)
                x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device)
                z, _, _ = glow(x_lasso - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device)
                print("z intialized from a solution of lasso-dct")
            elif args.init_strategy == "random_lasso_dct":
                new_args = {"batch_size": n_test, "lmbd": 0.01, "lasso_solver": "sklearn"}
                new_args = easydict.EasyDict(new_args)
                estimator = celebA_estimators.lasso_dct_estimator(new_args)
                x_ch_last = x_test.permute(0, 2, 3, 1)
                x_ch_last = x_ch_last.contiguous().view([-1, n])
                y_true = torch.matmul(x_ch_last, A) + noise
                x_lasso = estimator(np.sqrt(2 * m) * A.data.cpu().numpy(), np.sqrt(2 * m) * y_true.data.cpu().numpy(),
                                    new_args)
                x_lasso = np.array(x_lasso)
                x_lasso = x_lasso.reshape(-1, 64, 64, 3)
                x_lasso = x_lasso.transpose(0, 3, 1, 2)
                x_lasso = torch.tensor(x_lasso, dtype=torch.float, device=args.device)
                z_sampled = np.random.normal(0, args.init_std, [n_test, n])
                z_sampled = torch.tensor(z_sampled, requires_grad=True, dtype=torch.float, device=args.device)
                print("z intialized randomly and RED is initialized from a solution of lasso-dct")
            # intializing z from null(A)
            elif args.init_strategy == "null_space":
                x_test_flat = x_test.view([-1, n])
                x_test_flat_np = x_test_flat.data.cpu().numpy()
                A_np = A.data.cpu().numpy()
                nullA = null_space(A_np.T)
                coeff = np.random.normal(0, 1, (args.batchsize, nullA.shape[1]))
                x_null = np.array([(nullA * c).sum(axis=-1) for c in coeff])
                pert_norm = 5  # <-- 5 gives optimal results --  bad initialization and not too unstable
                x_null = x_null / np.linalg.norm(x_null, axis=1, keepdims=True) * pert_norm
                x_perturbed = x_test_flat_np + x_null
                # no clipping x_perturbed to make sure forward model is ||y-Ax|| is the same
                err = np.matmul(x_test_flat_np, A_np) - np.matmul(x_perturbed, A_np)
                assert (err ** 2).sum() < 1e-6, "null space does not satisfy ||y-A(x+x0)|| <= 1e-6"
                x_perturbed = x_perturbed.reshape(-1, 3, args.size, args.size)
                x_perturbed = torch.tensor(x_perturbed, dtype=torch.float, device=args.device)
                z, _, _ = glow(x_perturbed - 0.5)
                z = glow.flatten_z(z).clone().detach()
                z_sampled = torch.tensor(z, requires_grad=True, dtype=torch.float, device=args.device)
                print("z initialized from a point in null space of A")
            else:
                raise "Initialization strategy not defined"

            # selecting optimizer
            if args.optim == "adam":
                optimizer = torch.optim.Adam([z_sampled], lr=args.lr, )
            elif args.optim == "lbfgs":
                optimizer = torch.optim.LBFGS([z_sampled], lr=args.lr, )
            elif args.optim == "rk2":
                optimizer = RK2Heun([z_sampled], lr=args.lr)
            elif args.optim == "raghav":
                optimizer = RK2Raghav([z_sampled], lr=args.lr)
            elif args.optim == "sgd":
                optimizer = torch.optim.SGD([z_sampled], lr=args.lr, momentum=0.9)
            else:
                raise "optimizer not defined"

            # to be recorded over iteration
            psnr_t = torch.nn.MSELoss().to(device=args.device)
            residual = [];
            recorded_z = []
            x_f = x_lasso.clone()
            u = torch.zeros_like(x_test)

            df_losses = pd.DataFrame(columns=["loss_t", "residual_t", "residual_x", "z_reg_loss"])

            ##################
            alpha = args.alpha
            beta = args.beta
            ##################

            # running optimizer steps
            for t in range(args.steps):
                def closure():
                    optimizer.zero_grad()
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)
                    x_test_flat = x_test.view([-1, n])
                    x_gen_flat = x_gen.view([-1, n])
                    y_true = torch.matmul(x_test_flat, A) + noise
                    y_gen = torch.matmul(x_gen_flat, A)
                    global residual_t
                    residual_t = ((y_gen - y_true) ** 2).sum(dim=1).mean()
                    z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean()
                    residual_x = beta * ((x_gen - (x_f - u)) ** 2).view(len(x_f), -1).sum(dim=1).mean()
                    loss_t = residual_t + z_reg_loss_t + residual_x
                    psnr = psnr_t(x_test, x_gen)
                    psnr = 10 * np.log10(1 / psnr.item())
                    print("At step=%0.3d|loss=%0.4f|residual_t=%0.4f|residual_x=%0.4f|z_reg=%0.5f|psnr=%0.3f" % (
                    t, loss_t.item(), residual_t.item(), residual_x.item(), z_reg_loss_t.item(), psnr))
                    loss_t.backward()

                    update = [loss_t.item(), residual_t.item(), residual_x.item(), z_reg_loss_t.item()]
                    df_losses.loc[(len(df_losses))] = update

                    df_losses.to_csv(filename)

                    return loss_t

                def denoiser_step(x_f, u):
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True, reverse_clone=False).detach()
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)

                    x_f = 1 / (beta + alpha) * (beta * Denoiser(args.denoiser, args.sigma_f, x_f) + alpha * (x_gen + u))
                    u = u + x_gen - x_f
                    return x_f, u

                optimizer.step(closure)
                recorded_z.append(z_sampled.data.cpu().numpy())
                residual.append(residual_t.item())
                if t % args.update_iter == args.update_iter - 1:
                    x_f, u = denoiser_step(x_f, u)

                # if t == args.steps//2:
                #     gamma /= 10

                # try:
                #     optimizer.step(closure)
                #     recorded_z.append(z_sampled.data.cpu().numpy())
                #     residual.append(residual_t.item())
                # except:
                #     # try may not work due to instability in the reverse direction.
                #     skip_to_next = True
                #     break

            if skip_to_next:
                break

            # getting recovered and true images
            with torch.no_grad():
                x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1)
                z_unflat = glow.unflatten_z(z_sampled, clone=False)
                x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                x_gen = glow.postprocess(x_gen, floor_clamp=False)
                x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1)
                x_gen_np = np.clip(x_gen_np, 0, 1)
                x_f_np = x_f.cpu().numpy().transpose(0, 2, 3, 1)
                x_f_np = np.clip(x_f_np, 0, 1)
                z_recov = z_sampled.data.cpu().numpy()
            Original.append(x_test_np)
            Recovered.append(x_gen_np)
            Recovered_f.append(x_f_np)
            Z_Recovered.append(z_recov)
            Residual_Curve.append(residual)
            Recorded_Z.append(recorded_z)

            # freeing up memory for second loop
            glow.zero_grad()
            optimizer.zero_grad()
            del x_test, x_gen, optimizer, psnr_t, z_sampled, glow
            torch.cuda.empty_cache()
            print("\nbatch completed")

        if skip_to_next:
            print("\nskipping current loop due to instability or user triggered quit")
            continue

        # collecting everything together
        Original = np.vstack(Original)
        Recovered = np.vstack(Recovered)
        Recovered_f = np.vstack(Recovered_f)
        Z_Recovered = np.vstack(Z_Recovered)
        Recorded_Z = np.vstack(Recorded_Z)
        psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)]
        psnr_f = [compare_psnr(x, y) for x, y in zip(Original, Recovered_f)]
        z_recov_norm = np.linalg.norm(Z_Recovered, axis=-1)

        # print performance analysis
        printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n"
        printout = printout + "\t n_test        = %d\n" % len(Recovered)
        printout = printout + "\t n             = %d\n" % n
        printout = printout + "\t m             = %d\n" % m
        printout = printout + "\t update_iter    = %0.4f\n" % args.update_iter
        printout = printout + "\t gamma         = %0.6f\n" % gamma
        printout = printout + "\t alpha          = %0.6f\n" % alpha
        printout = printout + "\t beta           = %0.6f\n" % beta
        printout = printout + "\t optimizer     = %s\n" % args.optim
        printout = printout + "\t lr            = %0.3f\n" % args.lr
        printout = printout + "\t steps         = %0.3f\n" % args.steps
        printout = printout + "\t init_strategy = %s\n" % args.init_strategy
        printout = printout + "\t init_std      = %0.3f\n" % args.init_std
        if init_norm is not None:
            printout = printout + "\t init_norm     = %0.3f\n" % init_norm
        printout = printout + "\t z_recov_norm  = %0.3f\n" % np.mean(z_recov_norm)
        printout = printout + "\t mean PSNR     = %0.3f\n" % (np.mean(psnr))
        printout = printout + "\t mean PSNR_f   = %0.3f\n" % (np.mean(psnr_f))
        print(printout)

        # saving printout
        if args.save_metrics_text:
            with open("%s_cs_glow_results.txt" % args.dataset, "a") as f:
                f.write('\n' + printout)

        # setting folder to save results in
        if args.save_results:
            gamma = gamma.item()
            file_names = [name[0].split("/")[-1] for name in test_dataset.samples]
            if args.init_strategy == "random":
                save_path_template = save_path + "/cs_m_%d_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s"
                save_path = save_path_template % (m, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_std, args.optim)
            elif args.init_strategy == "random_fixed_norm":
                save_path_template = save_path + "/cs_m_%d_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_%s_%0.3f_optim_%s"
                save_path = save_path_template % (
                m, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_strategy, init_norm, args.optim)
            else:
                save_path_template = save_path + "/cs_m_%d_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_%s_optim_%s"
                save_path = save_path_template % (m, args.update_iter, gamma, alpha, beta, args.steps, args.lr, args.init_strategy, args.optim)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            else:
                save_path_1 = save_path + "_1"
                if not os.path.exists(save_path_1):
                    os.makedirs(save_path_1)
                    save_path = save_path_1
                else:
                    save_path_2 = save_path + "_2"
                    if not os.path.exists(save_path_2):
                        os.makedirs(save_path_2)
                        save_path = save_path_2
            # saving results now
            _ = [sio.imsave(save_path + "/" + name, x) for x, name in zip(Recovered, file_names)]
            print(save_path+"/"+file_names[0])
            _ = [sio.imsave(save_path + "/f_" + name, x) for x, name in zip(Recovered_f, file_names)]
            Residual_Curve = np.array(Residual_Curve).mean(axis=0)
            np.save(save_path + "/original.npy", Original)
            np.save(save_path + "/recovered.npy", Recovered)
            np.save(save_path + "/recovered_f.npy", Recovered_f)
            np.save(save_path + "/z_recovered.npy", Z_Recovered)
            np.save(save_path + "/residual_curve.npy", Residual_Curve)
            if init_norm is not None:
                np.save(save_path + "/Recorded_Z_init_norm_%d.npy" % init_norm, Recorded_Z)
        torch.cuda.empty_cache()
Пример #5
0
def GlowREDDenoiser(args):
    loopOver = zip(args.gamma)
    for gamma in loopOver:
        skip_to_next = False  # flag to skip to next loop if recovery is fails due to instability
        n = args.size * args.size * 3
        modeldir = "./trained_models/%s/glow" % args.model
        test_folder = "./test_images/%s" % args.dataset
        save_path = "./results/%s/%s" % (args.dataset, args.experiment)

        # loading dataset
        trans = transforms.Compose(
            [transforms.Resize((args.size, args.size)),
             transforms.ToTensor()])
        test_dataset = datasets.ImageFolder(test_folder, transform=trans)
        test_dataloader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.batchsize,
            drop_last=False,
            shuffle=False)

        # loading glow configurations
        config_path = modeldir + "/configs.json"
        with open(config_path, 'r') as f:
            configs = json.load(f)

        # regularizor
        gamma = torch.tensor(gamma,
                             requires_grad=True,
                             dtype=torch.float,
                             device=args.device)
        alpha = args.alpha
        beta = args.beta

        # getting test images
        gen_steps = []
        Original = []
        Recovered = []
        Noisy = []
        Residual_Curve = []
        for i, data in enumerate(test_dataloader):
            # getting batch of data
            x_test = data[0]
            x_test = x_test.clone().to(device=args.device)
            n_test = x_test.size()[0]
            assert n_test == args.batchsize, "please make sure that no. of images are evenly divided by batchsize"

            # noise to be added
            if args.noise == "gaussian":
                noise = np.random.normal(0,
                                         args.noise_std,
                                         size=(n_test, 3, args.size,
                                               args.size))
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     requires_grad=False,
                                     device=args.device)
            elif args.noise == "laplacian":
                noise = np.random.laplace(scale=args.noise_std,
                                          size=(n_test, 3, args.size,
                                                args.size))
                noise = torch.tensor(noise,
                                     dtype=torch.float,
                                     requires_grad=False,
                                     device=args.device)
                raise "code only supports gaussian for now"  # -> no noise type tag in the folder name
            else:
                raise "noise type not defined"

            # loading glow model
            glow = Glow((3, args.size, args.size),
                        K=configs["K"],
                        L=configs["L"],
                        coupling=configs["coupling"],
                        n_bits_x=configs["n_bits_x"],
                        nn_init_last_zeros=configs["last_zeros"],
                        device=args.device)
            glow.load_state_dict(torch.load(modeldir + "/glowmodel.pt"))
            glow.eval()

            # making a forward to record shapes of z's for reverse pass
            _ = glow(glow.preprocess(torch.zeros_like(x_test)))

            # initializing z from Gaussian
            if args.init_strategy == "random":
                z_sampled = np.random.normal(0, args.init_std, [n_test, n])
                z_sampled = torch.tensor(z_sampled,
                                         requires_grad=True,
                                         dtype=torch.float,
                                         device=args.device)
            # initializing z from noisy image
            elif args.init_strategy == "from-noisy":
                x_noisy = x_test + noise
                z, _, _ = glow(glow.preprocess(x_noisy * 255, clone=True))
                z = glow.flatten_z(z)
                z_sampled = z.clone().detach().requires_grad_(True)
            else:
                raise "Initialization strategy not defined"

            # selecting optimizer
            if args.optim == "adam":
                optimizer = torch.optim.Adam(
                    [z_sampled],
                    lr=args.lr,
                )
            elif args.optim == "lbfgs":
                optimizer = torch.optim.LBFGS(
                    [z_sampled],
                    lr=args.lr,
                )

            # to be recorded over iteration
            psnr_t = torch.nn.MSELoss().to(device=args.device)
            residual = []
            x_f = (x_test + noise).clone()
            u = torch.zeros_like(x_test)

            # running optimizer steps
            for t in range(args.steps):

                def closure():
                    optimizer.zero_grad()
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)
                    x_noisy = x_test + noise
                    global residual_t
                    residual_t = ((x_gen - x_noisy)**2).view(
                        len(x_noisy), -1).sum(dim=1).mean()
                    if args.z_penalty_squared:
                        z_reg_loss_t = gamma * (z_sampled.norm(dim=1)**
                                                2).mean()
                    else:
                        z_reg_loss_t = gamma * z_sampled.norm(dim=1).mean()
                    residual_x = beta * ((x_gen - (x_f - u))**2).view(
                        len(x_noisy), -1).sum(dim=1).mean()
                    loss_t = residual_t + z_reg_loss_t + residual_x
                    psnr = psnr_t(x_test, x_gen)
                    psnr = 10 * np.log10(1 / psnr.item())
                    print(
                        "\rAt step=%0.3d|loss=%0.4f|residual_t=%0.4f|residual_x=%0.4f|z_reg=%0.5f|psnr=%0.3f"
                        % (t, loss_t.item(), residual_t.item(),
                           residual_x.item(), z_reg_loss_t.item(), psnr),
                        end="\r")
                    loss_t.backward()
                    return loss_t

                def denoiser_step(x_f, u):
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True,
                                 reverse_clone=False).detach()
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)

                    x_f = 1 / (beta + alpha) * (
                        beta * Denoiser(args.denoiser, args.sigma_f, x_f) +
                        alpha * (x_gen + u))
                    u = u + x_gen - x_f
                    return x_f, u

                optimizer.step(closure)
                residual.append(residual_t.item())
                if t % args.update_iter == args.update_iter - 1:
                    x_f, u = denoiser_step(x_f, u)
                with torch.no_grad():
                    z_unflat = glow.unflatten_z(z_sampled, clone=False)
                    x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
                    x_gen = glow.postprocess(x_gen, floor_clamp=False)
                    x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1)
                    x_gen_np = np.clip(x_gen_np, 0, 1)
                    gen_steps.append(x_gen_np)

                # try:
                #     optimizer.step(closure)
                #     residual.append(residual_t.item())
                #     if t % args.update_iter == 0:
                #         x_f, u = denoiser_step(x_f, u)
                #
                # except:
                #     skip_to_next = True
                #     break

            if skip_to_next:
                break

            # getting recovered and true images
            x_test_np = x_test.data.cpu().numpy().transpose(0, 2, 3, 1)
            z_unflat = glow.unflatten_z(z_sampled, clone=False)
            x_gen = glow(z_unflat, reverse=True, reverse_clone=False)
            x_gen = glow.postprocess(x_gen, floor_clamp=False)
            x_gen_np = x_gen.data.cpu().numpy().transpose(0, 2, 3, 1)
            x_gen_np = np.clip(x_gen_np, 0, 1)
            x_noisy = x_test + noise
            x_noisy_np = x_noisy.data.cpu().numpy().transpose(0, 2, 3, 1)
            x_noisy_np = np.clip(x_noisy_np, 0, 1)

            Original.append(x_test_np)
            Recovered.append(x_gen_np)
            Noisy.append(x_noisy_np)
            Residual_Curve.append(residual)

            # freeing up memory for second loop
            glow.zero_grad()
            optimizer.zero_grad()
            del x_test, x_gen, optimizer, psnr_t, z_sampled, glow, noise,
            torch.cuda.empty_cache()
            print("\nbatch completed")

        if skip_to_next:
            print(
                "\nskipping current loop due to instability or user triggered quit"
            )
            continue

        # metric evaluations
        Original = np.vstack(Original)
        Recovered = np.vstack(Recovered)
        gen_steps = np.vstack(gen_steps)
        Noisy = np.vstack(Noisy)
        psnr = [compare_psnr(x, y) for x, y in zip(Original, Recovered)]

        # print performance analysis
        printout = "+-" * 10 + "%s" % args.dataset + "-+" * 10 + "\n"
        printout = printout + "\t n_test     = %d\n" % len(Recovered)
        printout = printout + "\t noise_std  = %0.4f\n" % args.noise_std
        printout = printout + "\t update_iter= %0.4f\n" % args.update_iter
        printout = printout + "\t gamma      = %0.6f\n" % gamma
        printout = printout + "\t alpha      = %0.6f\n" % alpha
        printout = printout + "\t beta       = %0.6f\n" % beta
        printout = printout + "\t PSNR       = %0.3f\n" % np.mean(psnr)
        print(printout)
        if args.save_metrics_text:
            with open("%s_denoising_glow_results.txt" % args.dataset,
                      "a") as f:
                f.write('\n' + printout)

        # saving images
        if args.save_results:
            gamma = gamma.item()
            file_names = [
                name[0].split("/")[-1].split(".")[0]
                for name in test_dataset.samples
            ]
            save_path = save_path + "/denoising_noisestd_%0.4f_updateiter_%0.4f_gamma_%0.6f_alpha_%0.6f_beta_%0.6f_steps_%d_lr_%0.3f_init_std_%0.2f_optim_%s"
            save_path = save_path % (args.noise_std, args.update_iter, gamma,
                                     alpha, beta, args.steps, args.lr,
                                     args.init_std, args.optim)
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            else:
                save_path_1 = save_path + "_1"
                if not os.path.exists(save_path_1):
                    os.makedirs(save_path_1)
                    save_path = save_path_1
                else:
                    save_path_2 = save_path + "_2"
                    if not os.path.exists(save_path_2):
                        os.makedirs(save_path_2)
                        save_path = save_path_2
            _ = [
                sio.imsave(save_path + "/" + name + "_recov.jpg", x)
                for x, name in zip(Recovered, file_names)
            ]
            _ = [
                sio.imsave(save_path + "/" + name + "_noisy.jpg", x)
                for x, name in zip(Noisy, file_names)
            ]
            Residual_Curve = np.array(Residual_Curve).mean(axis=0)
            np.save(save_path + "/residual_curve.npy", Residual_Curve)
            np.save(save_path + "/original.npy", Original)
            np.save(save_path + "/recovered.npy", Recovered)
            np.save(save_path + "/noisy.npy", Noisy)
            np.save(save_path + "/gen_steps.npy", gen_steps)