def main(unused_argv): g = tf.Graph() with g.as_default(): energy_fn_layers = [ int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",") ] if FLAGS.algo == "density": target = dists.get_target_distribution(FLAGS.target) plot = make_density_summary(target.log_prob, num_bins=FLAGS.num_bins) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: plot = sess.run(plot) with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"), "w") as out: np.save(out, plot) elif FLAGS.algo == "lars": tf.logging.info("Running LARS") proposal = base.get_independent_normal([2], FLAGS.proposal_variance) model = lars.SimpleLARS( K=FLAGS.K, data_dim=[2], accept_fn_layers=energy_fn_layers, proposal=proposal) plot = make_density_summary( lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x), num_bins=FLAGS.num_bins) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: plot = sess.run(plot) with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"), "w") as out: np.save(out, plot) else: proposal = base.get_independent_normal([2], FLAGS.proposal_variance) if FLAGS.algo == "nis": tf.logging.info("Running NIS") model = nis.NIS( K=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) elif FLAGS.algo == "his": tf.logging.info("Running HIS") model = his.FullyConnectedHIS( T=FLAGS.his_t, data_dim=[2], energy_hidden_sizes=energy_fn_layers, q_hidden_sizes=energy_fn_layers, init_step_size=FLAGS.his_stepsize, learn_stepsize=FLAGS.his_learn_stepsize, init_alpha=FLAGS.his_alpha, learn_temps=FLAGS.his_learn_alpha, proposal=proposal) elif FLAGS.algo == "rejection_sampling": model = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) samples = model.sample(FLAGS.batch_size) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: make_sample_density_summary( sess, samples, max_samples_per_batch=FLAGS.batch_size, num_samples=FLAGS.num_samples, num_bins=FLAGS.num_bins)
def main(unused_argv): g = tf.Graph() with g.as_default(): target = dists.get_target_distribution( FLAGS.target, nine_gaussians_variance=FLAGS.nine_gaussians_variance) energy_fn_layers = [ int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",") ] if FLAGS.algo == "lars": print("Running LARS") loss, train_op, global_step = make_lars_graph( target_dist=target, K=FLAGS.K, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, lr=FLAGS.learning_rate, mlp_layers=energy_fn_layers, dtype=tf.float32) else: if FLAGS.algo == "nis": print("Running NIS") model = nis.NIS(K=FLAGS.K, data_dim=2, energy_hidden_sizes=energy_fn_layers) density_image_summary( lambda x: # pylint: disable=g-long-lambda (tf.squeeze(model.energy_fn(x)) + model.proposal.log_prob( x)), FLAGS.density_num_bins, "energy/nis") elif FLAGS.algo == "rejection_sampling": print("Running Rejection Sampling") logit_accept_fn = tf.keras.Sequential([ tf.keras.layers.Dense(layer_size, activation="tanh") for layer_size in energy_fn_layers ] + [tf.keras.layers.Dense(1, activation=None)]) model = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=[2], logit_accept_fn=logit_accept_fn) density_image_summary( lambda x: tf.squeeze( # pylint: disable=g-long-lambda tf.log_sigmoid(model.logit_accept_fn(x)), axis=-1) + model.proposal.log_prob(x), FLAGS.density_num_bins, "energy/trs") elif FLAGS.algo == "his": print("Running HIS") model = his.FullyConnectedHIS( T=FLAGS.his_t, data_dim=2, energy_hidden_sizes=energy_fn_layers, q_hidden_sizes=energy_fn_layers, init_step_size=FLAGS.his_stepsize, learn_stepsize=FLAGS.his_learn_stepsize, init_alpha=FLAGS.his_alpha, learn_temps=FLAGS.his_learn_alpha) density_image_summary( lambda x: -model.hamiltonian_potential(x), FLAGS.density_num_bins, "energy/his") sample_image_summary(model, "density/his", num_samples=100000, num_bins=50) loss, train_op, global_step = make_train_graph( target_dist=target, model=model, batch_size=FLAGS.batch_size, eval_batch_size=FLAGS.eval_batch_size, lr=FLAGS.learning_rate) log_hooks = make_log_hooks(global_step, loss) with tf.train.MonitoredTrainingSession( master="", is_chief=True, hooks=log_hooks, checkpoint_dir=os.path.join(FLAGS.logdir, exp_name()), save_checkpoint_secs=120, save_summaries_steps=FLAGS.summarize_every, log_step_count_steps=FLAGS.summarize_every) as sess: cur_step = -1 while True: if sess.should_stop() or cur_step > FLAGS.max_steps: break # run a step _, cur_step = sess.run([train_op, global_step])