def main(_): hparams = eval_lib.HParams(FLAGS.checkpoint_dir, FLAGS.eval_dir, FLAGS.dataset_dir, FLAGS.num_images_generated, FLAGS.eval_real_images, FLAGS.noise_dims, FLAGS.max_number_of_evaluations, FLAGS.write_to_disk) eval_lib.evaluate(hparams, run_eval_loop=True)
def test_build_graph(self, eval_real_images, mock_util, mock_provide_data): hparams = eval_lib.HParams( checkpoint_dir='/tmp/mnist/', eval_dir='/tmp/mnist/', dataset_dir=None, num_images_generated=1000, eval_real_images=eval_real_images, noise_dims=64, classifier_filename=None, max_number_of_evaluations=None, write_to_disk=True) # Mock input pipeline. bs = hparams.num_images_generated mock_imgs = np.zeros([bs, 28, 28, 1], dtype=np.float32) mock_lbls = np.concatenate((np.ones([bs, 1], dtype=np.int32), np.zeros([bs, 9], dtype=np.int32)), axis=1) mock_provide_data.return_value = (mock_imgs, mock_lbls) # Mock expensive eval metrics. mock_util.mnist_frechet_distance.return_value = 1.0 mock_util.mnist_score.return_value = 0.0 eval_lib.evaluate(hparams, run_eval_loop=False)