示例#1
0
def Eval(exp, sess, batch_size=256, total_batch=4096):
  log_dir = FLAGS.neutra_log_dir

  with tf.gfile.Open(os.path.join(log_dir, FLAGS.tune_outputs_name)) as f:
    tune_outputs = neutra.TuneOutputs(**simplejson.load(f))

  results = []
  for i in range(total_batch // batch_size):
    tf.logging.info("Evaluating batch %d", i)
    feed = {
        exp.test_num_leapfrog_steps: tune_outputs.num_leapfrog_steps,
        exp.test_step_size: tune_outputs.step_size,
        exp.test_chain_batch_size: batch_size,
    }
    res = exp.Eval(sess, feed=feed, p_accept_only=True)
    results.append(res)

  def classify(path):
    if "ess" in path:
      return lambda x: 1. / np.mean(1. / np.array(x), 0)
    else:
      return lambda x: np.mean(x, 0)

  avg_type = [classify("".join(str(p) for p in path)) for path in tf.contrib.framework.nest.yield_flat_paths(results[0])]
  flat_results = [tf.contrib.framework.nest.flatten(r) for r in results]
  trans_results = zip(*flat_results)
  trans_mean_results = [avg(r) for avg, r in zip(avg_type, trans_results)]
  neutra_stats, p_accept = tf.contrib.framework.nest.pack_sequence_as(
      results[0], trans_mean_results)

  utils.SaveJSON(neutra_stats, os.path.join(log_dir, "neutra_stats" + FLAGS.eval_suffix))
  utils.SaveJSON(p_accept, os.path.join(log_dir, "p_accept" + FLAGS.eval_suffix))
示例#2
0
def Train(exp, sess):
  log_dir = FLAGS.neutra_log_dir

  global_step = sess.run(exp.global_step)
  if global_step == 0:
    tf.logging.info("Training")
    q_stats, secs_per_step = exp.TrainBijector(sess)
    utils.SaveJSON(q_stats, os.path.join(log_dir, "q_stats" + FLAGS.eval_suffix))
    utils.SaveJSON(secs_per_step, os.path.join(log_dir, "secs_per_train_step" + FLAGS.eval_suffix))

  tf.logging.info("Tuning")
  tune_outputs = exp.Tune(
      sess,
      feed={exp.test_num_steps: 500})

  utils.SaveJSON(tune_outputs, os.path.join(log_dir, FLAGS.tune_outputs_name))
示例#3
0
def Benchmark(exp, sess):
  log_dir = FLAGS.neutra_log_dir

  with tf.gfile.Open(os.path.join(log_dir, FLAGS.tune_outputs_name)) as f:
    tune_outputs = neutra.TuneOutputs(**simplejson.load(f))

  tf.logging.info("Benchmarking")
  feed = {
      exp.test_num_leapfrog_steps: tune_outputs.num_leapfrog_steps,
      exp.test_step_size: tune_outputs.step_size,
      exp.test_num_steps: 100,
  }
  seconds_per_step = exp.Benchmark(sess, feed=feed)

  utils.SaveJSON(seconds_per_step, os.path.join(log_dir, "secs_per_hmc_step" + FLAGS.eval_suffix))