def latent1d(ctx, rows, cols, plot, filename, **kwargs):
    """Latent space traversal in 1D."""
    add_gin(ctx, "config", ["evaluate/visual/latent1d.gin"])
    parse(ctx, set_seed=True)

    with gin.unlock_config():
        gin.bind_parameter("disentangled.visualize.show.output.show_plot",
                           plot)

        if filename is not None:
            gin.bind_parameter("disentangled.visualize.show.output.filename",
                               filename)

        if rows is not None:
            gin.bind_parameter("disentangled.visualize.traversal1d.dimensions",
                               rows)

        if cols is not None:
            gin.bind_parameter("disentangled.visualize.traversal1d.steps",
                               cols)

    dataset = ctx.obj["dataset"].pipeline()
    disentangled.visualize.traversal1d(
        ctx.obj["model"],
        dataset,
        dimensions=gin.REQUIRED,
        offset=gin.REQUIRED,
        skip_batches=gin.REQUIRED,
        steps=gin.REQUIRED,
    )
def fixed(ctx, batch_size, filename, rows, cols, plot, verbose, **kwargs):
    """View/save images of dataset given a fixed latent factor."""
    dataset = ctx.obj['dataset']
    add_gin(ctx, 'config', ['evaluate/dataset/{}.gin'.format(dataset)])
    parse(ctx)

    with gin.unlock_config():
        gin.bind_parameter('disentangled.visualize.show.output.show_plot',
                           plot)
        gin.bind_parameter('disentangled.visualize.show.output.filename',
                           filename)
        if rows is not None:
            gin.bind_parameter('disentangled.visualize.fixed_factor_data.rows',
                               rows)
        if cols is not None:
            gin.bind_parameter('disentangled.visualize.fixed_factor_data.cols',
                               cols)

    num_values_per_factor = disentangled.dataset.get(
        dataset).num_values_per_factor
    dataset = disentangled.dataset.get(dataset).supervised()

    fixed, _ = disentangled.metric.utils.fixed_factor_dataset(
        dataset, batch_size, num_values_per_factor)

    disentangled.visualize.fixed_factor_data(fixed,
                                             rows=gin.REQUIRED,
                                             cols=gin.REQUIRED,
                                             verbose=verbose)
def gui(ctx, **kwargs):
    """Qualitiative interactive evaluation of disentanglement."""
    add_gin(ctx, "config", ["evaluate/visual/gui.gin"])
    parse(ctx, set_seed=True)
    dataset = ctx.obj["dataset"].pipeline()

    disentangled.visualize.gui(ctx.obj["model"], dataset)
def visual(ctx, rows, cols, plot, filename, **kwargs):
    """Qualitative evaluation of output."""
    parse(ctx, set_seed=True)

    with gin.unlock_config():
        gin.bind_parameter("disentangled.visualize.show.output.show_plot",
                           plot)

        if filename is not None:
            gin.bind_parameter("disentangled.visualize.show.output.filename",
                               filename)

        if rows is not None:
            gin.bind_parameter("disentangled.visualize.reconstructed.rows",
                               rows)

        if cols is not None:
            gin.bind_parameter("disentangled.visualize.reconstructed.cols",
                               cols)

    dataset = ctx.obj["dataset"].pipeline()
    disentangled.visualize.reconstructed(ctx.obj["model"],
                                         dataset,
                                         rows=gin.REQUIRED,
                                         cols=gin.REQUIRED)
def loglikelihood(ctx, **kwargs):
    """The logarithmic likelihood of the input given the representation."""
    add_gin(ctx, "config", ["metric/loglikelihood.gin"])
    parse(ctx, set_seed=True)

    metric = disentangled.metric.loglikelihood(
        ctx.obj["model"],
        dataset=gin.REQUIRED,
    )
    breakpoint()
    disentangled.metric.log_metric(metric,
                                   name=ctx.obj["model_str"],
                                   metric_name=gin.REQUIRED)
def collapsed(ctx, **kwargs):
    """Number of latent dimensions collapsed to prior."""
    add_gin(ctx, "config", ["metric/collapsed.gin"])
    parse(ctx, set_seed=True)

    metric = disentangled.metric.collapsed(
        ctx.obj["model"],
        dataset=gin.REQUIRED,
        tolerance=gin.REQUIRED,
    )
    disentangled.metric.log_metric(metric,
                                   name=ctx.obj["model_str"],
                                   metric_name=gin.REQUIRED)
def mig(ctx, **kwargs):
    """Mutual Information GAP.
    Quantitative disentanglement metric."""
    add_gin(ctx, "config", ["metric/mig.gin"])
    parse(ctx, set_seed=True)

    metric = disentangled.metric.mutual_information_gap(
        ctx.obj["model"],
        dataset=gin.REQUIRED,
        encoding_dist=gin.REQUIRED,
    )
    disentangled.metric.log_metric(metric,
                                   metric_name=gin.REQUIRED,
                                   name=ctx.obj["model_str"])
def dmig(ctx, **kwargs):
    """Discrete Mutual Information Gap.
    Quantitative disentanglement metric."""
    add_gin(ctx, "config", ["metric/dmig.gin"])
    parse(ctx, set_seed=True)

    metric = disentangled.metric.discrete_mutual_information_gap(
        ctx.obj["model"],
        dataset=gin.REQUIRED,
        points=gin.REQUIRED,
        batch_size=gin.REQUIRED)
    disentangled.metric.log_metric(metric,
                                   metric_name=gin.REQUIRED,
                                   name=ctx.obj["model_str"])
def gini_index(ctx, **kwargs):
    """Quantitative sparsity metric of representation."""
    add_gin(ctx, "config", ["metric/gini.gin"])
    parse(ctx, set_seed=True)

    metric = disentangled.metric.gini_index(
        ctx.obj["model"],
        dataset=gin.REQUIRED,
        points=gin.REQUIRED,
        batch_size=gin.REQUIRED,
        tolerance=gin.REQUIRED,
    )
    disentangled.metric.log_metric(metric,
                                   name=ctx.obj["model_str"],
                                   metric_name=gin.REQUIRED)
def factorvae_score(ctx, **kwargs):
    """The FactorVAE-score.
    Quantitative disentanglement metric."""
    add_gin(ctx, "config", ["metric/factorvae_score.gin"])
    parse(ctx, set_seed=True)

    metric = disentangled.metric.factorvae_score(
        ctx.obj["model"],
        dataset=gin.REQUIRED,
        training_points=gin.REQUIRED,
        test_points=gin.REQUIRED,
        tolerance=gin.REQUIRED,
    )
    disentangled.metric.log_metric(metric,
                                   name=ctx.obj["model_str"],
                                   metric_name=gin.REQUIRED)
def latent2d(ctx, rows, cols, plot, filename, **kwargs):
    """Latent space traversal in 2D."""
    add_gin(ctx, "config", ["evaluate/visual/latent2d.gin"])
    parse(ctx, set_seed=True)

    with gin.unlock_config():
        gin.bind_parameter("disentangled.visualize.show.output.show_plot",
                           plot)
        gin.bind_parameter("disentangled.visualize.show.output.filename",
                           filename)

        if rows is not None:
            gin.bind_parameter("disentangled.visualize.traversal2d.rows", rows)

        if cols is not None:
            gin.bind_parameter("disentangled.visualize.traversal2d.cols", cols)

    dataset = ctx.obj["dataset"].pipeline()
    disentangled.visualize.traversal2d(ctx.obj["model"], dataset)
def train(ctx,
          model,
          config,
          gin_param,
          gin_file,
          path,
          checkpoint,
          overwrite=False,
          print_operative=False):
    """Interface for training models by using syntax MODEL/DATASET.
    Training is executed with default config values found in config/train directory. 
    See config files or print operative configuration for configurable parameters.

    \b
    Example of modifying default config:
        disentangled train BetaVAE/DSprites -p iterations=1000

    \b
    Available Models:
        BetaVAE
        BetaSVAE
        BetaTCVAE
        FactorVAE
    Available Datasets:
        Shapes3d
        DSprites
        CelebA
    """
    add_gin(ctx, 'config', ['train/{}.gin'.format(model)], insert=True)

    if overwrite is True:
        add_gin(ctx,
                'gin_param'['disentangled.model.utils.save.overwrite=True'])
    parse(ctx)

    disentangled.training.run_training(model=gin.REQUIRED,
                                       dataset=gin.REQUIRED,
                                       iterations=gin.REQUIRED,
                                       path=path,
                                       checkpoint=checkpoint)

    if print_operative:
        print(gin.operative_config_str())
Пример #13
0
def examples(ctx, filename, rows, cols, plot, **kwargs):
    """View or save example images of dataset."""
    dataset = ctx.obj['dataset']

    add_gin(ctx, 'config', ['evaluate/dataset/{}.gin'.format(dataset)])
    parse(ctx)

    with gin.unlock_config():
        gin.bind_parameter('disentangled.visualize.show.output.show_plot',
                           plot)
        gin.bind_parameter('disentangled.visualize.show.output.filename',
                           filename)
        if rows is not None:
            gin.bind_parameter('disentangled.visualize.data.rows', rows)
        if cols is not None:
            gin.bind_parameter('disentangled.visualize.data.cols', cols)

    dataset = disentangled.dataset.get(dataset).pipeline()

    disentangled.visualize.data(dataset, rows=gin.REQUIRED, cols=gin.REQUIRED)