def make_lars_graph( target_dist, # pylint: disable=invalid-name K, batch_size, eval_batch_size, lr, mlp_layers, dtype=tf.float32): """Construct the training graph for LARS.""" model = lars.SimpleLARS(K=K, data_dim=2, accept_fn_layers=mlp_layers, dtype=dtype) train_data = target_dist.sample(batch_size) log_p, ema_op = model.log_prob(train_data) test_data = target_dist.sample(eval_batch_size) eval_log_p, eval_ema_op = model.log_prob(test_data) global_step = tf.train.get_or_create_global_step() opt = tf.train.AdamOptimizer(lr) grads = opt.compute_gradients(-tf.reduce_mean(log_p)) with tf.control_dependencies([ema_op, eval_ema_op]): apply_grads_op = opt.apply_gradients(grads, global_step=global_step) density_image_summary( lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x), FLAGS.density_num_bins, "energy/lars") tf.summary.scalar("elbo", tf.reduce_mean(log_p)) tf.summary.scalar("eval_elbo", tf.reduce_mean(eval_log_p)) return -tf.reduce_mean(log_p), apply_grads_op, global_step
def main(unused_argv): g = tf.Graph() with g.as_default(): energy_fn_layers = [ int(x.strip()) for x in FLAGS.energy_fn_sizes.split(",") ] if FLAGS.algo == "density": target = dists.get_target_distribution(FLAGS.target) plot = make_density_summary(target.log_prob, num_bins=FLAGS.num_bins) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: plot = sess.run(plot) with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"), "w") as out: np.save(out, plot) elif FLAGS.algo == "lars": tf.logging.info("Running LARS") proposal = base.get_independent_normal([2], FLAGS.proposal_variance) model = lars.SimpleLARS( K=FLAGS.K, data_dim=[2], accept_fn_layers=energy_fn_layers, proposal=proposal) plot = make_density_summary( lambda x: tf.squeeze(model.accept_fn(x)) + model.proposal.log_prob(x), num_bins=FLAGS.num_bins) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: plot = sess.run(plot) with tf.io.gfile.GFile(os.path.join(FLAGS.logdir, "density"), "w") as out: np.save(out, plot) else: proposal = base.get_independent_normal([2], FLAGS.proposal_variance) if FLAGS.algo == "nis": tf.logging.info("Running NIS") model = nis.NIS( K=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) elif FLAGS.algo == "his": tf.logging.info("Running HIS") model = his.FullyConnectedHIS( T=FLAGS.his_t, data_dim=[2], energy_hidden_sizes=energy_fn_layers, q_hidden_sizes=energy_fn_layers, init_step_size=FLAGS.his_stepsize, learn_stepsize=FLAGS.his_learn_stepsize, init_alpha=FLAGS.his_alpha, learn_temps=FLAGS.his_learn_alpha, proposal=proposal) elif FLAGS.algo == "rejection_sampling": model = rejection_sampling.RejectionSampling( T=FLAGS.K, data_dim=[2], energy_hidden_sizes=energy_fn_layers, proposal=proposal) samples = model.sample(FLAGS.batch_size) with tf.train.SingularMonitoredSession( checkpoint_dir=FLAGS.logdir) as sess: make_sample_density_summary( sess, samples, max_samples_per_batch=FLAGS.batch_size, num_samples=FLAGS.num_samples, num_bins=FLAGS.num_bins)