Пример #1
0
def run_train():
  """Run the training loop."""
  print_flags()
  g = tf.Graph()
  with g.as_default():
    global_step = tf.train.get_or_create_global_step()
    data_batch, mean, _ = datasets.get_dataset(
        FLAGS.dataset,
        batch_size=FLAGS.batch_size,
        split=FLAGS.split,
        repeat=True,
        shuffle=True)
    data_dim = data_batch.get_shape().as_list()[1:]
    model = make_model(FLAGS.proposal, FLAGS.model, data_dim, mean, global_step)
    elbo = model.log_prob(data_batch)
    sample_summary(model, data_dim)
    # Finish constructing the graph
    elbo_avg = tf.reduce_mean(elbo)
    tf.summary.scalar("elbo", elbo_avg)
    if FLAGS.decay_lr:
      lr = tf.train.piecewise_constant(
          global_step, [int(1e6)],
          [FLAGS.learning_rate, FLAGS.learning_rate / 3.])
    else:
      lr = FLAGS.learning_rate
    tf.summary.scalar("learning_rate", lr)
    opt = tf.train.AdamOptimizer(learning_rate=lr)
    grads = opt.compute_gradients(-elbo_avg)
    opt_op = opt.apply_gradients(grads, global_step=global_step)
    # Some models require updates after the training step
    if hasattr(model, "post_train_op"):
      with tf.control_dependencies([opt_op]):
        train_op = model.post_train_op()
    else:
      train_op = opt_op
    log_hooks = make_log_hooks(global_step, elbo_avg)
    with tf.train.MonitoredTrainingSession(
        master="",
        is_chief=True,
        hooks=log_hooks,
        checkpoint_dir=os.path.join(FLAGS.logdir, exp_name()),
        save_checkpoint_steps=FLAGS.summarize_every * 2,
        save_summaries_steps=FLAGS.summarize_every,
        log_step_count_steps=FLAGS.summarize_every) as sess:
      cur_step = -1
      while cur_step <= FLAGS.max_steps and not sess.should_stop():
        _, cur_step = sess.run([train_op, global_step])
Пример #2
0
def run_eval():
  """Runs the eval loop."""
  print_flags()
  g = tf.Graph()
  with g.as_default():
    # If running eval, do not anneal the KL.
    FLAGS.anneal_kl_step = -1
    global_step = tf.train.get_or_create_global_step()
    summary_dir = os.path.join(FLAGS.logdir, exp_name(), "eval")
    summary_writer = tf.summary.FileWriter(
        summary_dir, flush_secs=15, max_queue=100)

    splits = FLAGS.split.split(",")
    for split in splits:
      assert split in ["train", "test", "valid"]

    num_iwae_samples = [
        int(x.strip()) for x in FLAGS.num_iwae_samples.split(",")
    ]
    assert len(num_iwae_samples) == 1 or len(num_iwae_samples) == len(splits)
    if len(num_iwae_samples) == 1:
      num_iwae_samples = num_iwae_samples * len(splits)

    bound_names = []
    for ns in num_iwae_samples:
      if ns > 1:
        bound_names.append("iwae_%d" % ns)
      else:
        bound_names.append("elbo")

    itrs = []
    batch_sizes = []
    elbos = []
    model = None
    lars_Z_op = None  # pylint: disable=invalid-name
    lars_Z_ph = None  # pylint: disable=invalid-name
    for split, num_samples in zip(splits, num_iwae_samples):
      data_batch, mean, itr = datasets.get_dataset(
          FLAGS.dataset,
          batch_size=FLAGS.batch_size,
          split=split,
          repeat=False,
          shuffle=False,
          initializable=True)
      itrs.append(itr)
      batch_sizes.append(tf.shape(data_batch)[0])
      if model is None:
        data_dim = data_batch.get_shape().as_list()[1:]
        model = make_model(FLAGS.proposal, FLAGS.model, data_dim, mean,
                           global_step)
      elbos.append(
          tf.reduce_sum(model.log_prob(data_batch, num_samples=num_samples)))
      if FLAGS.model == "lars" or FLAGS.proposal == "lars":
        lars_Z_op, lars_Z_ph = make_lars_Z_ops(model)  # pylint: disable=invalid-name

    saver = tf.train.Saver()
    prev_evaluated_step = -1
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.train.SingularMonitoredSession(config=config) as sess:
      while True:
        wait_for_checkpoint(saver, sess, os.path.join(FLAGS.logdir, exp_name()))
        step = sess.run(global_step)
        tf.logging.info("Model restored from step %d." % step)
        if step == prev_evaluated_step:
          tf.logging.info("Already evaluated checkpoint at step %d, sleeping" %
                          step)
          time.sleep(30)
          continue

        Z_estimate = (estimate_Z_lars(lars_Z_op, sess)  # pylint: disable=invalid-name
                      if FLAGS.model == "lars" or FLAGS.proposal == "lars"
                      else None)

        for i in range(len(splits)):
          sess.run(itrs[i].initializer)
          avg_elbo = average_elbo_over_dataset(
              elbos[i],
              batch_sizes[i],
              sess,
              Z_estimate=Z_estimate,
              Z_estimate_ph=lars_Z_ph)
          value = tf.Summary.Value(
              tag="%s_%s" % (splits[i], bound_names[i]), simple_value=avg_elbo)
          summary = tf.Summary(value=[value])
          summary_writer.add_summary(summary, global_step=step)
          tf.logging.info("Step %d, %s %s: %f" %
                          (step, splits[i], bound_names[i], avg_elbo))
        prev_evaluated_step = step