예제 #1
0
def main(unused_argv):
  g = tf.Graph()
  with g.as_default():
    energy_fn_layers = [
        int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",")
    ]
    if FLAGS.algo == "density":
      target = dists.get_target_distribution(FLAGS.target)
      plot = make_density_summary(target.log_prob, num_bins=FLAGS.num_bins)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        plot = sess.run(plot)
        with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"),
                               "w") as out:
          np.save(out, plot)
    elif FLAGS.algo == "lars":
      tf.logging.info("Running LARS")
      proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
      model = lars.SimpleLARS(
          K=FLAGS.K, data_dim=[2], accept_fn_layers=energy_fn_layers,
          proposal=proposal)
      plot = make_density_summary(
          lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x),
          num_bins=FLAGS.num_bins)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        plot = sess.run(plot)
        with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"),
                               "w") as out:
          np.save(out, plot)
    else:
      proposal = base.get_independent_normal([2], FLAGS.proposal_variance)
      if FLAGS.algo == "nis":
        tf.logging.info("Running NIS")
        model = nis.NIS(
            K=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers,
            proposal=proposal)
      elif FLAGS.algo == "his":
        tf.logging.info("Running HIS")
        model = his.FullyConnectedHIS(
            T=FLAGS.his_t,
            data_dim=[2],
            energy_hidden_sizes=energy_fn_layers,
            q_hidden_sizes=energy_fn_layers,
            init_step_size=FLAGS.his_stepsize,
            learn_stepsize=FLAGS.his_learn_stepsize,
            init_alpha=FLAGS.his_alpha,
            learn_temps=FLAGS.his_learn_alpha,
            proposal=proposal)
      elif FLAGS.algo == "rejection_sampling":
        model = rejection_sampling.RejectionSampling(
            T=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers,
            proposal=proposal)
      samples = model.sample(FLAGS.batch_size)
      with tf.train.SingularMonitoredSession(
          checkpoint_dir=FLAGS.logdir) as sess:
        make_sample_density_summary(
            sess,
            samples,
            max_samples_per_batch=FLAGS.batch_size,
            num_samples=FLAGS.num_samples,
            num_bins=FLAGS.num_bins)
예제 #2
0
def main(unused_argv):
    g = tf.Graph()
    with g.as_default():
        target = dists.get_target_distribution(
            FLAGS.target,
            nine_gaussians_variance=FLAGS.nine_gaussians_variance)
        energy_fn_layers = [
            int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",")
        ]
        if FLAGS.algo == "lars":
            print("Running LARS")
            loss, train_op, global_step = make_lars_graph(
                target_dist=target,
                K=FLAGS.K,
                batch_size=FLAGS.batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                lr=FLAGS.learning_rate,
                mlp_layers=energy_fn_layers,
                dtype=tf.float32)
        else:
            if FLAGS.algo == "nis":
                print("Running NIS")
                model = nis.NIS(K=FLAGS.K,
                                data_dim=2,
                                energy_hidden_sizes=energy_fn_layers)
                density_image_summary(
                    lambda x:  # pylint: disable=g-long-lambda
                    (tf.squeeze(model.energy_fn(x)) + model.proposal.log_prob(
                        x)),
                    FLAGS.density_num_bins,
                    "energy/nis")
            elif FLAGS.algo == "rejection_sampling":
                print("Running Rejection Sampling")
                logit_accept_fn = tf.keras.Sequential([
                    tf.keras.layers.Dense(layer_size, activation="tanh")
                    for layer_size in energy_fn_layers
                ] + [tf.keras.layers.Dense(1, activation=None)])
                model = rejection_sampling.RejectionSampling(
                    T=FLAGS.K, data_dim=[2], logit_accept_fn=logit_accept_fn)
                density_image_summary(
                    lambda x: tf.squeeze(  # pylint: disable=g-long-lambda
                        tf.log_sigmoid(model.logit_accept_fn(x)),
                        axis=-1) + model.proposal.log_prob(x),
                    FLAGS.density_num_bins,
                    "energy/trs")
            elif FLAGS.algo == "his":
                print("Running HIS")
                model = his.FullyConnectedHIS(
                    T=FLAGS.his_t,
                    data_dim=2,
                    energy_hidden_sizes=energy_fn_layers,
                    q_hidden_sizes=energy_fn_layers,
                    init_step_size=FLAGS.his_stepsize,
                    learn_stepsize=FLAGS.his_learn_stepsize,
                    init_alpha=FLAGS.his_alpha,
                    learn_temps=FLAGS.his_learn_alpha)
                density_image_summary(
                    lambda x: -model.hamiltonian_potential(x),
                    FLAGS.density_num_bins, "energy/his")
                sample_image_summary(model,
                                     "density/his",
                                     num_samples=100000,
                                     num_bins=50)

            loss, train_op, global_step = make_train_graph(
                target_dist=target,
                model=model,
                batch_size=FLAGS.batch_size,
                eval_batch_size=FLAGS.eval_batch_size,
                lr=FLAGS.learning_rate)

        log_hooks = make_log_hooks(global_step, loss)
        with tf.train.MonitoredTrainingSession(
                master="",
                is_chief=True,
                hooks=log_hooks,
                checkpoint_dir=os.path.join(FLAGS.logdir, exp_name()),
                save_checkpoint_secs=120,
                save_summaries_steps=FLAGS.summarize_every,
                log_step_count_steps=FLAGS.summarize_every) as sess:
            cur_step = -1
            while True:
                if sess.should_stop() or cur_step > FLAGS.max_steps:
                    break
                # run a step
                _, cur_step = sess.run([train_op, global_step])