예제 #1
0
def get_datasets():
    """Returns all the data sets."""
    return h.sweep(
        "dataset.name",
        h.categorical([
            "dsprites_full", "color_dsprites", "noisy_dsprites",
            "scream_dsprites", "smallnorb", "cars3d", "shapes3d"
        ]))
예제 #2
0
def get_default_models():
    """Our default set of models (6 model * 6 hyperparameters=36 models)."""
    # BetaVAE config.
    model_name = h.fixed("model.name", "beta_vae")
    model_fn = h.fixed("model.model", "@vae()")
    betas = h.sweep("vae.beta", h.discrete([1., 2., 4., 6., 8., 16.]))
    config_beta_vae = h.zipit([model_name, betas, model_fn])

    # AnnealedVAE config.
    model_name = h.fixed("model.name", "annealed_vae")
    model_fn = h.fixed("model.model", "@annealed_vae()")
    iteration_threshold = h.fixed("annealed_vae.iteration_threshold", 100000)
    c = h.sweep("annealed_vae.c_max",
                h.discrete([5., 10., 25., 50., 75., 100.]))
    gamma = h.fixed("annealed_vae.gamma", 1000)
    config_annealed_beta_vae = h.zipit(
        [model_name, c, iteration_threshold, gamma, model_fn])

    # FactorVAE config.
    model_name = h.fixed("model.name", "factor_vae")
    model_fn = h.fixed("model.model", "@factor_vae()")
    discr_fn = h.fixed("discriminator.discriminator_fn", "@fc_discriminator")

    gammas = h.sweep("factor_vae.gamma",
                     h.discrete([10., 20., 30., 40., 50., 100.]))
    config_factor_vae = h.zipit([model_name, gammas, model_fn, discr_fn])

    # DIP-VAE-I config.
    model_name = h.fixed("model.name", "dip_vae_i")
    model_fn = h.fixed("model.model", "@dip_vae()")
    lambda_od = h.sweep("dip_vae.lambda_od",
                        h.discrete([1., 2., 5., 10., 20., 50.]))
    lambda_d_factor = h.fixed("dip_vae.lambda_d_factor", 10.)
    dip_type = h.fixed("dip_vae.dip_type", "i")
    config_dip_vae_i = h.zipit(
        [model_name, model_fn, lambda_od, lambda_d_factor, dip_type])

    # DIP-VAE-II config.
    model_name = h.fixed("model.name", "dip_vae_ii")
    model_fn = h.fixed("model.model", "@dip_vae()")
    lambda_od = h.sweep("dip_vae.lambda_od",
                        h.discrete([1., 2., 5., 10., 20., 50.]))
    lambda_d_factor = h.fixed("dip_vae.lambda_d_factor", 1.)
    dip_type = h.fixed("dip_vae.dip_type", "ii")
    config_dip_vae_ii = h.zipit(
        [model_name, model_fn, lambda_od, lambda_d_factor, dip_type])

    # BetaTCVAE config.
    model_name = h.fixed("model.name", "beta_tc_vae")
    model_fn = h.fixed("model.model", "@beta_tc_vae()")
    betas = h.sweep("beta_tc_vae.beta", h.discrete([1., 2., 4., 6., 8., 10.]))
    config_beta_tc_vae = h.zipit([model_name, model_fn, betas])
    all_models = h.chainit([
        config_beta_vae, config_factor_vae, config_dip_vae_i,
        config_dip_vae_ii, config_beta_tc_vae, config_annealed_beta_vae
    ])
    return all_models
예제 #3
0
def get_seeds(num):
    """Returns random seeds."""
    return h.sweep("model.random_seed", h.categorical(list(range(num))))
예제 #4
0
def get_num_latent(sweep):
    return h.sweep("encoder.num_latent", h.discrete(sweep))