def Eval(exp, batch_size=256, total_batch=4096): log_dir = FLAGS.neutra_log_dir tf.io.gfile.makedirs(log_dir) tune_outputs = utils.load_json( os.path.join(log_dir, FLAGS.tune_outputs_name)) if isinstance(tune_outputs, dict): tune_outputs = neutra.TuneOutputs(**tune_outputs) results = [] for i in range(total_batch // batch_size): logging.info("Evaluating batch %d", i) res = exp.Eval( test_num_leapfrog_steps=tune_outputs.num_leapfrog_steps, test_step_size=tune_outputs.step_size, batch_size=batch_size, ) def to_numpy(t): if isinstance(t, tf.Tensor): return t.numpy() else: return t res = tf.nest.map_structure(to_numpy, res) results.append(res) neutra_stats = neutra.AverageStats(results) utils.save_json(neutra_stats, os.path.join(log_dir, "neutra_stats" + FLAGS.eval_suffix))
def Benchmark(exp): log_dir = FLAGS.neutra_log_dir tf.io.gfile.makedirs(log_dir) tune_outputs = utils.load_json(os.path.join(log_dir, FLAGS.tune_outputs_name)) if isinstance(tune_outputs, dict): tune_outputs = neutra.TuneOutputs(**tune_outputs) logging.info("Benchmarking") benchmark = exp.Benchmark( test_num_leapfrog_steps=tune_outputs.num_leapfrog_steps, test_step_size=tune_outputs.step_size, test_num_steps=100, test_batch_size=16384 * 8, ) utils.save_json(benchmark, os.path.join(log_dir, "bechmark" + FLAGS.eval_suffix))