Exemple #1
0
def plot_all_traversals_grid(model,
                             latent_dim,
                             input_size,
                             input_shape,
                             num_sample=1,
                             figsize=(30, 30),
                             filename=None,
                             conditional_target=None,
                             device='cpu'):
    """
     Generates a grid of images for all latent dimensions, where each row corresponds to a traversal along a latent
    dimension.

    :param model: INN used for sampling
    :param latent_dim: dimension of the latent space
    :param input_size: total number of elements in the input of the INN
    :param input_shape: shape of the input for the INN
    :param idx: Index of a continuous latent dimension to traverse. If None, no latent is traversed and all latent
    dimensions are randomly sampled or kept fixed.
    :param num_sample: number of samples to generate
    :param figsize: the size of the generated plot
    :param filename: file name under which the plot will be saved. (optional)
    :return: None
    """

    model.to(device)
    model.eval()

    if len(input_shape) != 3:
        raise ValueError("input_shape must be 3-dimensional")

    grid = []
    for idx in range(latent_dim):
        latent_samples = tra.traverse_continous_line(
            latent_dim,
            input_size,
            idx,
            num_sample,
            False,
            lat_img=None,
            conditional_target=conditional_target)

        latent_samples = latent_samples.view(num_sample, input_shape[0],
                                             input_shape[1], input_shape[2])

        generate = model(latent_samples.to(device), rev=True)

        grid.append(generate)

    grid = torch.cat(grid)

    imshow(torchvision.utils.make_grid(grid.detach(), num_sample), figsize,
           filename if (filename is not None) else None)
Exemple #2
0
def plot_latent_traversal_line(model, latent_dim, input_size, input_shape, idx, num_sample=1, figsize=(30, 30), filename=None, dataset=None, conditional_target=None, device='cpu'):
    """
    Generates an image traversal through a latent dimension.

    :param model: INN used for sampling
    :param latent_dim: dimension of the latent space
    :param input_size: total number of elements in the input of the INN
    :param input_shape: shape of the input for the INN
    :param idx: Index of a continuous latent dimension to traverse. If None, no latent is traversed and all latent
    dimensions are randomly sampled or kept fixed.
    :param num_sample: number of samples to generate
    :param figsize: the size of the generated plot
    :param filename: file name under which the plot will be saved. (optional)
    :param dataset: dataset to draw images from for which the latent traversal will be created
    :return: None
    """

    model.to(device)
    model.eval()

    if dataset is not None:
        loader = dl.get_loader(dataset, num_sample)

        img, label = next(iter(loader))

        lat_img = model(img)
        lat_img = lat_img.view(lat_img.size(0), -1)
    else:
        lat_img = None

    latent_samples = tra.traverse_continous_line(latent_dim, input_size, idx, num_sample, False, lat_img, conditional_target=conditional_target)

    if len(input_shape) == 2:
        latent_samples = latent_samples.view(num_sample, input_shape[0], input_shape[1])
    elif len(input_shape) == 3:
        latent_samples = latent_samples.view(num_sample, input_shape[0], input_shape[1], input_shape[2])
    else:
        raise ValueError("input_shape is neither 2- nor 3-dimensional")

    generate = model(latent_samples.to(device), rev=True)

    imshow(torchvision.utils.make_grid(generate.detach(), num_sample), figsize, filename if (filename is not None) else None)