def make_model(proposal_type, model_type, data_dim, mean, global_step): """Create model graph.""" kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step) # Bernoulli VAE proposal gets that data mean because it is proposing images. # Other proposals don't because they are proposing latent states. decoder_hidden_sizes = [ int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",") ] q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")] energy_hidden_sizes = [ int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",") ] if model_type in ["nis", "his", "lars", "conv_nis", "identity"]: proposal_data_dim = data_dim elif model_type in [ "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae", "conv_bernoulli_vae" ]: proposal_data_dim = [FLAGS.latent_dim] if proposal_type == "bernoulli_vae": proposal = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, reparameterize_sample=FLAGS.reparameterize_proposal, temperature=FLAGS.gst_temperature, dtype=tf.float32) if proposal_type == "conv_bernoulli_vae": proposal = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "gaussian_vae": proposal = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "nis": proposal = nis.NIS( K=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "rejection_sampling": proposal = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "gaussian": proposal = base.get_independent_normal(proposal_data_dim) elif proposal_type == "his": proposal = his.FullyConnectedHIS( T=FLAGS.his_T, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) elif proposal_type == "lars": proposal = lars.LARS( K=FLAGS.K, T=FLAGS.lars_T, data_dim=proposal_data_dim, accept_fn_layers=energy_hidden_sizes, proposal=None, data_mean=None, ema_decay=0.99, is_eval=FLAGS.mode == "eval", dtype=tf.float32) if model_type == "bernoulli_vae": model = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, reparameterize_sample=False, dtype=tf.float32) elif model_type == "gaussian_vae": model = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_gaussian_vae": model = vae.ConvGaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_bernoulli_vae": model = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, proposal=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "nis": model = nis.NIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "conv_nis": model = nis.ConvNIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "his": model = his.FullyConnectedHIS( proposal=proposal, T=FLAGS.his_T, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) elif model_type == "lars": model = lars.LARS( K=FLAGS.K, T=FLAGS.lars_T, data_dim=data_dim, accept_fn_layers=energy_hidden_sizes, proposal=proposal, data_mean=None if FLAGS.squash else mean, ema_decay=0.99, is_eval=FLAGS.mode == "eval", dtype=tf.float32) elif model_type == "identity": model = proposal # elif model_type == "hisvae": # model = his.HISVAE( # latent_dim=FLAGS.latent_dim, # proposal=proposal, # T=FLAGS.his_T, # data_dim=data_dim, # data_mean=mean, # energy_hidden_sizes=energy_hidden_sizes, # q_hidden_sizes=q_hidden_sizes, # decoder_hidden_sizes=decoder_hidden_sizes, # learn_temps=FLAGS.learn_his_temps, # learn_stepsize=FLAGS.learn_his_stepsize, # init_alpha=FLAGS.his_init_alpha, # init_step_size=FLAGS.his_init_stepsize, # squash=FLAGS.squash, # kl_weight=kl_weight, # dtype=tf.float32) if FLAGS.squash: model = base.SquashedDistribution(distribution=model, data_mean=mean) return model
def make_model(proposal_type, model_type, data_dim, mean, global_step): """Create model graph.""" kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step) # Bernoulli VAE proposal gets that data mean because it is proposing images. # Other proposals don't because they are proposing latent states. decoder_hidden_sizes = [ int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",") ] q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")] energy_hidden_sizes = [ int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",") ] if model_type in ["nis", "his"]: proposal_data_dim = data_dim elif model_type in [ "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae", "conv_bernoulli_vae" ]: proposal_data_dim = FLAGS.latent_dim if proposal_type == "bernoulli_vae": proposal = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, reparameterize_sample=FLAGS.reparameterize_proposal, temperature=FLAGS.gst_temperature, dtype=tf.float32) elif proposal_type == "gaussian_vae": proposal = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=proposal_data_dim, # data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, kl_weight=kl_weight, dtype=tf.float32) elif proposal_type == "nis": proposal = nis.NIS( K=FLAGS.K, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, dtype=tf.float32) elif proposal_type == "rejection_sampling": logit_accept_fn = functools.partial( base.mlp, layer_sizes=energy_hidden_sizes + [1], final_activation=None, name="rejection_sampling/energy_fn_mlp") proposal = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=[proposal_data_dim], logit_accept_fn=logit_accept_fn, dtype=tf.float32) elif proposal_type == "gaussian": proposal = tfd.MultivariateNormalDiag( loc=tf.zeros([proposal_data_dim], dtype=tf.float32), scale_diag=tf.ones([proposal_data_dim], dtype=tf.float32)) elif proposal_type == "his": proposal = his.FullyConnectedHIS( T=FLAGS.his_T, data_dim=proposal_data_dim, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) if model_type == "bernoulli_vae": model = vae.BernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=mean, decoder_hidden_sizes=decoder_hidden_sizes, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, prior=proposal, kl_weight=kl_weight, reparameterize_sample=False, dtype=tf.float32) elif model_type == "gaussian_vae": model = vae.GaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, decoder_hidden_sizes=decoder_hidden_sizes, decoder_nn_scale=FLAGS.vae_decoder_nn_scale, q_hidden_sizes=q_hidden_sizes, scale_min=FLAGS.scale_min, prior=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_gaussian_vae": model = vae.ConvGaussianVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, base_depth=FLAGS.base_depth, scale_min=FLAGS.scale_min, prior=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "conv_bernoulli_vae": model = vae.ConvBernoulliVAE( latent_dim=FLAGS.latent_dim, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, scale_min=FLAGS.scale_min, prior=proposal, kl_weight=kl_weight, dtype=tf.float32) elif model_type == "nis": model = nis.NIS( K=FLAGS.K, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, proposal=proposal, reparameterize_proposal_samples=FLAGS.reparameterize_proposal, dtype=tf.float32) elif model_type == "his": model = his.FullyConnectedHIS( proposal=proposal, T=FLAGS.his_T, data_dim=data_dim, data_mean=None if FLAGS.squash else mean, energy_hidden_sizes=energy_hidden_sizes, q_hidden_sizes=q_hidden_sizes, learn_temps=FLAGS.learn_his_temps, learn_stepsize=FLAGS.learn_his_stepsize, init_alpha=FLAGS.his_init_alpha, init_step_size=FLAGS.his_init_stepsize, dtype=tf.float32) # elif model_type == "hisvae": # model = his.HISVAE( # latent_dim=FLAGS.latent_dim, # proposal=proposal, # T=FLAGS.his_T, # data_dim=data_dim, # data_mean=mean, # energy_hidden_sizes=energy_hidden_sizes, # q_hidden_sizes=q_hidden_sizes, # decoder_hidden_sizes=decoder_hidden_sizes, # learn_temps=FLAGS.learn_his_temps, # learn_stepsize=FLAGS.learn_his_stepsize, # init_alpha=FLAGS.his_init_alpha, # init_step_size=FLAGS.his_init_stepsize, # squash=FLAGS.squash, # kl_weight=kl_weight, # dtype=tf.float32) if FLAGS.squash: model = base.SquashedDistribution(distribution=model, data_mean=mean) return model