Esempio n. 1
0
def make_model(proposal_type, model_type, data_dim, mean, global_step):
  """Create model graph."""
  kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step)
  # Bernoulli VAE proposal gets that data mean because it is proposing images.
  # Other proposals don't because they are proposing latent states.
  decoder_hidden_sizes = [
      int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",")
  ]
  q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")]
  energy_hidden_sizes = [
      int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",")
  ]
  if model_type in ["nis", "his", "lars", "conv_nis", "identity"]:
    proposal_data_dim = data_dim
  elif model_type in [
      "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae",
      "conv_bernoulli_vae"
  ]:
    proposal_data_dim = [FLAGS.latent_dim]

  if proposal_type == "bernoulli_vae":
    proposal = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        reparameterize_sample=FLAGS.reparameterize_proposal,
        temperature=FLAGS.gst_temperature,
        dtype=tf.float32)
  if proposal_type == "conv_bernoulli_vae":
    proposal = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "gaussian_vae":
    proposal = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "nis":
    proposal = nis.NIS(
        K=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "rejection_sampling":
    proposal = rejection_sampling.RejectionSampling(
        T=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "gaussian":
    proposal = base.get_independent_normal(proposal_data_dim)
  elif proposal_type == "his":
    proposal = his.FullyConnectedHIS(
        T=FLAGS.his_T,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)
  elif proposal_type == "lars":
    proposal = lars.LARS(
        K=FLAGS.K,
        T=FLAGS.lars_T,
        data_dim=proposal_data_dim,
        accept_fn_layers=energy_hidden_sizes,
        proposal=None,
        data_mean=None,
        ema_decay=0.99,
        is_eval=FLAGS.mode == "eval",
        dtype=tf.float32)

  if model_type == "bernoulli_vae":
    model = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        reparameterize_sample=False,
        dtype=tf.float32)
  elif model_type == "gaussian_vae":
    model = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_gaussian_vae":
    model = vae.ConvGaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_bernoulli_vae":
    model = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "nis":
    model = nis.NIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "conv_nis":
    model = nis.ConvNIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "his":
    model = his.FullyConnectedHIS(
        proposal=proposal,
        T=FLAGS.his_T,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)
  elif model_type == "lars":
    model = lars.LARS(
        K=FLAGS.K,
        T=FLAGS.lars_T,
        data_dim=data_dim,
        accept_fn_layers=energy_hidden_sizes,
        proposal=proposal,
        data_mean=None if FLAGS.squash else mean,
        ema_decay=0.99,
        is_eval=FLAGS.mode == "eval",
        dtype=tf.float32)
  elif model_type == "identity":
    model = proposal

#  elif model_type == "hisvae":
#    model = his.HISVAE(
#        latent_dim=FLAGS.latent_dim,
#        proposal=proposal,
#        T=FLAGS.his_T,
#        data_dim=data_dim,
#        data_mean=mean,
#        energy_hidden_sizes=energy_hidden_sizes,
#        q_hidden_sizes=q_hidden_sizes,
#        decoder_hidden_sizes=decoder_hidden_sizes,
#        learn_temps=FLAGS.learn_his_temps,
#        learn_stepsize=FLAGS.learn_his_stepsize,
#        init_alpha=FLAGS.his_init_alpha,
#        init_step_size=FLAGS.his_init_stepsize,
#        squash=FLAGS.squash,
#        kl_weight=kl_weight,
#        dtype=tf.float32)

  if FLAGS.squash:
    model = base.SquashedDistribution(distribution=model, data_mean=mean)

  return model
Esempio n. 2
0
def make_model(proposal_type, model_type, data_dim, mean, global_step):
  """Create model graph."""
  kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step)
  # Bernoulli VAE proposal gets that data mean because it is proposing images.
  # Other proposals don't because they are proposing latent states.
  decoder_hidden_sizes = [
      int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",")
  ]
  q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")]
  energy_hidden_sizes = [
      int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",")
  ]
  if model_type in ["nis", "his"]:
    proposal_data_dim = data_dim
  elif model_type in [
      "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae",
      "conv_bernoulli_vae"
  ]:
    proposal_data_dim = FLAGS.latent_dim

  if proposal_type == "bernoulli_vae":
    proposal = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        reparameterize_sample=FLAGS.reparameterize_proposal,
        temperature=FLAGS.gst_temperature,
        dtype=tf.float32)
  elif proposal_type == "gaussian_vae":
    proposal = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        # data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "nis":
    proposal = nis.NIS(
        K=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "rejection_sampling":
    logit_accept_fn = functools.partial(
        base.mlp,
        layer_sizes=energy_hidden_sizes + [1],
        final_activation=None,
        name="rejection_sampling/energy_fn_mlp")
    proposal = rejection_sampling.RejectionSampling(
        T=FLAGS.K,
        data_dim=[proposal_data_dim],
        logit_accept_fn=logit_accept_fn,
        dtype=tf.float32)
  elif proposal_type == "gaussian":
    proposal = tfd.MultivariateNormalDiag(
        loc=tf.zeros([proposal_data_dim], dtype=tf.float32),
        scale_diag=tf.ones([proposal_data_dim], dtype=tf.float32))
  elif proposal_type == "his":
    proposal = his.FullyConnectedHIS(
        T=FLAGS.his_T,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)

  if model_type == "bernoulli_vae":
    model = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        prior=proposal,
        kl_weight=kl_weight,
        reparameterize_sample=False,
        dtype=tf.float32)
  elif model_type == "gaussian_vae":
    model = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        prior=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_gaussian_vae":
    model = vae.ConvGaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        base_depth=FLAGS.base_depth,
        scale_min=FLAGS.scale_min,
        prior=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_bernoulli_vae":
    model = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        prior=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "nis":
    model = nis.NIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "his":
    model = his.FullyConnectedHIS(
        proposal=proposal,
        T=FLAGS.his_T,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)


#  elif model_type == "hisvae":
#    model = his.HISVAE(
#        latent_dim=FLAGS.latent_dim,
#        proposal=proposal,
#        T=FLAGS.his_T,
#        data_dim=data_dim,
#        data_mean=mean,
#        energy_hidden_sizes=energy_hidden_sizes,
#        q_hidden_sizes=q_hidden_sizes,
#        decoder_hidden_sizes=decoder_hidden_sizes,
#        learn_temps=FLAGS.learn_his_temps,
#        learn_stepsize=FLAGS.learn_his_stepsize,
#        init_alpha=FLAGS.his_init_alpha,
#        init_step_size=FLAGS.his_init_stepsize,
#        squash=FLAGS.squash,
#        kl_weight=kl_weight,
#        dtype=tf.float32)

  if FLAGS.squash:
    model = base.SquashedDistribution(distribution=model, data_mean=mean)

  return model