def fixed(ctx, batch_size, filename, rows, cols, plot, verbose, **kwargs): """View/save images of dataset given a fixed latent factor.""" dataset = ctx.obj['dataset'] add_gin(ctx, 'config', ['evaluate/dataset/{}.gin'.format(dataset)]) parse(ctx) with gin.unlock_config(): gin.bind_parameter('disentangled.visualize.show.output.show_plot', plot) gin.bind_parameter('disentangled.visualize.show.output.filename', filename) if rows is not None: gin.bind_parameter('disentangled.visualize.fixed_factor_data.rows', rows) if cols is not None: gin.bind_parameter('disentangled.visualize.fixed_factor_data.cols', cols) num_values_per_factor = disentangled.dataset.get( dataset).num_values_per_factor dataset = disentangled.dataset.get(dataset).supervised() fixed, _ = disentangled.metric.utils.fixed_factor_dataset( dataset, batch_size, num_values_per_factor) disentangled.visualize.fixed_factor_data(fixed, rows=gin.REQUIRED, cols=gin.REQUIRED, verbose=verbose)
def gui(ctx, **kwargs): """Qualitiative interactive evaluation of disentanglement.""" add_gin(ctx, "config", ["evaluate/visual/gui.gin"]) parse(ctx, set_seed=True) dataset = ctx.obj["dataset"].pipeline() disentangled.visualize.gui(ctx.obj["model"], dataset)
def latent1d(ctx, rows, cols, plot, filename, **kwargs): """Latent space traversal in 1D.""" add_gin(ctx, "config", ["evaluate/visual/latent1d.gin"]) parse(ctx, set_seed=True) with gin.unlock_config(): gin.bind_parameter("disentangled.visualize.show.output.show_plot", plot) if filename is not None: gin.bind_parameter("disentangled.visualize.show.output.filename", filename) if rows is not None: gin.bind_parameter("disentangled.visualize.traversal1d.dimensions", rows) if cols is not None: gin.bind_parameter("disentangled.visualize.traversal1d.steps", cols) dataset = ctx.obj["dataset"].pipeline() disentangled.visualize.traversal1d( ctx.obj["model"], dataset, dimensions=gin.REQUIRED, offset=gin.REQUIRED, skip_batches=gin.REQUIRED, steps=gin.REQUIRED, )
def loglikelihood(ctx, **kwargs): """The logarithmic likelihood of the input given the representation.""" add_gin(ctx, "config", ["metric/loglikelihood.gin"]) parse(ctx, set_seed=True) metric = disentangled.metric.loglikelihood( ctx.obj["model"], dataset=gin.REQUIRED, ) breakpoint() disentangled.metric.log_metric(metric, name=ctx.obj["model_str"], metric_name=gin.REQUIRED)
def collapsed(ctx, **kwargs): """Number of latent dimensions collapsed to prior.""" add_gin(ctx, "config", ["metric/collapsed.gin"]) parse(ctx, set_seed=True) metric = disentangled.metric.collapsed( ctx.obj["model"], dataset=gin.REQUIRED, tolerance=gin.REQUIRED, ) disentangled.metric.log_metric(metric, name=ctx.obj["model_str"], metric_name=gin.REQUIRED)
def mig(ctx, **kwargs): """Mutual Information GAP. Quantitative disentanglement metric.""" add_gin(ctx, "config", ["metric/mig.gin"]) parse(ctx, set_seed=True) metric = disentangled.metric.mutual_information_gap( ctx.obj["model"], dataset=gin.REQUIRED, encoding_dist=gin.REQUIRED, ) disentangled.metric.log_metric(metric, metric_name=gin.REQUIRED, name=ctx.obj["model_str"])
def dmig(ctx, **kwargs): """Discrete Mutual Information Gap. Quantitative disentanglement metric.""" add_gin(ctx, "config", ["metric/dmig.gin"]) parse(ctx, set_seed=True) metric = disentangled.metric.discrete_mutual_information_gap( ctx.obj["model"], dataset=gin.REQUIRED, points=gin.REQUIRED, batch_size=gin.REQUIRED) disentangled.metric.log_metric(metric, metric_name=gin.REQUIRED, name=ctx.obj["model_str"])
def gini_index(ctx, **kwargs): """Quantitative sparsity metric of representation.""" add_gin(ctx, "config", ["metric/gini.gin"]) parse(ctx, set_seed=True) metric = disentangled.metric.gini_index( ctx.obj["model"], dataset=gin.REQUIRED, points=gin.REQUIRED, batch_size=gin.REQUIRED, tolerance=gin.REQUIRED, ) disentangled.metric.log_metric(metric, name=ctx.obj["model_str"], metric_name=gin.REQUIRED)
def factorvae_score(ctx, **kwargs): """The FactorVAE-score. Quantitative disentanglement metric.""" add_gin(ctx, "config", ["metric/factorvae_score.gin"]) parse(ctx, set_seed=True) metric = disentangled.metric.factorvae_score( ctx.obj["model"], dataset=gin.REQUIRED, training_points=gin.REQUIRED, test_points=gin.REQUIRED, tolerance=gin.REQUIRED, ) disentangled.metric.log_metric(metric, name=ctx.obj["model_str"], metric_name=gin.REQUIRED)
def evaluate(ctx, model, path, **kwargs): """Interface for evaluating models. Quantitative and qualitative evaluation methods are provided. """ ctx.obj["model"] = disentangled.model.utils.load(path, model) ctx.obj["model_str"] = model method, dataset = model.split("/") ctx.obj["dataset"] = disentangled.dataset.get(dataset) ctx.obj["method_str"] = method ctx.obj["dataset_str"] = dataset add_gin(ctx, "gin_param", ["HP_SWEEP_VALUES=None"]) add_gin(ctx, "gin_param", ["log_metric.path='{}'".format(path)]) add_gin(ctx, "config", ["evaluate/evaluate.gin"]) add_gin(ctx, "config", ["evaluate/dataset/" + dataset + ".gin"]) add_gin(ctx, "config", ["evaluate/model/" + method + ".gin"])
def latent2d(ctx, rows, cols, plot, filename, **kwargs): """Latent space traversal in 2D.""" add_gin(ctx, "config", ["evaluate/visual/latent2d.gin"]) parse(ctx, set_seed=True) with gin.unlock_config(): gin.bind_parameter("disentangled.visualize.show.output.show_plot", plot) gin.bind_parameter("disentangled.visualize.show.output.filename", filename) if rows is not None: gin.bind_parameter("disentangled.visualize.traversal2d.rows", rows) if cols is not None: gin.bind_parameter("disentangled.visualize.traversal2d.cols", cols) dataset = ctx.obj["dataset"].pipeline() disentangled.visualize.traversal2d(ctx.obj["model"], dataset)
def train(ctx, model, config, gin_param, gin_file, path, checkpoint, overwrite=False, print_operative=False): """Interface for training models by using syntax MODEL/DATASET. Training is executed with default config values found in config/train directory. See config files or print operative configuration for configurable parameters. \b Example of modifying default config: disentangled train BetaVAE/DSprites -p iterations=1000 \b Available Models: BetaVAE BetaSVAE BetaTCVAE FactorVAE Available Datasets: Shapes3d DSprites CelebA """ add_gin(ctx, 'config', ['train/{}.gin'.format(model)], insert=True) if overwrite is True: add_gin(ctx, 'gin_param'['disentangled.model.utils.save.overwrite=True']) parse(ctx) disentangled.training.run_training(model=gin.REQUIRED, dataset=gin.REQUIRED, iterations=gin.REQUIRED, path=path, checkpoint=checkpoint) if print_operative: print(gin.operative_config_str())
def examples(ctx, filename, rows, cols, plot, **kwargs): """View or save example images of dataset.""" dataset = ctx.obj['dataset'] add_gin(ctx, 'config', ['evaluate/dataset/{}.gin'.format(dataset)]) parse(ctx) with gin.unlock_config(): gin.bind_parameter('disentangled.visualize.show.output.show_plot', plot) gin.bind_parameter('disentangled.visualize.show.output.filename', filename) if rows is not None: gin.bind_parameter('disentangled.visualize.data.rows', rows) if cols is not None: gin.bind_parameter('disentangled.visualize.data.cols', cols) dataset = disentangled.dataset.get(dataset).pipeline() disentangled.visualize.data(dataset, rows=gin.REQUIRED, cols=gin.REQUIRED)
def cli(ctx, **kwargs): """Train and evaluate disentangled representation learning models.""" add_gin(ctx, "config", ["config.gin"])