コード例 #1
0
ファイル: runner_lib.py プロジェクト: nupurkmr9/compare_gan
def _run_eval(module_spec,
              checkpoints,
              task_manager,
              run_config,
              use_tpu,
              num_averaging_runs,
              start_step=0):
    """Evaluates the given checkpoints and add results to a result writer.

  Args:
    module_spec: `ModuleSpec` of the model.
    checkpoints: Generator for for checkpoint paths.
    task_manager: `TaskManager`. init_eval() will be called before adding
      results.
    run_config: `RunConfig` to use. Values for master and tpu_config are
      currently ignored.
    use_tpu: Whether to use TPU for evaluation.
    num_averaging_runs: Determines how many times each metric is computed.
  """
    # By default, we compute FID and Inception scores. Other tasks defined in
    # the metrics folder (such as the one in metrics/kid_score.py) can be added
    # to this list if desired.
    eval_tasks = [
        inception_score_lib.InceptionScoreTask(),
        fid_score_lib.FIDScoreTask()
    ]
    logging.info("eval_tasks: %s", eval_tasks)

    for checkpoint_path in checkpoints:
        step = os.path.basename(checkpoint_path).split("-")[-1]
        # if step == 0:
        #   continue
        if int(step) <= int(start_step):
            continue
        export_path = os.path.join(run_config.model_dir, "tfhub", str(step))
        if not tf.gfile.Exists(export_path):
            module_spec.export(export_path, checkpoint_path=checkpoint_path)
        default_value = -1.0
        try:
            result_dict = eval_gan_lib.evaluate_tfhub_module(
                export_path,
                eval_tasks,
                use_tpu=use_tpu,
                num_averaging_runs=num_averaging_runs)
        except ValueError as nan_found_error:
            result_dict = {}
            logging.exception(nan_found_error)
            default_value = eval_gan_lib.NAN_DETECTED

        logging.info(
            "Evaluation result for checkpoint %s: %s (default value: %s)",
            checkpoint_path, result_dict, default_value)
        task_manager.add_eval_result(checkpoint_path, result_dict,
                                     default_value)
コード例 #2
0
def main(unused_argv):
    eval_tasks = [
        inception_score_lib.InceptionScoreTask(),
        fid_score_lib.FIDScoreTask()
    ]
    logging.info("eval_tasks: %s", eval_tasks)

    result_dict = eval_gan_lib.evaluate_tfhub_module(
        module_spec=dnnlib.util.unzip_from_url(FLAGS.tfhub_url),
        eval_tasks=eval_tasks,
        use_tpu=False,
        num_averaging_runs=FLAGS.num_eval_averaging_runs,
        update_bn_accumulators=False,
        use_tags=False,
    )
    logging.info("Evaluation result for checkpoint %s: %s", FLAGS.tfhub_url,
                 result_dict)
コード例 #3
0
    def test_end2end_checkpoint(self, architecture):
        """Takes real GAN (trained for 1 step) and evaluate it."""
        if architecture in {c.RESNET_STL_ARCH, c.RESNET30_ARCH}:
            # RESNET_STL_ARCH and RESNET107_ARCH do not support CIFAR image shape.
            return
        gin.bind_parameter("dataset.name", "cifar10")
        dataset = datasets.get_dataset("cifar10")
        options = {
            "architecture": architecture,
            "z_dim": 120,
            "disc_iters": 1,
            "lambda": 1,
        }
        model_dir = os.path.join(tf.test.get_temp_dir(), self.id())
        tf.logging.info("model_dir: %s" % model_dir)
        run_config = tf.contrib.tpu.RunConfig(model_dir=model_dir)
        gan = ModularGAN(dataset=dataset,
                         parameters=options,
                         conditional="biggan" in architecture,
                         model_dir=model_dir)
        estimator = gan.as_estimator(run_config, batch_size=2, use_tpu=False)
        estimator.train(input_fn=gan.input_fn, steps=1)
        export_path = os.path.join(model_dir, "tfhub")
        checkpoint_path = os.path.join(model_dir, "model.ckpt-1")
        module_spec = gan.as_module_spec()
        module_spec.export(export_path, checkpoint_path=checkpoint_path)

        eval_tasks = [
            fid_score.FIDScoreTask(),
            fractal_dimension.FractalDimensionTask(),
            inception_score.InceptionScoreTask(),
            ms_ssim_score.MultiscaleSSIMTask()
        ]
        result_dict = eval_gan_lib.evaluate_tfhub_module(export_path,
                                                         eval_tasks,
                                                         use_tpu=False,
                                                         num_averaging_runs=1)
        tf.logging.info("result_dict: %s", result_dict)
        for score in [
                "fid_score", "fractal_dimension", "inception_score", "ms_ssim"
        ]:
            for stats in ["mean", "std", "list"]:
                required_key = "%s_%s" % (score, stats)
                self.assertIn(required_key, result_dict,
                              "Missing: %s." % required_key)