def get_datasets(): """Returns all the data sets.""" return h.sweep( "dataset.name", h.categorical([ "dsprites_full", "color_dsprites", "noisy_dsprites", "scream_dsprites", "smallnorb", "cars3d", "shapes3d" ]))
def get_default_models(): """Our default set of models (6 model * 6 hyperparameters=36 models).""" # BetaVAE config. model_name = h.fixed("model.name", "beta_vae") model_fn = h.fixed("model.model", "@vae()") betas = h.sweep("vae.beta", h.discrete([1., 2., 4., 6., 8., 16.])) config_beta_vae = h.zipit([model_name, betas, model_fn]) # AnnealedVAE config. model_name = h.fixed("model.name", "annealed_vae") model_fn = h.fixed("model.model", "@annealed_vae()") iteration_threshold = h.fixed("annealed_vae.iteration_threshold", 100000) c = h.sweep("annealed_vae.c_max", h.discrete([5., 10., 25., 50., 75., 100.])) gamma = h.fixed("annealed_vae.gamma", 1000) config_annealed_beta_vae = h.zipit( [model_name, c, iteration_threshold, gamma, model_fn]) # FactorVAE config. model_name = h.fixed("model.name", "factor_vae") model_fn = h.fixed("model.model", "@factor_vae()") discr_fn = h.fixed("discriminator.discriminator_fn", "@fc_discriminator") gammas = h.sweep("factor_vae.gamma", h.discrete([10., 20., 30., 40., 50., 100.])) config_factor_vae = h.zipit([model_name, gammas, model_fn, discr_fn]) # DIP-VAE-I config. model_name = h.fixed("model.name", "dip_vae_i") model_fn = h.fixed("model.model", "@dip_vae()") lambda_od = h.sweep("dip_vae.lambda_od", h.discrete([1., 2., 5., 10., 20., 50.])) lambda_d_factor = h.fixed("dip_vae.lambda_d_factor", 10.) dip_type = h.fixed("dip_vae.dip_type", "i") config_dip_vae_i = h.zipit( [model_name, model_fn, lambda_od, lambda_d_factor, dip_type]) # DIP-VAE-II config. model_name = h.fixed("model.name", "dip_vae_ii") model_fn = h.fixed("model.model", "@dip_vae()") lambda_od = h.sweep("dip_vae.lambda_od", h.discrete([1., 2., 5., 10., 20., 50.])) lambda_d_factor = h.fixed("dip_vae.lambda_d_factor", 1.) dip_type = h.fixed("dip_vae.dip_type", "ii") config_dip_vae_ii = h.zipit( [model_name, model_fn, lambda_od, lambda_d_factor, dip_type]) # BetaTCVAE config. model_name = h.fixed("model.name", "beta_tc_vae") model_fn = h.fixed("model.model", "@beta_tc_vae()") betas = h.sweep("beta_tc_vae.beta", h.discrete([1., 2., 4., 6., 8., 10.])) config_beta_tc_vae = h.zipit([model_name, model_fn, betas]) all_models = h.chainit([ config_beta_vae, config_factor_vae, config_dip_vae_i, config_dip_vae_ii, config_beta_tc_vae, config_annealed_beta_vae ]) return all_models
def get_seeds(num): """Returns random seeds.""" return h.sweep("model.random_seed", h.categorical(list(range(num))))
def get_num_latent(sweep): return h.sweep("encoder.num_latent", h.discrete(sweep))