def autoencoder_autoregressive(): """Autoregressive autoencoder model.""" hparams = basic.basic_autoencoder() hparams.add_hparam("autoregressive_forget_base", False) hparams.add_hparam("autoregressive_mode", "conv3") hparams.add_hparam("autoregressive_dropout", 0.4) return hparams
def basic_discrete_autoencoder(): """Basic autoencoder model.""" hparams = basic.basic_autoencoder() hparams.hidden_size = 128 hparams.bottleneck_size = 512 hparams.bottleneck_warmup_steps = 3000 hparams.add_hparam("discretize_warmup_steps", 5000) return hparams
def autoencoder_autoregressive(): """Autoregressive autoencoder model.""" hparams = basic.basic_autoencoder() hparams.add_hparam("autoregressive_forget_base", False) hparams.add_hparam("autoregressive_mode", "none") hparams.add_hparam("autoregressive_dropout", 0.4) hparams.add_hparam("autoregressive_decode_steps", 0) hparams.add_hparam("autoregressive_eval_pure_autoencoder", False) return hparams
def sliced_gan(): """Basic parameters for a vanilla_gan.""" hparams = basic.basic_autoencoder() hparams.hidden_size = 128 hparams.batch_size = 128 hparams.weight_decay = 1e-6 hparams.bottleneck_bits = 128 hparams.add_hparam("discriminator_batchnorm", True) hparams.add_hparam("num_sliced_vecs", 4096) return hparams
def basic_discrete_autoencoder(): """Basic autoencoder model.""" hparams = basic.basic_autoencoder() hparams.num_hidden_layers = 5 hparams.hidden_size = 64 hparams.bottleneck_size = 4096 hparams.bottleneck_noise = 0.1 hparams.bottleneck_warmup_steps = 3000 hparams.add_hparam("discretize_warmup_steps", 5000) return hparams
def residual_autoencoder(): """Residual autoencoder model.""" hparams = basic.basic_autoencoder() hparams.optimizer = "Adam" hparams.learning_rate_constant = 0.0001 hparams.learning_rate_warmup_steps = 500 hparams.learning_rate_schedule = "constant * linear_warmup" hparams.dropout = 0.05 hparams.num_hidden_layers = 5 hparams.hidden_size = 64 hparams.max_hidden_size = 1024 hparams.add_hparam("num_residual_layers", 2) hparams.add_hparam("residual_kernel_height", 3) hparams.add_hparam("residual_kernel_width", 3) hparams.add_hparam("residual_filter_multiplier", 2.0) hparams.add_hparam("residual_dropout", 0.2) hparams.add_hparam("residual_use_separable_conv", int(True)) return hparams