def run_train(): """Run the training loop.""" print_flags() g = tf.Graph() with g.as_default(): global_step = tf.train.get_or_create_global_step() data_batch, mean, _ = datasets.get_dataset( FLAGS.dataset, batch_size=FLAGS.batch_size, split=FLAGS.split, repeat=True, shuffle=True) data_dim = data_batch.get_shape().as_list()[1:] model = make_model(FLAGS.proposal, FLAGS.model, data_dim, mean, global_step) elbo = model.log_prob(data_batch) sample_summary(model, data_dim) # Finish constructing the graph elbo_avg = tf.reduce_mean(elbo) tf.summary.scalar("elbo", elbo_avg) if FLAGS.decay_lr: lr = tf.train.piecewise_constant( global_step, [int(1e6)], [FLAGS.learning_rate, FLAGS.learning_rate / 3.]) else: lr = FLAGS.learning_rate tf.summary.scalar("learning_rate", lr) opt = tf.train.AdamOptimizer(learning_rate=lr) grads = opt.compute_gradients(-elbo_avg) opt_op = opt.apply_gradients(grads, global_step=global_step) # Some models require updates after the training step if hasattr(model, "post_train_op"): with tf.control_dependencies([opt_op]): train_op = model.post_train_op() else: train_op = opt_op log_hooks = make_log_hooks(global_step, elbo_avg) with tf.train.MonitoredTrainingSession( master="", is_chief=True, hooks=log_hooks, checkpoint_dir=os.path.join(FLAGS.logdir, exp_name()), save_checkpoint_steps=FLAGS.summarize_every * 2, save_summaries_steps=FLAGS.summarize_every, log_step_count_steps=FLAGS.summarize_every) as sess: cur_step = -1 while cur_step <= FLAGS.max_steps and not sess.should_stop(): _, cur_step = sess.run([train_op, global_step])
def run_eval(): """Runs the eval loop.""" print_flags() g = tf.Graph() with g.as_default(): # If running eval, do not anneal the KL. FLAGS.anneal_kl_step = -1 global_step = tf.train.get_or_create_global_step() summary_dir = os.path.join(FLAGS.logdir, exp_name(), "eval") summary_writer = tf.summary.FileWriter( summary_dir, flush_secs=15, max_queue=100) splits = FLAGS.split.split(",") for split in splits: assert split in ["train", "test", "valid"] num_iwae_samples = [ int(x.strip()) for x in FLAGS.num_iwae_samples.split(",") ] assert len(num_iwae_samples) == 1 or len(num_iwae_samples) == len(splits) if len(num_iwae_samples) == 1: num_iwae_samples = num_iwae_samples * len(splits) bound_names = [] for ns in num_iwae_samples: if ns > 1: bound_names.append("iwae_%d" % ns) else: bound_names.append("elbo") itrs = [] batch_sizes = [] elbos = [] model = None lars_Z_op = None # pylint: disable=invalid-name lars_Z_ph = None # pylint: disable=invalid-name for split, num_samples in zip(splits, num_iwae_samples): data_batch, mean, itr = datasets.get_dataset( FLAGS.dataset, batch_size=FLAGS.batch_size, split=split, repeat=False, shuffle=False, initializable=True) itrs.append(itr) batch_sizes.append(tf.shape(data_batch)[0]) if model is None: data_dim = data_batch.get_shape().as_list()[1:] model = make_model(FLAGS.proposal, FLAGS.model, data_dim, mean, global_step) elbos.append( tf.reduce_sum(model.log_prob(data_batch, num_samples=num_samples))) if FLAGS.model == "lars" or FLAGS.proposal == "lars": lars_Z_op, lars_Z_ph = make_lars_Z_ops(model) # pylint: disable=invalid-name saver = tf.train.Saver() prev_evaluated_step = -1 config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.train.SingularMonitoredSession(config=config) as sess: while True: wait_for_checkpoint(saver, sess, os.path.join(FLAGS.logdir, exp_name())) step = sess.run(global_step) tf.logging.info("Model restored from step %d." % step) if step == prev_evaluated_step: tf.logging.info("Already evaluated checkpoint at step %d, sleeping" % step) time.sleep(30) continue Z_estimate = (estimate_Z_lars(lars_Z_op, sess) # pylint: disable=invalid-name if FLAGS.model == "lars" or FLAGS.proposal == "lars" else None) for i in range(len(splits)): sess.run(itrs[i].initializer) avg_elbo = average_elbo_over_dataset( elbos[i], batch_sizes[i], sess, Z_estimate=Z_estimate, Z_estimate_ph=lars_Z_ph) value = tf.Summary.Value( tag="%s_%s" % (splits[i], bound_names[i]), simple_value=avg_elbo) summary = tf.Summary(value=[value]) summary_writer.add_summary(summary, global_step=step) tf.logging.info("Step %d, %s %s: %f" % (step, splits[i], bound_names[i], avg_elbo)) prev_evaluated_step = step