Esempio n. 1
0
def main():

    parser = argparse.ArgumentParser(description="Train a DCGAN on CIFAR10")
    parser.add_argument("--n_epochs", type=int, default=25, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=2 ** 5, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument(
        "--sample_interval",
        type=int,
        default=100,
        help="interval between image sampling. The number refers to the number of minibatch updates.",
    )
    parser.add_argument(
        "--save_model_interval", type=int, default=10, help="Save the generator once every this many epochs."
    )
    parser.add_argument("--prob_model_dir", type=str, help="interval between image sampling")
    parser.add_argument(
        "--classes", type=int, help="a list of integers (0-9) denoting the classes to consider", nargs="+"
    )

    # --------------------------------
    args = parser.parse_args()

    # op is a dict
    op = vars(args)
    if op["classes"] is None:
        # classes not specified => consider all classes
        op["classes"] = list(range(10))

    classes = sorted(op["classes"])
    cls_str = "".join(map(str, classes))
    if op["prob_model_dir"] is None:
        # use the list of classes to name the prob_model_dir
        prob_model_dir_name = "cifar10_c{}-dcgan".format(cls_str)
        op["prob_model_dir"] = glo.prob_model_folder(prob_model_dir_name)

    log.l().info("Options used: ")
    pprint.pprint(op)

    dcgan = DCGAN(**op)
    model_fname = "cifar10_c{}-dcgan-ep{}_bs{}.pt".format(cls_str, op["n_epochs"], op["batch_size"])
    model_fpath = os.path.join(dcgan.prob_model_dir, model_fname)

    # train
    log.l().info("Starting training a DCGAN on CIFAR10")
    dcgan.train()

    # save the generator
    g = dcgan.generator
    log.l().info("Saving the trained model to: {}".format(model_fpath))
    g.save(model_fpath)
Esempio n. 2
0
    def train(self):
        """
        Traing a DCGAN model with the training hyperparameters as specified in
        the constructor. Directly modify the state of this object to store all
        relevant variables.

        * self.generator stores the trained generator.
        """

        # Loss function
        adversarial_loss = torch.nn.BCELoss()

        # Initialize generator and discriminator
        img_size = 32
        minmax = (0.0, 1.0)

        # f_noise = lambda n: sample_standard_normal(n, self.latent_dim)
        # generator = ConvTranGenerator1(latent_dim=self.latent_dim,
        #         f_noise=f_noise, channels=3,  minmax=minmax)
        # generator = ReluGenerator1(latent_dim=self.latent_dim,
        #         f_noise=f_noise, channels=3,  minmax=minmax)
        generator = PatsornGenerator1(latent_dim=self.latent_dim, channels=3, minmax=minmax)
        # generator = SlowConvTransGenerator1(latent_dim=self.latent_dim,
        #         channels=3,  minmax=minmax)
        discriminator = Discriminator(channels=3, minmax=minmax)
        cuda = True if torch.cuda.is_available() else False

        if self.use_cuda and cuda:
            generator.cuda()
            discriminator.cuda()
            adversarial_loss.cuda()

        # Initialize weights
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

        # Configure data loader
        os.makedirs(self.data_dir, exist_ok=True)
        # trdata = torchvision.datasets.CIFAR10(self.data_dir, train=True, download=True,
        #                         transform=transforms.Compose([
        #                            transforms.ToTensor(),
        # #                            transforms.Normalize((0.1307,), (0.3081,))
        #                        ]))

        print("classes to use to train: {}".format(self.classes))
        trdata = cifar10_util.load_cifar10_class_subsets(self.classes, train=True, device="cpu", dtype=torch.float)
        print("dataset size: {}".format(len(trdata)))

        dataloader = torch.utils.data.DataLoader(trdata, batch_size=self.batch_size, shuffle=True, drop_last=True)

        # Optimizers
        optimizer_G = torch.optim.Adam(generator.parameters(), lr=self.lr, betas=(self.b1, self.b2))
        optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

        # ----------
        #  Training
        # ----------

        # noise vectors for saving purpose
        z_save = generator.sample_noise(25).type(Tensor)
        for epoch in range(self.n_epochs):
            for i, (imgs, _) in enumerate(dataloader):

                # Adversarial ground truths
                valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(Tensor))

                # -----------------
                #  Train Generator
                # -----------------

                optimizer_G.zero_grad()

                # Sample noise as generator input
                z = Variable(generator.sample_noise(imgs.shape[0]).type(Tensor))

                # Generate a batch of images
                gen_imgs = generator(z)

                # Loss measures generator's ability to fool the discriminator
                g_loss = adversarial_loss(discriminator(gen_imgs), valid)

                g_loss.backward()
                optimizer_G.step()

                # ---------------------
                #  Train Discriminator
                # ---------------------
                optimizer_D.zero_grad()
                # Measure discriminator's ability to classify real from generated samples
                real_loss = adversarial_loss(discriminator(real_imgs), valid)
                fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2

                d_loss.backward()
                optimizer_D.step()

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                    % (epoch, self.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
                )

                batches_done = epoch * len(dataloader) + i
                if batches_done % self.sample_interval == 0:
                    with torch.no_grad():
                        gev = generator.eval()
                        gen_save = gev(z_save)
                    save_image(
                        gen_save.data[:25], "%s/%06d.png" % (self.prob_model_dir, batches_done), nrow=5, normalize=False
                    )

                # keep the state of the generator
                self.generator = generator

            # Save the model once in a while
            if (epoch + 1) % self.save_model_interval == 0:
                model_fname = DCGAN.make_model_file_name(self.classes, epoch + 1, self.batch_size)
                model_fpath = os.path.join(self.prob_model_dir, model_fname)
                log.l().info("Save the generator after {} epochs. Save to: {}".format(epoch + 1, model_fpath))
                generator.save(model_fpath)
Esempio n. 3
0
def pt_gkmm(
    g,
    cond_imgs,
    extractor,
    k,
    Z,
    optimizer,
    sum_writer,
    input_weights=None,
    z_penalty=TPNull(),
    device=torch.device("cpu"),
    tensor_type=torch.FloatTensor,
    n_opt_iter=500,
    seed=1,
    texture=0,
    img_log_steps=10,
    weigh_logits=0,
    log_img_dir="",
):
    """
    Conditionally generate images conditioning on the input images (cond_imgs)
    using kernel moment matching.

    * g: a generator of type torch.nn.Module (forward() takes noise vectors
        and tranforms them into images). Need to be differentiable for the optimization.
    * cond_imgs: a stack of input images to condition on. Pixel value range
        should be [0,1]
    * extractor: an instance of torch.nn.Module representing a
        feature extractor for image input.
    * k: cadgan.kernel.PTKernel representing a kernel on top of the extracted
        features.
    * Z: a stack of noise vectors to be optimized. These are fed to the
        generator g for the optimization.
    * optimizer: a Pytorch optimizer. The list of variables to optimize has to
        contain Z.
    * sum_writer: SummaryWriter for tensorboard.
    * input_weights: a one-dimensional Torch tensor (vector) whose length is the
        same as the number of conditioned images. Specifies weights of the
        conditioned images. 0 <= w_i <= 1 and weights sum to 1.
        If None, automatically set to uniform weight.s
    * z_penalty: a TensorPenalty to penalize Z. Set to TPNull() to set to
        penalty.
    * device: a object constructed from torch.device(..). Likely this might be
        torch.device('cuda') or torch.device('cpu'). Use CPU by default.
    * tensor_type: Default Pytorch tensor type to use e.g., torch.FloatTensor
        or torch.cuda.FloatTensor. Use torch.FloatTensor by default (for cpu)
    * n_opt_iter: number of iterations for the optimization
    * seed: random seed (positive integer)
    * img_log_steps: record generated images once every this many
        optimization steps.
    * weigh_logits: to weight the output logits of feature extactor so that we can
        backpropagate w.r.t certain image feature.

    Write output in a Tensorboard log.
    """
    # Check generator's output range and image pixel range
    # We work with [0, 1]
    pixel_values_check(cond_imgs, (0, 1), "cond_imgs")
    tmp_sam = g.forward(Z)
    pixel_values_check(tmp_sam, (0, 1), "generator's output")

    # number of images to condition on
    n_cond = cond_imgs.shape[0]

    if input_weights is None:
        # None => set to uniform weights.
        input_weights = torch.ones(
            n_cond, device=device).type(tensor_type) / float(n_cond)

    # Check the rangeo of input_weights. Has to be in [0,1]
    if not ((input_weights >= 0.0).all() and (input_weights <= 1.0).all()):
        raise ValueError(
            '"input_weights" contains at least one weight which is outside [0,1] interval. Was {}'
            .format(input_weights))
    # Check that the weights sum to 1
    if torch.abs(input_weights.sum() - 1.0) > 1e-3:
        raise ValueError('"input_weights" does not sum to one. Was {}'.format(
            input_weights.sum()))

    gens_cpu = tmp_sam.to(torch.device("cpu"))
    arranged_init_imgs = torchvision.utils.make_grid(gens_cpu,
                                                     nrow=2,
                                                     normalize=True)
    log.l().debug('Adding initial generated images to Tensorboard')
    sum_writer.add_image("Init_Images", arranged_init_imgs)

    del tmp_sam

    # Setting requires_grad=True is very important. We will optimize Z.
    Z.requires_grad = True
    # number of images to generate
    n_sample = Z.shape[0]

    # Put models on gpu if needed
    # with torch.enable_grad():
    #    g = g.to(device)

    # Select a test image from the generated images
    arranged_cond_imgs = torchvision.utils.make_grid(cond_imgs,
                                                     nrow=2,
                                                     normalize=True)
    sum_writer.add_image("Cond_Images", arranged_cond_imgs)

    with torch.no_grad():
        FX_ = extractor.forward(cond_imgs)
        FX = FX_
        if weigh_logits:
            FX = weighing_logits(FX)

    # mean_KFX = torch.mean(k.eval(FX, FX))
    kFX = k.eval(FX, FX)
    mean_KFX = kFX.mv(input_weights).dot(input_weights)
    time_per_itr = []
    loss_all = []
    for t in range(n_opt_iter):

        def closure():
            Z.data.clamp_(-3.3, 3.3)
            optimizer.zero_grad()

            gens = g.forward(Z)
            if gens.size()[3] == 1024:
                # Downsample images else it takes a lot of time in optimization
                # TODO: WJ: To downsample, it is better to do it before calling this function.
                # Condiitonal generation function does not need to handle this.
                downsample = torch.nn.AvgPool2d(3, stride=2)
                gens = downsample(downsample(gens))

            if t <= -1 or t % img_log_steps == 0 or t == n_opt_iter - 1:
                gens_cpu = gens.to(torch.device("cpu"))
                imutil.save_images(
                    gens_cpu, os.path.join(log_img_dir, "output_images",
                                           str(t)))
                arranged_gens = torchvision.utils.make_grid(gens_cpu,
                                                            nrow=2,
                                                            normalize=True)
                log.l().debug(
                    'Logging generated images at iteration {}'.format(t + 1))
                sum_writer.add_image("Generated_Images", arranged_gens, t)

            F_gz = extractor.forward(gens)
            # import pdb; pdb.set_trace()
            if t <= -1 or t % img_log_steps == 0 or t == n_opt_iter - 1:
                feature_size = int(np.sqrt(F_gz.shape[1]))
                # import pdb; pdb.set_trace()
                try:
                    feat_out = F_gz.view(F_gz.shape[0], 1, feature_size,
                                         feature_size)
                    gens_cpu = feat_out.to(torch.device("cpu"))
                    imutil.save_images(
                        gens_cpu,
                        os.path.join(log_img_dir, "feature_images", str(t)))
                    arranged_init_imgs = torchvision.utils.make_grid(
                        gens_cpu, nrow=2, normalize=True)
                    sum_writer.add_image("feature_images", arranged_init_imgs,
                                         t)
                except:
                    if t == 0:
                        log.l().debug(
                            "Unable to plot features as image. Okay. Will skip plotting features."
                        )

            if weigh_logits:
                # WJ: This option is not really used. Should be removed.
                F_gz = weighing_logits(F_gz)
            KF_gz = k.eval(F_gz, F_gz)

            Z_loss = z_penalty(Z)
            mmd2 = torch.mean(KF_gz) - 2.0 * torch.mean(
                k.eval(F_gz, FX).mv(input_weights)) + mean_KFX
            loss = mmd2 + Z_loss

            # compute the gradients
            loss.backward(retain_graph=True)

            # record losses
            sum_writer.add_scalar("loss/total", loss.item(), t)
            sum_writer.add_scalar("loss/mmd2", mmd2.item(), t)
            sum_writer.add_scalar("loss/Z_penalty", Z_loss, t)

            # record some statistics
            sum_writer.add_scalar("Z/max_z", torch.max(Z), t)
            sum_writer.add_scalar("Z/min_z", torch.min(Z), t)
            sum_writer.add_scalar("Z/avg_z", torch.mean(Z), t)
            sum_writer.add_scalar("Z/std_z", torch.std(Z), t)
            sum_writer.add_histogram("Z/hist", Z.reshape(-1), t)

            loss_all.append(mmd2.item())

            if t <= 20 or t % 20 == 0:
                log.l().info("Iter [{}], overall_loss: {}".format(
                    t, loss.item()))
            return loss

        #    start_time = datetime.datetime.now()
        optimizer.step(closure)
Esempio n. 4
0
        def closure():
            Z.data.clamp_(-3.3, 3.3)
            optimizer.zero_grad()

            gens = g.forward(Z)
            if gens.size()[3] == 1024:
                # Downsample images else it takes a lot of time in optimization
                # TODO: WJ: To downsample, it is better to do it before calling this function.
                # Condiitonal generation function does not need to handle this.
                downsample = torch.nn.AvgPool2d(3, stride=2)
                gens = downsample(downsample(gens))

            if t <= -1 or t % img_log_steps == 0 or t == n_opt_iter - 1:
                gens_cpu = gens.to(torch.device("cpu"))
                imutil.save_images(
                    gens_cpu, os.path.join(log_img_dir, "output_images",
                                           str(t)))
                arranged_gens = torchvision.utils.make_grid(gens_cpu,
                                                            nrow=2,
                                                            normalize=True)
                log.l().debug(
                    'Logging generated images at iteration {}'.format(t + 1))
                sum_writer.add_image("Generated_Images", arranged_gens, t)

            F_gz = extractor.forward(gens)
            # import pdb; pdb.set_trace()
            if t <= -1 or t % img_log_steps == 0 or t == n_opt_iter - 1:
                feature_size = int(np.sqrt(F_gz.shape[1]))
                # import pdb; pdb.set_trace()
                try:
                    feat_out = F_gz.view(F_gz.shape[0], 1, feature_size,
                                         feature_size)
                    gens_cpu = feat_out.to(torch.device("cpu"))
                    imutil.save_images(
                        gens_cpu,
                        os.path.join(log_img_dir, "feature_images", str(t)))
                    arranged_init_imgs = torchvision.utils.make_grid(
                        gens_cpu, nrow=2, normalize=True)
                    sum_writer.add_image("feature_images", arranged_init_imgs,
                                         t)
                except:
                    if t == 0:
                        log.l().debug(
                            "Unable to plot features as image. Okay. Will skip plotting features."
                        )

            if weigh_logits:
                # WJ: This option is not really used. Should be removed.
                F_gz = weighing_logits(F_gz)
            KF_gz = k.eval(F_gz, F_gz)

            Z_loss = z_penalty(Z)
            mmd2 = torch.mean(KF_gz) - 2.0 * torch.mean(
                k.eval(F_gz, FX).mv(input_weights)) + mean_KFX
            loss = mmd2 + Z_loss

            # compute the gradients
            loss.backward(retain_graph=True)

            # record losses
            sum_writer.add_scalar("loss/total", loss.item(), t)
            sum_writer.add_scalar("loss/mmd2", mmd2.item(), t)
            sum_writer.add_scalar("loss/Z_penalty", Z_loss, t)

            # record some statistics
            sum_writer.add_scalar("Z/max_z", torch.max(Z), t)
            sum_writer.add_scalar("Z/min_z", torch.min(Z), t)
            sum_writer.add_scalar("Z/avg_z", torch.mean(Z), t)
            sum_writer.add_scalar("Z/std_z", torch.std(Z), t)
            sum_writer.add_histogram("Z/hist", Z.reshape(-1), t)

            loss_all.append(mmd2.item())

            if t <= 20 or t % 20 == 0:
                log.l().info("Iter [{}], overall_loss: {}".format(
                    t, loss.item()))
            return loss
Esempio n. 5
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch GKMM. Some paths are relative to the "(share_path)/prob_models/". See settings.ini for (share_path).'
    )

    parser.add_argument(
        "--extractor_type",
        type=str,
        default="vgg",
        help=
        "The feature extractor. The saved object should be a torch.nn.Module representing a \
        feature extractor. Currently support [vgg | vgg_face | alexnet_365 | resnet18_365 | resnet50_365 | hed | mnist_cnn | pixel]",
        required=True,
    )
    parser.add_argument(
        "--extractor_layers",
        nargs="+",
        default=["4", "9", "18", "27"],
        help=
        "Number of layers to include. Only for VGG feature extractor. Default:[]",
    )
    parser.add_argument(
        "--texture",
        type=float,
        default=0,
        help="Use texture (grammatrix) of extracted features. Default=0")
    parser.add_argument(
        "--depth_process",
        nargs="?",
        choices=["avg", "max", "no"],
        default="no",
        help="Processing module to run on the output from \
            each filter in the specified layer(s).",
    )
    parser.add_argument(
        "--g_path",
        type=str,
        required=True,
        help="Relative path \
            (relative to (share_path)/prob_models) to the file that can be loaded \
            to get a cadgan.gen.PTNoiseTransformer representing an image generator.",
    )
    parser.add_argument(
        "--g_type",
        type=str,
        default="celebAHQ.yaml",
        help="Generator type based on the data it is trained for.")
    parser.add_argument(
        "--g_min",
        type=float,
        help="The minimum value of the pixel output from the generator.",
        required=True)
    parser.add_argument(
        "--g_max",
        type=float,
        help="The maximum value of the pixel output from the generator.",
        required=True)
    parser.add_argument(
        "--logdir",
        type=str,
        required=True,
        help="full path to the folder to contain Tensorboard log files")
    parser.add_argument("--device",
                        nargs="?",
                        choices=["cpu", "gpu"],
                        default="cpu",
                        help="Device to use for computation.")
    parser.add_argument("--n_sample",
                        type=int,
                        default=16,
                        metavar="n",
                        help="Number of images to generate")
    parser.add_argument("--n_opt_iter",
                        type=int,
                        default=500,
                        help="Number of optimization iterations")

    parser.add_argument("--lr",
                        type=float,
                        default=0.001,
                        metavar="LR",
                        help="learning rate (for the optimizer)")
    parser.add_argument("--n_init_resample",
                        type=float,
                        default=1,
                        help="number of time to resample z for the heuristic")
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        metavar="S",
        help=
        "Random seed. Among others, this affects the initialization of the noise vectors of the generator in the optimization.",
    )
    parser.add_argument(
        "--img_log_steps",
        type=int,
        default=10,
        metavar="N",
        help=
        "how many optimization iterations to wait before logging generated images",
    )
    parser.add_argument("--img_size",
                        type=int,
                        default=224,
                        help="image size nxn default 256")
    # parser.add_argument('--data_dir', type=str,
    #        default='mnist/', help='Relative path (relative to the data folder) \
    #        containing Mnist training data. Mnist data will be downloaded if \
    #        not existed already.')
    # parser.add_argument('--cond', nargs='+', type=int, dest='cond',
    #        action='append', required=True, help='Digit label and number of images from that label to condition on. For example, "--cond 3 4" means 4 images of digit 3. --cond can be used multiple times. For instance, use --cond 1 2 --cond 3 1 to condition on 2 digits of 1, and 1 digit of 3')
    parser.add_argument("--cond_path",
                        type=str,
                        required=True,
                        help="Path to imgs for conditioning")
    parser.add_argument(
        "--kernel",
        nargs="?",
        required=True,
        choices=["linear", "gauss", "imq"],
        help=
        "choice of kernel to put on top of extracted features.  May need to specify also --kparams.",
    )
    parser.add_argument(
        "--kparams",
        nargs="*",
        type=float,
        dest="kparams",
        default=[],
        help=
        "A list of kernel parameters (float). Semantic of parameters depends on the chosen kernel",
    )

    parser.add_argument(
        "--w_input",
        nargs="+",
        default=[],
        help=
        "weight of the input, must be equal to the number of cond images and sum to 1. if none specified, equal weights will be used.",
    )

    img_transform = target_transform()
    # glo.data_file('mnist/')
    args = parser.parse_args()
    print("Training options: ")
    args_dict = vars(args)
    pprint.pprint(args_dict, width=5)

    # ---------------------------------

    # Check if texture and extractor are called correctly
    if args.texture and not args.extractor_layers or args.texture and not args.extractor_type:
        parser.error(
            "Texture call, Extractor layers and Extractor type must be given at the same time!"
        )

    # True to use GPU
    use_cuda = args.device == "gpu" and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    tensor_type = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
    torch.set_default_tensor_type(tensor_type)

    # load option depends on whether GPU is used
    device_load_options = {} if use_cuda else {
        "map_location": lambda storage, loc: storage
    }

    # initialize the noise vectors for the generator
    # Set the random seed
    seed = args.seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    n_sample = args.n_sample

    if args.g_type.endswith(".yaml"):
        # sample a stack of noise vectors
        latent_dim = 256
        f_noise = lambda n: torch.randn(n, latent_dim).float()
        Z0 = f_noise(n_sample)

        # Loading Configs for LarsGAN
        yaml_folder = os.path.dirname(ganstab.configs.__file__)
        yaml_config_path = os.path.join(yaml_folder, args.g_type)
        config = load_config(yaml_config_path)

        # load generator
        nlabels = config["data"]["nlabels"]
        out_dir = config["training"]["out_dir"]
        checkpoint_dir = os.path.join(out_dir, "chkpts")

        generator = build_generator(config)

        # Put models on gpu if needed
        #with torch.enable_grad():  # use_cuda??????
        generator = generator.to(device)
        # for celebA HQ generator,
        # if args.g_type == 'celebAHQ.yaml':
        #    generator.add_resize(args.img_size)

        # Use multiple GPUs if possible
        generator = nn.DataParallel(generator)
        # Logger
        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

        # Register modules to checkpoint
        checkpoint_io.register_modules(generator=generator)
        # Test generator
        if config["test"]["use_model_average"]:
            generator_test = copy.deepcopy(generator)
            checkpoint_io.register_modules(generator_test=generator_test)
        else:
            generator_test = generator

        # Loading Generator
        ydist = get_ydist(nlabels, device=device)

        full_g_path = glo.prob_model_folder(args.g_path)
        if not os.path.exists(full_g_path):
            #download lars pre-trained model file if not existed
            print(
                "Generator file does not exist: {}\n I will load a pretrained model for you. Please wait ..."
                .format(full_g_path),
                end='')

            dict_url = {
                'lsun_bedroom.yaml':
                'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bedroom-df4e7dd2.pt',
                'lsun_bridge.yaml':
                'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bridge-82887d22.pt',
                'celebAHQ.yaml':
                'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/celebahq-baab46b2.pt',
                'lsun_tower.yaml':
                'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_tower-1af5e570.pt'
            }

            assert args.g_type in dict_url.keys(
            ), 'g_type of {} not support'.format(args.g_type)
            url = dict_url[args.g_type]
            r = requests.get(url)
            os.makedirs(os.path.dirname(full_g_path), exist_ok=True)
            with open(full_g_path, 'wb') as f:
                f.write(r.content)

            print('done')
        load_options = {} if use_cuda else {
            "map_location": lambda storage, loc: storage
        }
        it = checkpoint_io.load(full_g_path, **load_options)

    elif args.g_type == "mnist_dcgan":
        # TODO should probablu reorganize these
        latent_dim = 100
        f_noise = lambda n: torch.randn(n, latent_dim).float()
        Z0 = f_noise(n_sample)

        full_g_path = glo.prob_model_folder(args.g_path)
        # load option depends on whether GPU is used
        load_options = {} if use_cuda else {
            "map_location": lambda storage, loc: storage
        }

        generator = mnist_dcgan.Generator()
        if os.path.exists(full_g_path):
            generator.load(full_g_path)
        else:
            print(
                "Generator file does not exist: {}\nLoading pretrain model...".
                format(full_g_path))
            generator.download_pretrain(
                output=full_g_path)  # .load(full_g_path, **load_options)

        generator = generator.to(device)

        generator_test = generator
        ydist = None

    elif args.g_type == "colormnist_dcgan":
        # TODO should probablu reorganize these
        latent_dim = 100
        f_noise = lambda n: torch.randn(n, latent_dim).float()
        Z0 = f_noise(n_sample)

        full_g_path = glo.prob_model_folder(args.g_path)
        generator = cmnist_dcgan.Generator()
        if os.path.exists(full_g_path):
            generator.load(full_g_path)
        else:
            print(
                "Generator file does not exist: {}\nLoading pretrain model...".
                format(full_g_path))
            generator.download_pretrain(
                output=full_g_path)  # .load(full_g_path, **load_options)

        generator = generator.to(device)

        generator_test = generator
        ydist = None

    # Noise distribution is Gaussian. Unlikely that the magnitude of the
    # coordinate is above the bound.
    z_penalty = kmain.TPNull()  # kmain.TPSymLogBarrier(bound=4.2, scale=1e-4)
    args_dict["zpen"] = z_penalty

    # output range of the generator (according to what the user specifies)
    g_range = (args.g_min, args.g_max)

    # Sanity check. Check that the specified g-range is plausible
    g_out_uncontrolled = Generator(ydist=ydist,
                                   generator=generator_test.to(device))

    temp_sample = g_out_uncontrolled.forward(Z0)
    kmain.pixel_values_check(temp_sample, g_range, "Generator's samples")

    extractor_in_size = args.img_size

    # transform the output range of g to (0,1)
    g = nn.Sequential(
        g_out_uncontrolled,
        nn.AdaptiveAvgPool2d((extractor_in_size, extractor_in_size)),
        gen.LinearRangeTransform(from_range=g_range, to_range=(0, 1)),
    )
    depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()}
    feature_size = 128

    if args.texture == 1:
        post_process = nn.Sequential(depth_process_map[args.depth_process],
                                     GramMatrix())
    else:
        post_process = nn.Sequential(depth_process_map[args.depth_process])

    # Loading Extractor
    if args.extractor_type == "vgg":
        extractor_layers = [int(i) for i in args.extractor_layers]
        extractor = ext.VGG19(layers=extractor_layers,
                              layer_postprocess=post_process)
    elif args.extractor_type == "vgg_face":
        extractor_layers = [int(i) for i in args.extractor_layers]
        extractor = ext.VGG19_face(layers=extractor_layers,
                                   layer_postprocess=post_process)
    elif args.extractor_type == "alexnet_365":
        extractor = ext.AlexNet_365()
    elif args.extractor_type == "resnet18_365":
        extractor = ext.ResNet18_365()
    elif args.extractor_type == "resnet50_365":
        extractor = ext.ResNet50_365(n_remove_last_layers=2,
                                     layer_postprocess=post_process)
    elif args.extractor_type == "hed":
        # extractor_in_size = 256
        extractor = ext.HED(device=device, resize=feature_size)
    elif args.extractor_type == "hed_color":
        #stacking feature from HED and tiny image to get both edge and color information
        hed = ext.HED(device=device, resize=feature_size)
        tiny = ext.TinyImage(device=device, grid_size=(10, 10))
        extractor = ext.StackModule(device=device,
                                    module_list=[hed, tiny],
                                    weights=[0.01, 0.99])
    elif args.extractor_type == "hed_vgg":
        #stacking feature from HED and vgg feature to get both edge and high level vgg information
        feature_size = 128
        hed = ext.HED(device=device, resize=feature_size)
        extractor_layers = [int(i) for i in args.extractor_layers]
        vgg = ext.VGG19(layers=extractor_layers,
                        layer_postprocess=post_process)
        extractor = ext.StackModule(device=device,
                                    module_list=[hed, vgg],
                                    weights=[0.99, 0.01])
    elif args.extractor_type == "hed_color_vgg":
        #stacking feature from HED, tiny image, and vgg feature to get edge, color, and high level vgg information
        feature_size = 128
        hed = ext.HED(device=device, resize=feature_size)
        extractor_layers = [int(i) for i in args.extractor_layers]
        vgg = ext.VGG19(layers=extractor_layers,
                        layer_postprocess=post_process)
        tiny = ext.TinyImage(device=device, grid_size=(10, 10))
        extractor = ext.StackModule(device=device,
                                    module_list=[hed, vgg, tiny],
                                    weights=[0.005, 0.005, 0.99])
    elif args.extractor_type == "color":
        extractor = ext.TinyImage(device=device, grid_size=(128, 128))
    elif args.extractor_type == "color_count":
        # to use with Waleed color mnist only:
        # the purpose is to count color based on the template, currently not working as expected.
        prototypes = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0],
                                   [1, 0, 1], [0.4, 0.2, 0]])
        extractor = ext.SoftCountPixels(prototypes=prototypes,
                                        gwidth2=0.3,
                                        device=device,
                                        tensor_type=tensor_type)
    elif args.extractor_type == "mnist_cnn":
        depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()}
        if args.texture == 1:
            post_process = nn.Sequential(depth_process_map[args.depth_process],
                                         GramMatrix())
        else:
            post_process = nn.Sequential(depth_process_map[args.depth_process])
        extractor = ext.MnistCNN(device="cuda" if use_cuda else "cpu",
                                 layer_postprocess=post_process,
                                 layer=int(args.extractor_layers[0]))
    elif args.extractor_type == "mnist_cnn_digit_layer":
        #using the last layer of MNIST CNN (digit classification)
        depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()}
        if args.texture == 1:
            post_process = nn.Sequential(depth_process_map[args.depth_process],
                                         GramMatrix())
        else:
            post_process = nn.Sequential(depth_process_map[args.depth_process])
        extractor = ext.MnistCNN(device="cuda" if use_cuda else "cpu",
                                 layer_postprocess=post_process,
                                 layer=3)
    elif args.extractor_type == "mnist_cnn_digit_layer_color":
        # using the last layer of MNIST CNN (digit classification) stacking with color information from tiny image
        depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()}
        if args.texture == 1:
            post_process = nn.Sequential(depth_process_map[args.depth_process],
                                         GramMatrix())
        else:
            post_process = nn.Sequential(depth_process_map[args.depth_process])
        mnistcnn = ext.MnistCNN(device="cuda" if use_cuda else "cpu",
                                layer_postprocess=post_process,
                                layer=3)
        color = ext.MaxColor(device=device)
        extractor = ext.StackModule(device=device,
                                    module_list=[mnistcnn, color],
                                    weights=[1, 99])
    elif args.extractor_type == "pixel":
        #raw pixel as feature
        extractor = ext.Identity(
            flatten=True,
            slice_dim=0 if args.g_type == "mnist_dcgan" else None)
    else:
        raise ValueError("Unknown extractor type. Check --extractor_type")

    if use_cuda:
        extractor = extractor.cuda()
    assert isinstance(extractor, torch.nn.Module)

    print("Summary of the extractor:")
    try:
        torchsummary.summary(extractor,
                             input_size=(3, extractor_in_size,
                                         extractor_in_size))
    except:
        log.l().info(
            "Exception occured when getting a summary of the extractor")

    # run a forward pass throught the extractor just to test
    tmp_extracted = extractor(g(Z0[[0]]))
    n_features = torch.prod(torch.tensor(tmp_extracted.shape))
    print("Number of extracted features = {}".format(n_features))
    del tmp_extracted

    def load_multiple_images(list_imgs):
        for path_img in list_imgs:
            loaded = imutil.load_resize_image(path_img,
                                              extractor_in_size).copy()
            cond_img = img_transform(loaded).unsqueeze(0).type(
                tensor_type)  # .to(device)
            try:
                cond_imgs = torch.cat((cond_imgs.clone(), cond_img))
            except NameError:
                cond_imgs = cond_img.clone()
        return cond_imgs

    if not os.path.isdir(glo.data_file(args.cond_path)):  #
        # read list of imgs if it's a text file
        if args.cond_path.endswith(".txt"):
            img_txt_path = glo.data_file(args.cond_path)
            with open(img_txt_path, "r") as f:
                data = f.readlines()

            list_imgs = [
                glo.data_file(x.strip()) for x in data if len(x.strip()) != 0
            ]
            if not list_imgs:
                raise ValueError(
                    "Empty list of images to condiiton. Make sure that {} is valid"
                    .format(img_txt_path))

            cond_imgs = load_multiple_images(list_imgs)
        elif args.cond_path.endswith(".png") or args.cond_path.endswith(
                ".jpg"):
            path_img = glo.data_file(args.cond_path)
            loaded = imutil.load_resize_image(path_img,
                                              extractor_in_size).copy()
            cond_imgs = img_transform(loaded).unsqueeze(0).type(
                tensor_type)  # .to(device)
        else:
            raise 'Not support input type at {} (currently support folder or text file with list of images)'.format(
                glo.data_file(args.cond_path))
    else:
        # using all images in the folder
        list_imgs = glob.glob(glo.data_file(args.cond_path) + "*")
        cond_imgs = load_multiple_images(list_imgs)

    cond_imgs = cond_imgs.to(device).type(tensor_type)

    # kernel on top of the extracted features
    k_map = {
        "linear": kernel.PTKLinear,
        "gauss": kernel.PTKGauss,
        "imq": kernel.PTKIMQ
    }
    kernel_key = args.kernel
    kernel_params = args.kparams
    k_constructor = k_map[kernel_key]
    # construct the chosen kernel with the specified parameters
    k = k_constructor(*kernel_params)

    # texture flag
    texture = args.texture
    # run the kernel moment matching optimization
    n_opt_iter = args.n_opt_iter
    logdir = args.logdir
    print("LOGDIR: ", logdir)

    # dictionary containing key-value pairs for experimental settings.
    log_str_dict = dict((ke, str(va)) for (ke, va) in args_dict.items())

    # logdir is just a parent folder.
    # Form the actual file name by concatenating the values of all
    # hyperparameters used.
    log_str_dict2 = copy.deepcopy(log_str_dict)

    now = datetime.datetime.now()
    time_str = "{:02}.{:02}.{}_{:02}{:02}{:02}".format(now.day, now.month,
                                                       now.year, now.hour,
                                                       now.minute, now.second)
    log_str_dict2["t"] = time_str
    util.translate_keys(
        log_str_dict2,
        {
            "cond_path": "co",
            "data_dir": "dat",
            "depth_process": "dp",
            "extractor_path": "ep",
            "extractor_type": "et",
            "extractor_layers": "el",
            "g_type": "gt",
            "kernel": "k",
            "kparams": "kp",
            "n_opt_iter": "it",
            "n_sample": "n",
            "seed": "s",
            "texture": "te",
        },
    )

    parameters_str = util.dict_to_string(
        log_str_dict2,
        exclude=[
            "device", "img_log_steps", "logdir", "g_min", "g_max", "g_path",
            "t"
        ],
        entry_sep="-",
        kv_sep="_",
    )
    img_log_steps = args.img_log_steps
    logdir_fname = util.clean_filename(parameters_str, replace="/\\[]")
    log_dir_path = glo.result_folder(os.path.join(logdir, logdir_fname))

    # multiple restarts to refine the drawn Z. This is just a heuristic
    # so we start (hopefully) from a good initial point.
    k_img = kernel.PTKFuncCompose(k, f=extractor)
    # multi_restarts_refiner = kmain.ZRMMDMultipleRestarts(
    #         g, z_sampler=f_noise, k=k_img, X=cond_imgs,
    #         n_restarts=100,
    #         n_sample=Z0.shape[0],
    #         )

    tmp_gen = g(Z0)
    assert tmp_gen.shape[-1] == extractor_in_size and tmp_gen.shape[
        -2] == extractor_in_size
    del tmp_gen

    if len(args.w_input) == 0:
        input_weights = None
    else:
        assert cond_imgs.shape[0] == len(
            args.w_input
        ), "number of input weights must equal to number of input images"
        input_weights = torch.Tensor([float(x) for x in args.w_input],
                                     device=device).type(tensor_type)

    # A heuristic to pick good Z to start the optimization
    multi_restarts_refiner = kmain.ZRMMDIterGreedy(
        g,
        z_sampler=f_noise,
        k=k_img,
        X=cond_imgs,
        n_draws=int(
            args.n_init_resample
        ),  # number of times to draw each z_i --> set to 1 since I want to test the latent optimization,
        n_sample=Z0.shape[0],
        device=device,
        tensor_type=tensor_type,
        input_weights=input_weights,
    )

    # Summary writer for Tensorboard logging
    sum_writer = SummaryWriter(log_dir=log_dir_path)

    # write all key-value pairs in log_str_dict to the Tensorboard
    for ke, va in log_str_dict.items():
        sum_writer.add_text(ke, va)

    with open(os.path.join(log_dir_path, "metadata"), "wb") as f:
        dill.dump(log_str_dict, f)

    imutil.save_images(cond_imgs, os.path.join(log_dir_path, "input_images"))

    gens = g.forward(Z0)
    gens_cpu = gens.to(torch.device("cpu"))
    imutil.save_images(gens_cpu, os.path.join(log_dir_path, "prior_images"))
    del gens
    del gens_cpu
    # import pdb; pdb.set_trace()
    # Get a better Z
    Z = multi_restarts_refiner(Z0)

    # Try to plot (in Tensorboard) extracted features as images if possible
    log.l().info(
        'Attemping to plot extracted features as images. Will skip if this does not work'
    )
    try:
        # if args.extractor_type == 'hed':
        feat_out = extractor.forward(cond_imgs)
        # import pdb; pdb.set_trace()
        feature_size = int(np.sqrt(feat_out.shape[1]))
        feat_out = feat_out.view(feat_out.shape[0], 1, feature_size,
                                 feature_size)
        gens_cpu = feat_out.to(torch.device("cpu"))
        imutil.save_images(gens_cpu, os.path.join(log_dir_path,
                                                  "input_feature"))
        arranged_init_imgs = torchvision.utils.make_grid(gens_cpu,
                                                         nrow=2,
                                                         normalize=True)
        sum_writer.add_image("Init_feature", arranged_init_imgs)
        del feat_out
    except Exception as err:
        log.l().info(err)
        log.l().info("unable to plot feature as image")
    # if args.w_intp
    # import pdb; pdb.set_trace()

    imutil.save_images(cond_imgs, os.path.join(log_dir_path, "input_images"))

    # optimizer
    optimizer = torch.optim.Adam([Z],
                                 lr=args.lr)  # ,momentum=0.99,nesterov=True)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1000,2000,3000], gamma=0.1)
    # optimizer = torch.optim.LBFGS([Z]) # --> LBFGS doesn't really converge, we could try other optimizer as well
    # Solve the kernel moment matching problem
    kmain.pt_gkmm(
        g,
        cond_imgs,
        extractor,
        k,
        Z,
        optimizer,
        z_penalty=z_penalty,
        sum_writer=sum_writer,
        device=device,
        tensor_type=tensor_type,
        n_opt_iter=n_opt_iter,
        seed=seed,
        texture=texture,
        input_weights=input_weights,
        img_log_steps=img_log_steps,
        log_img_dir=log_dir_path,
    )
    print('Finished, results location : {}'.format(log_dir_path))