コード例 #1
0
def Eval(exp, batch_size=256, total_batch=4096):
    log_dir = FLAGS.neutra_log_dir
    tf.io.gfile.makedirs(log_dir)

    tune_outputs = utils.load_json(
        os.path.join(log_dir, FLAGS.tune_outputs_name))
    if isinstance(tune_outputs, dict):
        tune_outputs = neutra.TuneOutputs(**tune_outputs)

    results = []
    for i in range(total_batch // batch_size):
        logging.info("Evaluating batch %d", i)
        res = exp.Eval(
            test_num_leapfrog_steps=tune_outputs.num_leapfrog_steps,
            test_step_size=tune_outputs.step_size,
            batch_size=batch_size,
        )

        def to_numpy(t):
            if isinstance(t, tf.Tensor):
                return t.numpy()
            else:
                return t

        res = tf.nest.map_structure(to_numpy, res)
        results.append(res)

    neutra_stats = neutra.AverageStats(results)

    utils.save_json(neutra_stats,
                    os.path.join(log_dir, "neutra_stats" + FLAGS.eval_suffix))
コード例 #2
0
def Train(exp):
    log_dir = (FLAGS.checkpoint_log_dir
               if FLAGS.checkpoint_log_dir else FLAGS.neutra_log_dir)
    logging.info("Training")
    q_stats, secs_per_step = exp.Train()
    tf.io.gfile.makedirs(log_dir)
    utils.save_json(q_stats, os.path.join(log_dir, "q_stats"))
    utils.save_json(secs_per_step, os.path.join(log_dir,
                                                "secs_per_train_step"))
コード例 #3
0
def Benchmark(exp):
  log_dir = FLAGS.neutra_log_dir
  tf.io.gfile.makedirs(log_dir)

  tune_outputs = utils.load_json(os.path.join(log_dir, FLAGS.tune_outputs_name))
  if isinstance(tune_outputs, dict):
    tune_outputs = neutra.TuneOutputs(**tune_outputs)

  logging.info("Benchmarking")
  benchmark = exp.Benchmark(
      test_num_leapfrog_steps=tune_outputs.num_leapfrog_steps,
      test_step_size=tune_outputs.step_size,
      test_num_steps=100,
      test_batch_size=16384 * 8,
  )

  utils.save_json(benchmark,
                  os.path.join(log_dir, "bechmark" + FLAGS.eval_suffix))
コード例 #4
0
def TuneObjective(exp):
    log_dir = FLAGS.neutra_log_dir
    tf.io.gfile.makedirs(log_dir)
    objective = exp.TuneObjective()
    utils.save_json(objective, os.path.join(log_dir, "objective"))