def make_model(proposal_type, model_type, data_dim, mean, global_step):
  """Create model graph."""
  kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step)
  # Bernoulli VAE proposal gets that data mean because it is proposing images.
  # Other proposals don't because they are proposing latent states.
  decoder_hidden_sizes = [
      int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",")
  ]
  q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")]
  energy_hidden_sizes = [
      int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",")
  ]
  if model_type in ["nis", "his", "lars", "conv_nis", "identity"]:
    proposal_data_dim = data_dim
  elif model_type in [
      "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae",
      "conv_bernoulli_vae"
  ]:
    proposal_data_dim = [FLAGS.latent_dim]

  if proposal_type == "bernoulli_vae":
    proposal = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        reparameterize_sample=FLAGS.reparameterize_proposal,
        temperature=FLAGS.gst_temperature,
        dtype=tf.float32)
  if proposal_type == "conv_bernoulli_vae":
    proposal = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "gaussian_vae":
    proposal = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "nis":
    proposal = nis.NIS(
        K=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "rejection_sampling":
    proposal = rejection_sampling.RejectionSampling(
        T=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "gaussian":
    proposal = base.get_independent_normal(proposal_data_dim)
  elif proposal_type == "his":
    proposal = his.FullyConnectedHIS(
        T=FLAGS.his_T,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)
  elif proposal_type == "lars":
    proposal = lars.LARS(
        K=FLAGS.K,
        T=FLAGS.lars_T,
        data_dim=proposal_data_dim,
        accept_fn_layers=energy_hidden_sizes,
        proposal=None,
        data_mean=None,
        ema_decay=0.99,
        is_eval=FLAGS.mode == "eval",
        dtype=tf.float32)

  if model_type == "bernoulli_vae":
    model = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        reparameterize_sample=False,
        dtype=tf.float32)
  elif model_type == "gaussian_vae":
    model = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_gaussian_vae":
    model = vae.ConvGaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_bernoulli_vae":
    model = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        proposal=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "nis":
    model = nis.NIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "conv_nis":
    model = nis.ConvNIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "his":
    model = his.FullyConnectedHIS(
        proposal=proposal,
        T=FLAGS.his_T,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)
  elif model_type == "lars":
    model = lars.LARS(
        K=FLAGS.K,
        T=FLAGS.lars_T,
        data_dim=data_dim,
        accept_fn_layers=energy_hidden_sizes,
        proposal=proposal,
        data_mean=None if FLAGS.squash else mean,
        ema_decay=0.99,
        is_eval=FLAGS.mode == "eval",
        dtype=tf.float32)
  elif model_type == "identity":
    model = proposal

#  elif model_type == "hisvae":
#    model = his.HISVAE(
#        latent_dim=FLAGS.latent_dim,
#        proposal=proposal,
#        T=FLAGS.his_T,
#        data_dim=data_dim,
#        data_mean=mean,
#        energy_hidden_sizes=energy_hidden_sizes,
#        q_hidden_sizes=q_hidden_sizes,
#        decoder_hidden_sizes=decoder_hidden_sizes,
#        learn_temps=FLAGS.learn_his_temps,
#        learn_stepsize=FLAGS.learn_his_stepsize,
#        init_alpha=FLAGS.his_init_alpha,
#        init_step_size=FLAGS.his_init_stepsize,
#        squash=FLAGS.squash,
#        kl_weight=kl_weight,
#        dtype=tf.float32)

  if FLAGS.squash:
    model = base.SquashedDistribution(distribution=model, data_mean=mean)

  return model
def main(unused_argv):
    g = tf.Graph()
    with g.as_default():
        target = dists.get_target_distribution(
            FLAGS.target,
            nine_gaussians_variance=FLAGS.nine_gaussians_variance)
        energy_fn_layers = [
            int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",")
        ]
        if FLAGS.algo == "lars":
            print("Running LARS")
            loss, train_op, global_step = make_lars_graph(
                target_dist=target,
                K=FLAGS.K,
                batch_size=FLAGS.batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                lr=FLAGS.learning_rate,
                mlp_layers=energy_fn_layers,
                dtype=tf.float32)
        else:
            if FLAGS.algo == "nis":
                print("Running NIS")
                model = nis.NIS(K=FLAGS.K,
                                data_dim=2,
                                energy_hidden_sizes=energy_fn_layers)
                density_image_summary(
                    lambda x:  # pylint: disable=g-long-lambda
                    (tf.squeeze(model.energy_fn(x)) + model.proposal.log_prob(
                        x)),
                    FLAGS.density_num_bins,
                    "energy/nis")
            elif FLAGS.algo == "rejection_sampling":
                print("Running Rejection Sampling")
                logit_accept_fn = tf.keras.Sequential([
                    tf.keras.layers.Dense(layer_size, activation="tanh")
                    for layer_size in energy_fn_layers
                ] + [tf.keras.layers.Dense(1, activation=None)])
                model = rejection_sampling.RejectionSampling(
                    T=FLAGS.K, data_dim=[2], logit_accept_fn=logit_accept_fn)
                density_image_summary(
                    lambda x: tf.squeeze(  # pylint: disable=g-long-lambda
                        tf.log_sigmoid(model.logit_accept_fn(x)),
                        axis=-1) + model.proposal.log_prob(x),
                    FLAGS.density_num_bins,
                    "energy/trs")
            elif FLAGS.algo == "his":
                print("Running HIS")
                model = his.FullyConnectedHIS(
                    T=FLAGS.his_t,
                    data_dim=2,
                    energy_hidden_sizes=energy_fn_layers,
                    q_hidden_sizes=energy_fn_layers,
                    init_step_size=FLAGS.his_stepsize,
                    learn_stepsize=FLAGS.his_learn_stepsize,
                    init_alpha=FLAGS.his_alpha,
                    learn_temps=FLAGS.his_learn_alpha)
                density_image_summary(
                    lambda x: -model.hamiltonian_potential(x),
                    FLAGS.density_num_bins, "energy/his")
                sample_image_summary(model,
                                     "density/his",
                                     num_samples=100000,
                                     num_bins=50)

            loss, train_op, global_step = make_train_graph(
                target_dist=target,
                model=model,
                batch_size=FLAGS.batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                lr=FLAGS.learning_rate)

        log_hooks = make_log_hooks(global_step, loss)
        with tf.train.MonitoredTrainingSession(
                master="",
                is_chief=True,
                hooks=log_hooks,
                checkpoint_dir=os.path.join(FLAGS.logdir, exp_name()),
                save_checkpoint_secs=120,
                save_summaries_steps=FLAGS.summarize_every,
                log_step_count_steps=FLAGS.summarize_every) as sess:
            cur_step = -1
            while True:
                if sess.should_stop() or cur_step > FLAGS.max_steps:
                    break
                # run a step
                _, cur_step = sess.run([train_op, global_step])
Exemple #3
0
def main(unused_argv):
  g = tf.Graph()
  with g.as_default():
    energy_fn_layers = [
        int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",")
    ]
    if FLAGS.algo == "density":
      target = dists.get_target_distribution(FLAGS.target)
      plot = make_density_summary(target.log_prob, num_bins=FLAGS.num_bins)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        plot = sess.run(plot)
        with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"),
                               "w") as out:
          np.save(out, plot)
    elif FLAGS.algo == "lars":
      tf.logging.info("Running LARS")
      proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
      model = lars.SimpleLARS(
          K=FLAGS.K, data_dim=[2], accept_fn_layers=energy_fn_layers,
          proposal=proposal)
      plot = make_density_summary(
          lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x),
          num_bins=FLAGS.num_bins)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        plot = sess.run(plot)
        with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"),
                               "w") as out:
          np.save(out, plot)
    else:
      proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
      if FLAGS.algo == "nis":
        tf.logging.info("Running NIS")
        model = nis.NIS(
            K=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers,
            proposal=proposal)
      elif FLAGS.algo == "his":
        tf.logging.info("Running HIS")
        model = his.FullyConnectedHIS(
            T=FLAGS.his_t,
            data_dim=[2],
            energy_hidden_sizes=energy_fn_layers,
            q_hidden_sizes=energy_fn_layers,
            init_step_size=FLAGS.his_stepsize,
            learn_stepsize=FLAGS.his_learn_stepsize,
            init_alpha=FLAGS.his_alpha,
            learn_temps=FLAGS.his_learn_alpha,
            proposal=proposal)
      elif FLAGS.algo == "rejection_sampling":
        model = rejection_sampling.RejectionSampling(
            T=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers,
            proposal=proposal)
      samples = model.sample(FLAGS.batch_size)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        make_sample_density_summary(
            sess,
            samples,
            max_samples_per_batch=FLAGS.batch_size,
            num_samples=FLAGS.num_samples,
            num_bins=FLAGS.num_bins)
Exemple #4
0
def make_model(proposal_type, model_type, data_dim, mean, global_step):
  """Create model graph."""
  kl_weight = make_kl_weight(global_step, FLAGS.anneal_kl_step)
  # Bernoulli VAE proposal gets that data mean because it is proposing images.
  # Other proposals don't because they are proposing latent states.
  decoder_hidden_sizes = [
      int(x.strip()) for x in FLAGS.decoder_hidden_sizes.split(",")
  ]
  q_hidden_sizes = [int(x.strip()) for x in FLAGS.q_hidden_sizes.split(",")]
  energy_hidden_sizes = [
      int(x.strip()) for x in FLAGS.energy_hidden_sizes.split(",")
  ]
  if model_type in ["nis", "his"]:
    proposal_data_dim = data_dim
  elif model_type in [
      "bernoulli_vae", "gaussian_vae", "hisvae", "conv_gaussian_vae",
      "conv_bernoulli_vae"
  ]:
    proposal_data_dim = FLAGS.latent_dim

  if proposal_type == "bernoulli_vae":
    proposal = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        reparameterize_sample=FLAGS.reparameterize_proposal,
        temperature=FLAGS.gst_temperature,
        dtype=tf.float32)
  elif proposal_type == "gaussian_vae":
    proposal = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=proposal_data_dim,
        # data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif proposal_type == "nis":
    proposal = nis.NIS(
        K=FLAGS.K,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        dtype=tf.float32)
  elif proposal_type == "rejection_sampling":
    logit_accept_fn = functools.partial(
        base.mlp,
        layer_sizes=energy_hidden_sizes + [1],
        final_activation=None,
        name="rejection_sampling/energy_fn_mlp")
    proposal = rejection_sampling.RejectionSampling(
        T=FLAGS.K,
        data_dim=[proposal_data_dim],
        logit_accept_fn=logit_accept_fn,
        dtype=tf.float32)
  elif proposal_type == "gaussian":
    proposal = tfd.MultivariateNormalDiag(
        loc=tf.zeros([proposal_data_dim], dtype=tf.float32),
        scale_diag=tf.ones([proposal_data_dim], dtype=tf.float32))
  elif proposal_type == "his":
    proposal = his.FullyConnectedHIS(
        T=FLAGS.his_T,
        data_dim=proposal_data_dim,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)

  if model_type == "bernoulli_vae":
    model = vae.BernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        prior=proposal,
        kl_weight=kl_weight,
        reparameterize_sample=False,
        dtype=tf.float32)
  elif model_type == "gaussian_vae":
    model = vae.GaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        decoder_hidden_sizes=decoder_hidden_sizes,
        decoder_nn_scale=FLAGS.vae_decoder_nn_scale,
        q_hidden_sizes=q_hidden_sizes,
        scale_min=FLAGS.scale_min,
        prior=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_gaussian_vae":
    model = vae.ConvGaussianVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        base_depth=FLAGS.base_depth,
        scale_min=FLAGS.scale_min,
        prior=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "conv_bernoulli_vae":
    model = vae.ConvBernoulliVAE(
        latent_dim=FLAGS.latent_dim,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        scale_min=FLAGS.scale_min,
        prior=proposal,
        kl_weight=kl_weight,
        dtype=tf.float32)
  elif model_type == "nis":
    model = nis.NIS(
        K=FLAGS.K,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        proposal=proposal,
        reparameterize_proposal_samples=FLAGS.reparameterize_proposal,
        dtype=tf.float32)
  elif model_type == "his":
    model = his.FullyConnectedHIS(
        proposal=proposal,
        T=FLAGS.his_T,
        data_dim=data_dim,
        data_mean=None if FLAGS.squash else mean,
        energy_hidden_sizes=energy_hidden_sizes,
        q_hidden_sizes=q_hidden_sizes,
        learn_temps=FLAGS.learn_his_temps,
        learn_stepsize=FLAGS.learn_his_stepsize,
        init_alpha=FLAGS.his_init_alpha,
        init_step_size=FLAGS.his_init_stepsize,
        dtype=tf.float32)


#  elif model_type == "hisvae":
#    model = his.HISVAE(
#        latent_dim=FLAGS.latent_dim,
#        proposal=proposal,
#        T=FLAGS.his_T,
#        data_dim=data_dim,
#        data_mean=mean,
#        energy_hidden_sizes=energy_hidden_sizes,
#        q_hidden_sizes=q_hidden_sizes,
#        decoder_hidden_sizes=decoder_hidden_sizes,
#        learn_temps=FLAGS.learn_his_temps,
#        learn_stepsize=FLAGS.learn_his_stepsize,
#        init_alpha=FLAGS.his_init_alpha,
#        init_step_size=FLAGS.his_init_stepsize,
#        squash=FLAGS.squash,
#        kl_weight=kl_weight,
#        dtype=tf.float32)

  if FLAGS.squash:
    model = base.SquashedDistribution(distribution=model, data_mean=mean)

  return model