示例#1
0
def make_lars_graph(
        target_dist,  # pylint: disable=invalid-name
        K,
        batch_size,
        eval_batch_size,
        lr,
        mlp_layers,
        dtype=tf.float32):
    """Construct the training graph for LARS."""
    proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
    model = lars.LARS(K=K,
                      T=K,
                      data_dim=[2],
                      accept_fn_layers=mlp_layers,
                      proposal=proposal,
                      dtype=dtype)

    train_data = target_dist.sample(batch_size)
    log_p = model.log_prob(train_data)
    test_data = target_dist.sample(eval_batch_size)
    eval_log_p = model.log_prob(test_data)

    global_step = tf.train.get_or_create_global_step()
    opt = tf.train.AdamOptimizer(lr)
    grads = opt.compute_gradients(-tf.reduce_mean(log_p))
    apply_grads_op = opt.apply_gradients(grads, global_step=global_step)
    with tf.control_dependencies([apply_grads_op]):
        train_op = model.post_train_op()

    density_image_summary(
        lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x),
        FLAGS.density_num_bins, "energy/lars")
    tf.summary.scalar("elbo", tf.reduce_mean(log_p))
    tf.summary.scalar("eval_elbo", tf.reduce_mean(eval_log_p))
    return -tf.reduce_mean(log_p), train_op, global_step
示例#2
0
def make_model(proposal_type, model_type, data_dim, mean, global_step):
  """Create model graph."""
  kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step)
  # Bernoulli VAE proposal gets that data mean because it is proposing images.
  # Other proposals don't because they are proposing latent states.
  decoder_hidden_sizes = [
      int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",")
  ]
  q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")]
  energy_hidden_sizes = [
      int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",")
  ]
  if model_type in ["nis", "his", "lars", "conv_nis", "identity"]:
    proposal_data_dim = data_dim
  elif model_type in [
      "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae",
      "conv_bernoulli_vae"
  ]:
    proposal_data_dim = [FLAGS.latent_dim]

  if proposal_type == "bernoulli_vae":
    proposal = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        reparameterize_sample=FLAGS.reparameterize_proposal,
        temperature=FLAGS.gst_temperature,
        dtype=tf.float32)
  if proposal_type == "conv_bernoulli_vae":
    proposal = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "gaussian_vae":
    proposal = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "nis":
    proposal = nis.NIS(
        K=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "rejection_sampling":
    proposal = rejection_sampling.RejectionSampling(
        T=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "gaussian":
    proposal = base.get_independent_normal(proposal_data_dim)
  elif proposal_type == "his":
    proposal = his.FullyConnectedHIS(
        T=FLAGS.his_T,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)
  elif proposal_type == "lars":
    proposal = lars.LARS(
        K=FLAGS.K,
        T=FLAGS.lars_T,
        data_dim=proposal_data_dim,
        accept_fn_layers=energy_hidden_sizes,
        proposal=None,
        data_mean=None,
        ema_decay=0.99,
        is_eval=FLAGS.mode == "eval",
        dtype=tf.float32)

  if model_type == "bernoulli_vae":
    model = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        reparameterize_sample=False,
        dtype=tf.float32)
  elif model_type == "gaussian_vae":
    model = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_gaussian_vae":
    model = vae.ConvGaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_bernoulli_vae":
    model = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "nis":
    model = nis.NIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "conv_nis":
    model = nis.ConvNIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "his":
    model = his.FullyConnectedHIS(
        proposal=proposal,
        T=FLAGS.his_T,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)
  elif model_type == "lars":
    model = lars.LARS(
        K=FLAGS.K,
        T=FLAGS.lars_T,
        data_dim=data_dim,
        accept_fn_layers=energy_hidden_sizes,
        proposal=proposal,
        data_mean=None if FLAGS.squash else mean,
        ema_decay=0.99,
        is_eval=FLAGS.mode == "eval",
        dtype=tf.float32)
  elif model_type == "identity":
    model = proposal

#  elif model_type == "hisvae":
#    model = his.HISVAE(
#        latent_dim=FLAGS.latent_dim,
#        proposal=proposal,
#        T=FLAGS.his_T,
#        data_dim=data_dim,
#        data_mean=mean,
#        energy_hidden_sizes=energy_hidden_sizes,
#        q_hidden_sizes=q_hidden_sizes,
#        decoder_hidden_sizes=decoder_hidden_sizes,
#        learn_temps=FLAGS.learn_his_temps,
#        learn_stepsize=FLAGS.learn_his_stepsize,
#        init_alpha=FLAGS.his_init_alpha,
#        init_step_size=FLAGS.his_init_stepsize,
#        squash=FLAGS.squash,
#        kl_weight=kl_weight,
#        dtype=tf.float32)

  if FLAGS.squash:
    model = base.SquashedDistribution(distribution=model, data_mean=mean)

  return model