示例#1
0
def make_lars_graph(
        target_dist,  # pylint: disable=invalid-name
        K,
        batch_size,
        eval_batch_size,
        lr,
        mlp_layers,
        dtype=tf.float32):
    """Construct the training graph for LARS."""
    model = lars.SimpleLARS(K=K,
                            data_dim=2,
                            accept_fn_layers=mlp_layers,
                            dtype=dtype)

    train_data = target_dist.sample(batch_size)
    log_p, ema_op = model.log_prob(train_data)
    test_data = target_dist.sample(eval_batch_size)
    eval_log_p, eval_ema_op = model.log_prob(test_data)

    global_step = tf.train.get_or_create_global_step()
    opt = tf.train.AdamOptimizer(lr)
    grads = opt.compute_gradients(-tf.reduce_mean(log_p))
    with tf.control_dependencies([ema_op, eval_ema_op]):
        apply_grads_op = opt.apply_gradients(grads, global_step=global_step)

    density_image_summary(
        lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x),
        FLAGS.density_num_bins, "energy/lars")
    tf.summary.scalar("elbo", tf.reduce_mean(log_p))
    tf.summary.scalar("eval_elbo", tf.reduce_mean(eval_log_p))
    return -tf.reduce_mean(log_p), apply_grads_op, global_step
示例#2
0
def main(unused_argv):
  g = tf.Graph()
  with g.as_default():
    energy_fn_layers = [
        int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",")
    ]
    if FLAGS.algo == "density":
      target = dists.get_target_distribution(FLAGS.target)
      plot = make_density_summary(target.log_prob, num_bins=FLAGS.num_bins)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        plot = sess.run(plot)
        with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"),
                               "w") as out:
          np.save(out, plot)
    elif FLAGS.algo == "lars":
      tf.logging.info("Running LARS")
      proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
      model = lars.SimpleLARS(
          K=FLAGS.K, data_dim=[2], accept_fn_layers=energy_fn_layers,
          proposal=proposal)
      plot = make_density_summary(
          lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x),
          num_bins=FLAGS.num_bins)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        plot = sess.run(plot)
        with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"),
                               "w") as out:
          np.save(out, plot)
    else:
      proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
      if FLAGS.algo == "nis":
        tf.logging.info("Running NIS")
        model = nis.NIS(
            K=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers,
            proposal=proposal)
      elif FLAGS.algo == "his":
        tf.logging.info("Running HIS")
        model = his.FullyConnectedHIS(
            T=FLAGS.his_t,
            data_dim=[2],
            energy_hidden_sizes=energy_fn_layers,
            q_hidden_sizes=energy_fn_layers,
            init_step_size=FLAGS.his_stepsize,
            learn_stepsize=FLAGS.his_learn_stepsize,
            init_alpha=FLAGS.his_alpha,
            learn_temps=FLAGS.his_learn_alpha,
            proposal=proposal)
      elif FLAGS.algo == "rejection_sampling":
        model = rejection_sampling.RejectionSampling(
            T=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers,
            proposal=proposal)
      samples = model.sample(FLAGS.batch_size)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        make_sample_density_summary(
            sess,
            samples,
            max_samples_per_batch=FLAGS.batch_size,
            num_samples=FLAGS.num_samples,
            num_bins=FLAGS.num_bins)