Ejemplo n.º 1
0
 def download_pretrain(self, output_path='',**opt):
     if output_path == '':
         output_path = glo.prob_model_folder('mnist_dcgan/mnist_dcgan_ep{}_bs{}.pt'.format(40,64))
     if not os.path.exists(output_path):
         gdd.download_file_from_google_drive(file_id='1KOi9b8JSBXc7hx9P8Azbtkhd7fZOeZxc',dest_path=output_path)
         
     use_cuda = torch.cuda.is_available()
     load_options = {} if use_cuda else {'map_location': lambda storage, loc: storage} 
     self.load(output_path,**load_options)
Ejemplo n.º 2
0
def main():

    parser = argparse.ArgumentParser(description="Train a DCGAN on CIFAR10")
    parser.add_argument("--n_epochs", type=int, default=25, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=2 ** 5, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument(
        "--sample_interval",
        type=int,
        default=100,
        help="interval between image sampling. The number refers to the number of minibatch updates.",
    )
    parser.add_argument(
        "--save_model_interval", type=int, default=10, help="Save the generator once every this many epochs."
    )
    parser.add_argument("--prob_model_dir", type=str, help="interval between image sampling")
    parser.add_argument(
        "--classes", type=int, help="a list of integers (0-9) denoting the classes to consider", nargs="+"
    )

    # --------------------------------
    args = parser.parse_args()

    # op is a dict
    op = vars(args)
    if op["classes"] is None:
        # classes not specified => consider all classes
        op["classes"] = list(range(10))

    classes = sorted(op["classes"])
    cls_str = "".join(map(str, classes))
    if op["prob_model_dir"] is None:
        # use the list of classes to name the prob_model_dir
        prob_model_dir_name = "cifar10_c{}-dcgan".format(cls_str)
        op["prob_model_dir"] = glo.prob_model_folder(prob_model_dir_name)

    log.l().info("Options used: ")
    pprint.pprint(op)

    dcgan = DCGAN(**op)
    model_fname = "cifar10_c{}-dcgan-ep{}_bs{}.pt".format(cls_str, op["n_epochs"], op["batch_size"])
    model_fpath = os.path.join(dcgan.prob_model_dir, model_fname)

    # train
    log.l().info("Starting training a DCGAN on CIFAR10")
    dcgan.train()

    # save the generator
    g = dcgan.generator
    log.l().info("Saving the trained model to: {}".format(model_fpath))
    g.save(model_fpath)
Ejemplo n.º 3
0
    def download_pretrain(self, output='', **opt):
        if output == '':
            output = glo.prob_model_folder('mnist_cnn/mnist_cnn_ep40_s1.pt')

        if not os.path.exists(output):
            gdd.download_file_from_google_drive(
                file_id='1wYJX_w3J5Fzxc5E4DCMPunWRKikLvk5F', dest_path=output)
        use_cuda = True and torch.cuda.is_available()
        load_options = {} if use_cuda else {
            'map_location': lambda storage, loc: storage
        }

        self.load(output, **load_options)
Ejemplo n.º 4
0
    def download_pretrain(self, output_path='', **opt):
        if output_path == '':
            output_path = glo.prob_model_folder(
                'dcgan_colormnist/colormnist/netG_epoch_{}.pth'.format(24))
        if not os.path.exists(output_path):
            gdd.download_file_from_google_drive(
                file_id='149uouxlGyhAOPKrGTVNzxXV0p7IjKqX-',
                dest_path=output_path)

        use_cuda = torch.cuda.is_available()
        load_options = {} if use_cuda else {
            'map_location': lambda storage, loc: storage
        }
        self.load(output_path, **load_options)


# ---------
Ejemplo n.º 5
0
 def __init__(
     self,
     prob_model_dir=glo.prob_model_folder("cifar10_dcgan"),
     data_dir=glo.data_file("cifar10"),
     use_cuda=True,
     n_epochs=30,
     batch_size=2 ** 6,
     lr=0.0002,
     b1=0.5,
     b2=0.999,
     latent_dim=100,
     sample_interval=200,
     save_model_interval=10,
     classes=list(range(10)),
     **op,
 ):
     """
     n_epochs: number of epochs of training
     batch_size: size of the batches
     lr: adam: learning rate
     b1: adam: decay of first order momentum of gradient
     b2: adam: decay of first order momentum of gradient
     latent_dim: dimensionality of the latent space
     sample_interval: interval between image sampling
     save_model_interval: save the generator once every this many epochs
     """
     os.makedirs(prob_model_dir, exist_ok=True)
     self.prob_model_dir = prob_model_dir
     self.data_dir = data_dir
     self.use_cuda = use_cuda
     self.n_epochs = n_epochs
     self.batch_size = batch_size
     self.lr = lr
     self.b1 = b1
     self.b2 = b2
     self.latent_dim = latent_dim
     self.sample_interval = sample_interval
     self.save_model_interval = save_model_interval
     self.classes = classes
Ejemplo n.º 6
0
 def __init__(
     self,
     prob_model_dir=glo.prob_model_folder("fashion_dcgan"),
     data_dir=glo.data_file("/home/wgondal/cadgan/data/fashion"),
     use_cuda=True,
     n_epochs=30,
     batch_size=2 ** 6,
     lr=0.0002,
     b1=0.5,
     b2=0.999,
     n_cpu=4,
     latent_dim=100,
     sample_interval=400,
 ):
     """
     n_epochs: number of epochs of training
     batch_size: size of the batches
     lr: adam: learning rate
     b1: adam: decay of first order momentum of gradient
     b2: adam: decay of first order momentum of gradient
     n_cpu: number of cpu threads to use during batch generation
     latent_dim: dimensionality of the latent space
     sample_interval: interval between image sampling
     """
     print(prob_model_dir)
     os.makedirs(prob_model_dir, exist_ok=True)
     self.prob_model_dir = prob_model_dir
     self.data_dir = data_dir
     self.use_cuda = use_cuda
     self.n_epochs = n_epochs
     self.batch_size = batch_size
     self.lr = lr
     self.b1 = b1
     self.b2 = b2
     self.n_cpu = n_cpu
     self.latent_dim = latent_dim
     self.sample_interval = sample_interval
Ejemplo n.º 7
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(
        description=
        'PyTorch GKMM. Some paths are relative to the "(share_path)/prob_models/". See settings.ini for (share_path).'
    )

    parser.add_argument(
        "--extractor_type",
        type=str,
        default="vgg",
        help=
        "The feature extractor. The saved object should be a torch.nn.Module representing a \
        feature extractor. Currently support [vgg | vgg_face | alexnet_365 | resnet18_365 | resnet50_365 | hed | mnist_cnn | pixel]",
        required=True,
    )
    parser.add_argument(
        "--extractor_layers",
        nargs="+",
        default=["4", "9", "18", "27"],
        help=
        "Number of layers to include. Only for VGG feature extractor. Default:[]",
    )
    parser.add_argument(
        "--texture",
        type=float,
        default=0,
        help="Use texture (grammatrix) of extracted features. Default=0")
    parser.add_argument(
        "--depth_process",
        nargs="?",
        choices=["avg", "max", "no"],
        default="no",
        help="Processing module to run on the output from \
            each filter in the specified layer(s).",
    )
    parser.add_argument(
        "--g_path",
        type=str,
        required=True,
        help="Relative path \
            (relative to (share_path)/prob_models) to the file that can be loaded \
            to get a cadgan.gen.PTNoiseTransformer representing an image generator.",
    )
    parser.add_argument(
        "--g_type",
        type=str,
        default="celebAHQ.yaml",
        help="Generator type based on the data it is trained for.")
    parser.add_argument(
        "--g_min",
        type=float,
        help="The minimum value of the pixel output from the generator.",
        required=True)
    parser.add_argument(
        "--g_max",
        type=float,
        help="The maximum value of the pixel output from the generator.",
        required=True)
    parser.add_argument(
        "--logdir",
        type=str,
        required=True,
        help="full path to the folder to contain Tensorboard log files")
    parser.add_argument("--device",
                        nargs="?",
                        choices=["cpu", "gpu"],
                        default="cpu",
                        help="Device to use for computation.")
    parser.add_argument("--n_sample",
                        type=int,
                        default=16,
                        metavar="n",
                        help="Number of images to generate")
    parser.add_argument("--n_opt_iter",
                        type=int,
                        default=500,
                        help="Number of optimization iterations")

    parser.add_argument("--lr",
                        type=float,
                        default=0.001,
                        metavar="LR",
                        help="learning rate (for the optimizer)")
    parser.add_argument("--n_init_resample",
                        type=float,
                        default=1,
                        help="number of time to resample z for the heuristic")
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        metavar="S",
        help=
        "Random seed. Among others, this affects the initialization of the noise vectors of the generator in the optimization.",
    )
    parser.add_argument(
        "--img_log_steps",
        type=int,
        default=10,
        metavar="N",
        help=
        "how many optimization iterations to wait before logging generated images",
    )
    parser.add_argument("--img_size",
                        type=int,
                        default=224,
                        help="image size nxn default 256")
    # parser.add_argument('--data_dir', type=str,
    #        default='mnist/', help='Relative path (relative to the data folder) \
    #        containing Mnist training data. Mnist data will be downloaded if \
    #        not existed already.')
    # parser.add_argument('--cond', nargs='+', type=int, dest='cond',
    #        action='append', required=True, help='Digit label and number of images from that label to condition on. For example, "--cond 3 4" means 4 images of digit 3. --cond can be used multiple times. For instance, use --cond 1 2 --cond 3 1 to condition on 2 digits of 1, and 1 digit of 3')
    parser.add_argument("--cond_path",
                        type=str,
                        required=True,
                        help="Path to imgs for conditioning")
    parser.add_argument(
        "--kernel",
        nargs="?",
        required=True,
        choices=["linear", "gauss", "imq"],
        help=
        "choice of kernel to put on top of extracted features.  May need to specify also --kparams.",
    )
    parser.add_argument(
        "--kparams",
        nargs="*",
        type=float,
        dest="kparams",
        default=[],
        help=
        "A list of kernel parameters (float). Semantic of parameters depends on the chosen kernel",
    )

    parser.add_argument(
        "--w_input",
        nargs="+",
        default=[],
        help=
        "weight of the input, must be equal to the number of cond images and sum to 1. if none specified, equal weights will be used.",
    )

    img_transform = target_transform()
    # glo.data_file('mnist/')
    args = parser.parse_args()
    print("Training options: ")
    args_dict = vars(args)
    pprint.pprint(args_dict, width=5)

    # ---------------------------------

    # Check if texture and extractor are called correctly
    if args.texture and not args.extractor_layers or args.texture and not args.extractor_type:
        parser.error(
            "Texture call, Extractor layers and Extractor type must be given at the same time!"
        )

    # True to use GPU
    use_cuda = args.device == "gpu" and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    tensor_type = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
    torch.set_default_tensor_type(tensor_type)

    # load option depends on whether GPU is used
    device_load_options = {} if use_cuda else {
        "map_location": lambda storage, loc: storage
    }

    # initialize the noise vectors for the generator
    # Set the random seed
    seed = args.seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    n_sample = args.n_sample

    if args.g_type.endswith(".yaml"):
        # sample a stack of noise vectors
        latent_dim = 256
        f_noise = lambda n: torch.randn(n, latent_dim).float()
        Z0 = f_noise(n_sample)

        # Loading Configs for LarsGAN
        yaml_folder = os.path.dirname(ganstab.configs.__file__)
        yaml_config_path = os.path.join(yaml_folder, args.g_type)
        config = load_config(yaml_config_path)

        # load generator
        nlabels = config["data"]["nlabels"]
        out_dir = config["training"]["out_dir"]
        checkpoint_dir = os.path.join(out_dir, "chkpts")

        generator = build_generator(config)

        # Put models on gpu if needed
        #with torch.enable_grad():  # use_cuda??????
        generator = generator.to(device)
        # for celebA HQ generator,
        # if args.g_type == 'celebAHQ.yaml':
        #    generator.add_resize(args.img_size)

        # Use multiple GPUs if possible
        generator = nn.DataParallel(generator)
        # Logger
        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)

        # Register modules to checkpoint
        checkpoint_io.register_modules(generator=generator)
        # Test generator
        if config["test"]["use_model_average"]:
            generator_test = copy.deepcopy(generator)
            checkpoint_io.register_modules(generator_test=generator_test)
        else:
            generator_test = generator

        # Loading Generator
        ydist = get_ydist(nlabels, device=device)

        full_g_path = glo.prob_model_folder(args.g_path)
        if not os.path.exists(full_g_path):
            #download lars pre-trained model file if not existed
            print(
                "Generator file does not exist: {}\n I will load a pretrained model for you. Please wait ..."
                .format(full_g_path),
                end='')

            dict_url = {
                'lsun_bedroom.yaml':
                'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bedroom-df4e7dd2.pt',
                'lsun_bridge.yaml':
                'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bridge-82887d22.pt',
                'celebAHQ.yaml':
                'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/celebahq-baab46b2.pt',
                'lsun_tower.yaml':
                'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_tower-1af5e570.pt'
            }

            assert args.g_type in dict_url.keys(
            ), 'g_type of {} not support'.format(args.g_type)
            url = dict_url[args.g_type]
            r = requests.get(url)
            os.makedirs(os.path.dirname(full_g_path), exist_ok=True)
            with open(full_g_path, 'wb') as f:
                f.write(r.content)

            print('done')
        load_options = {} if use_cuda else {
            "map_location": lambda storage, loc: storage
        }
        it = checkpoint_io.load(full_g_path, **load_options)

    elif args.g_type == "mnist_dcgan":
        # TODO should probablu reorganize these
        latent_dim = 100
        f_noise = lambda n: torch.randn(n, latent_dim).float()
        Z0 = f_noise(n_sample)

        full_g_path = glo.prob_model_folder(args.g_path)
        # load option depends on whether GPU is used
        load_options = {} if use_cuda else {
            "map_location": lambda storage, loc: storage
        }

        generator = mnist_dcgan.Generator()
        if os.path.exists(full_g_path):
            generator.load(full_g_path)
        else:
            print(
                "Generator file does not exist: {}\nLoading pretrain model...".
                format(full_g_path))
            generator.download_pretrain(
                output=full_g_path)  # .load(full_g_path, **load_options)

        generator = generator.to(device)

        generator_test = generator
        ydist = None

    elif args.g_type == "colormnist_dcgan":
        # TODO should probablu reorganize these
        latent_dim = 100
        f_noise = lambda n: torch.randn(n, latent_dim).float()
        Z0 = f_noise(n_sample)

        full_g_path = glo.prob_model_folder(args.g_path)
        generator = cmnist_dcgan.Generator()
        if os.path.exists(full_g_path):
            generator.load(full_g_path)
        else:
            print(
                "Generator file does not exist: {}\nLoading pretrain model...".
                format(full_g_path))
            generator.download_pretrain(
                output=full_g_path)  # .load(full_g_path, **load_options)

        generator = generator.to(device)

        generator_test = generator
        ydist = None

    # Noise distribution is Gaussian. Unlikely that the magnitude of the
    # coordinate is above the bound.
    z_penalty = kmain.TPNull()  # kmain.TPSymLogBarrier(bound=4.2, scale=1e-4)
    args_dict["zpen"] = z_penalty

    # output range of the generator (according to what the user specifies)
    g_range = (args.g_min, args.g_max)

    # Sanity check. Check that the specified g-range is plausible
    g_out_uncontrolled = Generator(ydist=ydist,
                                   generator=generator_test.to(device))

    temp_sample = g_out_uncontrolled.forward(Z0)
    kmain.pixel_values_check(temp_sample, g_range, "Generator's samples")

    extractor_in_size = args.img_size

    # transform the output range of g to (0,1)
    g = nn.Sequential(
        g_out_uncontrolled,
        nn.AdaptiveAvgPool2d((extractor_in_size, extractor_in_size)),
        gen.LinearRangeTransform(from_range=g_range, to_range=(0, 1)),
    )
    depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()}
    feature_size = 128

    if args.texture == 1:
        post_process = nn.Sequential(depth_process_map[args.depth_process],
                                     GramMatrix())
    else:
        post_process = nn.Sequential(depth_process_map[args.depth_process])

    # Loading Extractor
    if args.extractor_type == "vgg":
        extractor_layers = [int(i) for i in args.extractor_layers]
        extractor = ext.VGG19(layers=extractor_layers,
                              layer_postprocess=post_process)
    elif args.extractor_type == "vgg_face":
        extractor_layers = [int(i) for i in args.extractor_layers]
        extractor = ext.VGG19_face(layers=extractor_layers,
                                   layer_postprocess=post_process)
    elif args.extractor_type == "alexnet_365":
        extractor = ext.AlexNet_365()
    elif args.extractor_type == "resnet18_365":
        extractor = ext.ResNet18_365()
    elif args.extractor_type == "resnet50_365":
        extractor = ext.ResNet50_365(n_remove_last_layers=2,
                                     layer_postprocess=post_process)
    elif args.extractor_type == "hed":
        # extractor_in_size = 256
        extractor = ext.HED(device=device, resize=feature_size)
    elif args.extractor_type == "hed_color":
        #stacking feature from HED and tiny image to get both edge and color information
        hed = ext.HED(device=device, resize=feature_size)
        tiny = ext.TinyImage(device=device, grid_size=(10, 10))
        extractor = ext.StackModule(device=device,
                                    module_list=[hed, tiny],
                                    weights=[0.01, 0.99])
    elif args.extractor_type == "hed_vgg":
        #stacking feature from HED and vgg feature to get both edge and high level vgg information
        feature_size = 128
        hed = ext.HED(device=device, resize=feature_size)
        extractor_layers = [int(i) for i in args.extractor_layers]
        vgg = ext.VGG19(layers=extractor_layers,
                        layer_postprocess=post_process)
        extractor = ext.StackModule(device=device,
                                    module_list=[hed, vgg],
                                    weights=[0.99, 0.01])
    elif args.extractor_type == "hed_color_vgg":
        #stacking feature from HED, tiny image, and vgg feature to get edge, color, and high level vgg information
        feature_size = 128
        hed = ext.HED(device=device, resize=feature_size)
        extractor_layers = [int(i) for i in args.extractor_layers]
        vgg = ext.VGG19(layers=extractor_layers,
                        layer_postprocess=post_process)
        tiny = ext.TinyImage(device=device, grid_size=(10, 10))
        extractor = ext.StackModule(device=device,
                                    module_list=[hed, vgg, tiny],
                                    weights=[0.005, 0.005, 0.99])
    elif args.extractor_type == "color":
        extractor = ext.TinyImage(device=device, grid_size=(128, 128))
    elif args.extractor_type == "color_count":
        # to use with Waleed color mnist only:
        # the purpose is to count color based on the template, currently not working as expected.
        prototypes = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0],
                                   [1, 0, 1], [0.4, 0.2, 0]])
        extractor = ext.SoftCountPixels(prototypes=prototypes,
                                        gwidth2=0.3,
                                        device=device,
                                        tensor_type=tensor_type)
    elif args.extractor_type == "mnist_cnn":
        depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()}
        if args.texture == 1:
            post_process = nn.Sequential(depth_process_map[args.depth_process],
                                         GramMatrix())
        else:
            post_process = nn.Sequential(depth_process_map[args.depth_process])
        extractor = ext.MnistCNN(device="cuda" if use_cuda else "cpu",
                                 layer_postprocess=post_process,
                                 layer=int(args.extractor_layers[0]))
    elif args.extractor_type == "mnist_cnn_digit_layer":
        #using the last layer of MNIST CNN (digit classification)
        depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()}
        if args.texture == 1:
            post_process = nn.Sequential(depth_process_map[args.depth_process],
                                         GramMatrix())
        else:
            post_process = nn.Sequential(depth_process_map[args.depth_process])
        extractor = ext.MnistCNN(device="cuda" if use_cuda else "cpu",
                                 layer_postprocess=post_process,
                                 layer=3)
    elif args.extractor_type == "mnist_cnn_digit_layer_color":
        # using the last layer of MNIST CNN (digit classification) stacking with color information from tiny image
        depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()}
        if args.texture == 1:
            post_process = nn.Sequential(depth_process_map[args.depth_process],
                                         GramMatrix())
        else:
            post_process = nn.Sequential(depth_process_map[args.depth_process])
        mnistcnn = ext.MnistCNN(device="cuda" if use_cuda else "cpu",
                                layer_postprocess=post_process,
                                layer=3)
        color = ext.MaxColor(device=device)
        extractor = ext.StackModule(device=device,
                                    module_list=[mnistcnn, color],
                                    weights=[1, 99])
    elif args.extractor_type == "pixel":
        #raw pixel as feature
        extractor = ext.Identity(
            flatten=True,
            slice_dim=0 if args.g_type == "mnist_dcgan" else None)
    else:
        raise ValueError("Unknown extractor type. Check --extractor_type")

    if use_cuda:
        extractor = extractor.cuda()
    assert isinstance(extractor, torch.nn.Module)

    print("Summary of the extractor:")
    try:
        torchsummary.summary(extractor,
                             input_size=(3, extractor_in_size,
                                         extractor_in_size))
    except:
        log.l().info(
            "Exception occured when getting a summary of the extractor")

    # run a forward pass throught the extractor just to test
    tmp_extracted = extractor(g(Z0[[0]]))
    n_features = torch.prod(torch.tensor(tmp_extracted.shape))
    print("Number of extracted features = {}".format(n_features))
    del tmp_extracted

    def load_multiple_images(list_imgs):
        for path_img in list_imgs:
            loaded = imutil.load_resize_image(path_img,
                                              extractor_in_size).copy()
            cond_img = img_transform(loaded).unsqueeze(0).type(
                tensor_type)  # .to(device)
            try:
                cond_imgs = torch.cat((cond_imgs.clone(), cond_img))
            except NameError:
                cond_imgs = cond_img.clone()
        return cond_imgs

    if not os.path.isdir(glo.data_file(args.cond_path)):  #
        # read list of imgs if it's a text file
        if args.cond_path.endswith(".txt"):
            img_txt_path = glo.data_file(args.cond_path)
            with open(img_txt_path, "r") as f:
                data = f.readlines()

            list_imgs = [
                glo.data_file(x.strip()) for x in data if len(x.strip()) != 0
            ]
            if not list_imgs:
                raise ValueError(
                    "Empty list of images to condiiton. Make sure that {} is valid"
                    .format(img_txt_path))

            cond_imgs = load_multiple_images(list_imgs)
        elif args.cond_path.endswith(".png") or args.cond_path.endswith(
                ".jpg"):
            path_img = glo.data_file(args.cond_path)
            loaded = imutil.load_resize_image(path_img,
                                              extractor_in_size).copy()
            cond_imgs = img_transform(loaded).unsqueeze(0).type(
                tensor_type)  # .to(device)
        else:
            raise 'Not support input type at {} (currently support folder or text file with list of images)'.format(
                glo.data_file(args.cond_path))
    else:
        # using all images in the folder
        list_imgs = glob.glob(glo.data_file(args.cond_path) + "*")
        cond_imgs = load_multiple_images(list_imgs)

    cond_imgs = cond_imgs.to(device).type(tensor_type)

    # kernel on top of the extracted features
    k_map = {
        "linear": kernel.PTKLinear,
        "gauss": kernel.PTKGauss,
        "imq": kernel.PTKIMQ
    }
    kernel_key = args.kernel
    kernel_params = args.kparams
    k_constructor = k_map[kernel_key]
    # construct the chosen kernel with the specified parameters
    k = k_constructor(*kernel_params)

    # texture flag
    texture = args.texture
    # run the kernel moment matching optimization
    n_opt_iter = args.n_opt_iter
    logdir = args.logdir
    print("LOGDIR: ", logdir)

    # dictionary containing key-value pairs for experimental settings.
    log_str_dict = dict((ke, str(va)) for (ke, va) in args_dict.items())

    # logdir is just a parent folder.
    # Form the actual file name by concatenating the values of all
    # hyperparameters used.
    log_str_dict2 = copy.deepcopy(log_str_dict)

    now = datetime.datetime.now()
    time_str = "{:02}.{:02}.{}_{:02}{:02}{:02}".format(now.day, now.month,
                                                       now.year, now.hour,
                                                       now.minute, now.second)
    log_str_dict2["t"] = time_str
    util.translate_keys(
        log_str_dict2,
        {
            "cond_path": "co",
            "data_dir": "dat",
            "depth_process": "dp",
            "extractor_path": "ep",
            "extractor_type": "et",
            "extractor_layers": "el",
            "g_type": "gt",
            "kernel": "k",
            "kparams": "kp",
            "n_opt_iter": "it",
            "n_sample": "n",
            "seed": "s",
            "texture": "te",
        },
    )

    parameters_str = util.dict_to_string(
        log_str_dict2,
        exclude=[
            "device", "img_log_steps", "logdir", "g_min", "g_max", "g_path",
            "t"
        ],
        entry_sep="-",
        kv_sep="_",
    )
    img_log_steps = args.img_log_steps
    logdir_fname = util.clean_filename(parameters_str, replace="/\\[]")
    log_dir_path = glo.result_folder(os.path.join(logdir, logdir_fname))

    # multiple restarts to refine the drawn Z. This is just a heuristic
    # so we start (hopefully) from a good initial point.
    k_img = kernel.PTKFuncCompose(k, f=extractor)
    # multi_restarts_refiner = kmain.ZRMMDMultipleRestarts(
    #         g, z_sampler=f_noise, k=k_img, X=cond_imgs,
    #         n_restarts=100,
    #         n_sample=Z0.shape[0],
    #         )

    tmp_gen = g(Z0)
    assert tmp_gen.shape[-1] == extractor_in_size and tmp_gen.shape[
        -2] == extractor_in_size
    del tmp_gen

    if len(args.w_input) == 0:
        input_weights = None
    else:
        assert cond_imgs.shape[0] == len(
            args.w_input
        ), "number of input weights must equal to number of input images"
        input_weights = torch.Tensor([float(x) for x in args.w_input],
                                     device=device).type(tensor_type)

    # A heuristic to pick good Z to start the optimization
    multi_restarts_refiner = kmain.ZRMMDIterGreedy(
        g,
        z_sampler=f_noise,
        k=k_img,
        X=cond_imgs,
        n_draws=int(
            args.n_init_resample
        ),  # number of times to draw each z_i --> set to 1 since I want to test the latent optimization,
        n_sample=Z0.shape[0],
        device=device,
        tensor_type=tensor_type,
        input_weights=input_weights,
    )

    # Summary writer for Tensorboard logging
    sum_writer = SummaryWriter(log_dir=log_dir_path)

    # write all key-value pairs in log_str_dict to the Tensorboard
    for ke, va in log_str_dict.items():
        sum_writer.add_text(ke, va)

    with open(os.path.join(log_dir_path, "metadata"), "wb") as f:
        dill.dump(log_str_dict, f)

    imutil.save_images(cond_imgs, os.path.join(log_dir_path, "input_images"))

    gens = g.forward(Z0)
    gens_cpu = gens.to(torch.device("cpu"))
    imutil.save_images(gens_cpu, os.path.join(log_dir_path, "prior_images"))
    del gens
    del gens_cpu
    # import pdb; pdb.set_trace()
    # Get a better Z
    Z = multi_restarts_refiner(Z0)

    # Try to plot (in Tensorboard) extracted features as images if possible
    log.l().info(
        'Attemping to plot extracted features as images. Will skip if this does not work'
    )
    try:
        # if args.extractor_type == 'hed':
        feat_out = extractor.forward(cond_imgs)
        # import pdb; pdb.set_trace()
        feature_size = int(np.sqrt(feat_out.shape[1]))
        feat_out = feat_out.view(feat_out.shape[0], 1, feature_size,
                                 feature_size)
        gens_cpu = feat_out.to(torch.device("cpu"))
        imutil.save_images(gens_cpu, os.path.join(log_dir_path,
                                                  "input_feature"))
        arranged_init_imgs = torchvision.utils.make_grid(gens_cpu,
                                                         nrow=2,
                                                         normalize=True)
        sum_writer.add_image("Init_feature", arranged_init_imgs)
        del feat_out
    except Exception as err:
        log.l().info(err)
        log.l().info("unable to plot feature as image")
    # if args.w_intp
    # import pdb; pdb.set_trace()

    imutil.save_images(cond_imgs, os.path.join(log_dir_path, "input_images"))

    # optimizer
    optimizer = torch.optim.Adam([Z],
                                 lr=args.lr)  # ,momentum=0.99,nesterov=True)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1000,2000,3000], gamma=0.1)
    # optimizer = torch.optim.LBFGS([Z]) # --> LBFGS doesn't really converge, we could try other optimizer as well
    # Solve the kernel moment matching problem
    kmain.pt_gkmm(
        g,
        cond_imgs,
        extractor,
        k,
        Z,
        optimizer,
        z_penalty=z_penalty,
        sum_writer=sum_writer,
        device=device,
        tensor_type=tensor_type,
        n_opt_iter=n_opt_iter,
        seed=seed,
        texture=texture,
        input_weights=input_weights,
        img_log_steps=img_log_steps,
        log_img_dir=log_dir_path,
    )
    print('Finished, results location : {}'.format(log_dir_path))