示例#1
0
def main(_):
    hparams = eval_lib.HParams(FLAGS.checkpoint_dir, FLAGS.eval_dir,
                               FLAGS.dataset_dir, FLAGS.num_images_generated,
                               FLAGS.eval_real_images, FLAGS.noise_dims,
                               FLAGS.max_number_of_evaluations,
                               FLAGS.write_to_disk)
    eval_lib.evaluate(hparams, run_eval_loop=True)
示例#2
0
  def test_build_graph(self, eval_real_images, mock_util, mock_provide_data):
    hparams = eval_lib.HParams(
        checkpoint_dir='/tmp/mnist/',
        eval_dir='/tmp/mnist/',
        dataset_dir=None,
        num_images_generated=1000,
        eval_real_images=eval_real_images,
        noise_dims=64,
        classifier_filename=None,
        max_number_of_evaluations=None,
        write_to_disk=True)

    # Mock input pipeline.
    bs = hparams.num_images_generated
    mock_imgs = np.zeros([bs, 28, 28, 1], dtype=np.float32)
    mock_lbls = np.concatenate((np.ones([bs, 1], dtype=np.int32),
                                np.zeros([bs, 9], dtype=np.int32)),
                               axis=1)
    mock_provide_data.return_value = (mock_imgs, mock_lbls)

    # Mock expensive eval metrics.
    mock_util.mnist_frechet_distance.return_value = 1.0
    mock_util.mnist_score.return_value = 0.0

    eval_lib.evaluate(hparams, run_eval_loop=False)