Beispiel #1
0
def RunCheckpointEval(checkpoint_path, task_workdir, options, tasks_to_run):
    """Evaluate model at given checkpoint_path.

  Args:
    checkpoint_path: string, path to the single checkpoint to evaluate.
    task_workdir: directory, where results and logs can be written.
    options: Dict[Text, Text] with all parameters for the current trial.
    tasks_to_run: List of objects that inherit from EvalTask.

  Returns:
    Dict[Text, float] with all the computed results.

  Raises:
    NanFoundError: If gan output has generated any NaNs.
    ValueError: If options["gan_type"] is not supported.
  """

    # Make sure that the same latent variables are used for each evaluation.
    np.random.seed(42)

    checkpoint_dir = os.path.join(task_workdir, "checkpoint")
    result_dir = os.path.join(task_workdir, "result")
    gan_log_dir = os.path.join(task_workdir, "logs")

    gan_type = options["gan_type"]
    if gan_type not in SUPPORTED_GANS:
        raise ValueError("Gan type %s is not supported." % gan_type)

    dataset = options["dataset"]
    dataset_params = params.GetDatasetParameters(dataset)
    dataset_params.update(options)
    num_test_examples = dataset_params.get("eval_test_samples", 10000)

    if num_test_examples % INCEPTION_BATCH != 0:
        logging.info("Padding number of examples to fit inception batch.")
        num_test_examples -= num_test_examples % INCEPTION_BATCH

    real_images = GetRealImages(options["dataset"],
                                split_name="test",
                                num_examples=num_test_examples)
    logging.info("Real data processed.")

    result_dict = {}
    # Get Fake images from the generator.
    samples = []
    logging.info(
        "Running eval for dataset %s, checkpoint: %s, num_examples: %d ",
        dataset, checkpoint_path, num_test_examples)
    with tf.Graph().as_default():
        with tf.Session() as sess:
            gan = gan_lib.create_gan(gan_type=gan_type,
                                     dataset=dataset,
                                     dataset_content=None,
                                     options=options,
                                     checkpoint_dir=checkpoint_dir,
                                     result_dir=result_dir,
                                     gan_log_dir=gan_log_dir)

            gan.build_model(is_training=False)

            tf.global_variables_initializer().run()
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_path)

            # Make sure we have >= examples as in the test set.
            num_batches = int(np.ceil(num_test_examples / gan.batch_size))
            for _ in range(num_batches):
                z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
                x = sess.run(gan.fake_images, feed_dict={gan.z: z_sample})
                # If NaNs were generated, ignore this checkpoint and assign a very high
                # FID score which we handle specially later.
                if np.isnan(x).any():
                    logging.error(
                        "Detected NaN in fake_images! Returning NaN.")
                    raise NanFoundError("Detected NaN in fake images.")
                samples.append(x)

            print("Fake data generated, running tasks in session.")
            for task in tasks_to_run:
                result_dict.update(
                    task.RunInSession(options, sess, gan, real_images))

    fake_images = np.concatenate(samples, axis=0)
    # Adjust the number of fake images to the number of images in the test set.
    fake_images = fake_images[:num_test_examples, :, :, :]

    assert fake_images.shape == real_images.shape

    # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel.
    if fake_images.shape[3] == 1:
        fake_images = np.tile(fake_images, [1, 1, 1, 3])
        # change the real_images' shape too - so that it keep matching
        # fake_images' shape.
        real_images = np.tile(real_images, [1, 1, 1, 3])

    fake_images *= 255.0

    logging.info("Fake data processed. Starting tasks for checkpoint: %s.",
                 checkpoint_path)

    for task in tasks_to_run:
        result_dict.update(
            task.RunAfterSession(options, fake_images, real_images))

    return result_dict
Beispiel #2
0
def EvalCheckpoint(checkpoint_path, task_workdir, options, out_cp_dir):
  """Evaluate model at given checkpoint_path."""

  # Overwrite batch size
  options["batch_size"] = FLAGS.batch_size

  checkpoint_dir = os.path.join(task_workdir, "checkpoint")
  result_dir = os.path.join(task_workdir, "result")
  gan_log_dir = os.path.join(task_workdir, "logs")

  dataset_params = params.GetDatasetParameters(options["dataset"])
  dataset_params.update(options)

  # generate fake images
  with tf.Graph().as_default() as g:
    with tf.Session() as sess:
      gan = gan_lib.create_gan(
          gan_type=options["gan_type"],
          dataset=options["dataset"],
          dataset_content=None,
          options=options,
          checkpoint_dir=checkpoint_dir,
          result_dir=result_dir,
          gan_log_dir=gan_log_dir)

      gan.build_model(is_training=False)

      tf.global_variables_initializer().run()
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)

      # Compute outputs for MultiGanGeneratorImages.
      if (FLAGS.visualization_type == "multi_image" and
          "MultiGAN" in options["gan_type"]):
        generator_preds_op = GetMultiGANGeneratorsOp(
            g, options["gan_type"], options["architecture"],
            options["aggregate"])

        fetches = [gan.fake_images, generator_preds_op]

        # Construct feed dict
        z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
        feed_dict = {gan.z: z_sample}

        # Fetch data and save images.
        fake_images, generator_preds = sess.run(fetches, feed_dict=feed_dict)
        SaveMultiGanGeneratorImages(fake_images, generator_preds, out_cp_dir)

      # Compute outputs for MultiGanLatentTraversalImages
      elif (FLAGS.visualization_type == "multi_latent" and
            "MultiGAN" in options["gan_type"]):
        generator_preds_op = GetMultiGANGeneratorsOp(
            g, options["gan_type"], options["architecture"],
            options["aggregate"])

        fetches = [gan.fake_images, generator_preds_op]

        # Init latent params
        z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
        directions = np.random.uniform(size=z_sample.shape)
        k_indices = np.random.randint(gan.k, size=gan.batch_size)
        n_steps, step_size = 10, 0.1
        images, gen_preds = [], []

        # Traverse in latent space of a single component n_steps times and
        # generate the corresponding images.
        for step in range(n_steps + 1):
          new_z = z_sample.copy()
          for i in range(z_sample.shape[0]):
            new_z[i, k_indices[i]] += (
                step * step_size * directions[i, k_indices[i]])

          images_batch, gen_preds_batch = sess.run(fetches, {gan.z: new_z})
          images.append(images_batch)
          gen_preds.append(gen_preds_batch)

        images = np.stack(images, axis=1)
        gen_preds = np.stack(gen_preds, axis=1)
        SaveMultiGanLatentTraversalImages(images, gen_preds, out_cp_dir)

      # Compute outputs for GanLatentTraversalImages
      elif FLAGS.visualization_type == "latent":
        # Init latent params.
        z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
        directions = np.random.uniform(size=z_sample.shape)
        k_indices = np.random.randint(options.get("k", 1), size=gan.batch_size)
        n_steps, step_size = 5, 0.1
        images = []

        # Traverse in latent space of a single component n_steps times and
        # generate the corresponding images.
        for step in range(n_steps + 1):
          new_z = z_sample.copy()
          for i in range(z_sample.shape[0]):
            if "MultiGAN" in options["gan_type"]:
              new_z[i, k_indices[i]] += (
                  step * step_size * directions[i, k_indices[i]])
            else:
              new_z[i] += step * step_size * directions[i]

          images_batch = sess.run(gan.fake_images, {gan.z: new_z})
          images.append(images_batch)

        images = np.stack(images, axis=1)
        SaveGanLatentTraversalImages(images, out_cp_dir)
Beispiel #3
0
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph):
  """Evaluate model at given checkpoint_path."""

  # Make sure that the same latent variables are used for each evaluation.
  np.random.seed(42)

  checkpoint_dir = os.path.join(task_workdir, "checkpoint")
  result_dir = os.path.join(task_workdir, "result")
  gan_log_dir = os.path.join(task_workdir, "logs")

  gan_type = options["gan_type"]
  if gan_type not in SUPPORTED_GANS:
    raise ValueError("Gan type %s is not supported." % gan_type)

  dataset = options["dataset"]
  dataset_params = params.GetDatasetParameters(dataset)
  dataset_params.update(options)
  num_test_examples = dataset_params.get("eval_test_samples", 10000)

  if num_test_examples % INCEPTION_BATCH != 0:
    logging.info("Padding number of examples to fit inception batch.")
    num_test_examples -= num_test_examples % INCEPTION_BATCH

  real_images = GetRealImages(
      options["dataset"],
      split_name="test",
      num_examples=num_test_examples)
  logging.info("Real data processed.")

  result_dict = {}
  default_value = -1.0
  # Get Fake images from the generator.
  samples = []
  logging.info("Running eval for dataset %s, checkpoint: %s, num_examples: %d ",
               dataset, checkpoint_path, num_test_examples)
  with tf.Graph().as_default():
    with tf.Session() as sess:
      gan = gan_lib.create_gan(
          gan_type=gan_type,
          dataset=dataset,
          dataset_content=FakeDatasetContent(),  # This should never be used.
          options=options,
          checkpoint_dir=checkpoint_dir,
          result_dir=result_dir,
          gan_log_dir=gan_log_dir)

      gan.build_model(is_training=False)

      tf.global_variables_initializer().run()
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)

      # Make sure we have >= examples as in the test set.
      num_batches = int(np.ceil(num_test_examples / gan.batch_size))
      for _ in range(num_batches):
        z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
        feed_dict = {gan.z: z_sample}
        x = sess.run(gan.fake_images, feed_dict=feed_dict)
        # If NaNs were generated, ignore this checkpoint and assign a very high
        # FID score which we handle specially later.
        while np.isnan(x).any():
          logging.error("Detected NaN in fake_images! Returning NaN.")
          default_value = NAN_DETECTED
          return result_dict, default_value
        samples.append(x)

      if ShouldRunAccuracyLossTrainVsTest(options):
        result_dict = ComputeAccuracyLoss(options, gan, real_images,
                                          sess, result_dict)

  fake_images = np.concatenate(samples, axis=0)
  # Adjust the number of fake images to the number of images in the test set.
  fake_images = fake_images[:num_test_examples, :, :, :]

  assert fake_images.shape == real_images.shape

  # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel.
  if fake_images.shape[3] == 1:
    fake_images = np.tile(fake_images, [1, 1, 1, 3])
    # change the real_images' shape too - so that it keep matching
    # fake_images' shape.
    real_images = np.tile(real_images, [1, 1, 1, 3])

  fake_images *= 255.0

  logging.info("Fake data processed, computing inception score.")

  result_dict["inception_score"] = GetInceptionScore(fake_images,
                                                     inception_graph)
  logging.info("Inception score computed: %.3f", result_dict["inception_score"])


  result_dict["fid_score"] = ComputeTFGanFIDScore(fake_images, real_images,
                                                  inception_graph)
  logging.info("Frechet Inception Distance for checkpoint %s is %.3f",
               checkpoint_path, result_dict["fid_score"])

  if ShouldRunMultiscaleSSIM(options):
    result_dict["ms_ssim"] = ComputeMultiscaleSSIMScore(fake_images)
    logging.info("MS-SSIM score computed: %.3f", result_dict["ms_ssim"])

  return result_dict, default_value
Beispiel #4
0
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph):
    """Evaluate model at given checkpoint_path."""

    # Make sure that the same latent variables are used for each evaluation.
    np.random.seed(42)

    checkpoint_dir = os.path.join(task_workdir, "checkpoint")
    result_dir = os.path.join(task_workdir, "result")
    gan_log_dir = os.path.join(task_workdir, "logs")

    gan_type = options["gan_type"]
    if gan_type not in SUPPORTED_GANS:
        raise ValueError("Gan type %s is not supported." % gan_type)
    dataset = options["dataset"]

    dataset_content = gan_lib.load_dataset(dataset, split_name="test")
    num_test_examples = FLAGS.num_test_examples

    if num_test_examples % INCEPTION_BATCH != 0:
        logging.info("Padding number of examples to fit inception batch.")
        num_test_examples -= num_test_examples % INCEPTION_BATCH

    # Get real images from the dataset. In the case of a 1-channel
    # dataset (like mnist) convert it to 3 channels.

    data_x = []
    with tf.Graph().as_default():
        get_next = dataset_content.make_one_shot_iterator().get_next()
        with tf.Session() as sess:
            for _ in range(num_test_examples):
                data_x.append(sess.run(get_next[0]))

    real_images = np.array(data_x)
    if real_images.shape[0] != num_test_examples:
        raise ValueError("Not enough examples in the dataset.")

    if real_images.shape[3] == 1:
        real_images = np.tile(real_images, [1, 1, 1, 3])
    real_images *= 255.0
    logging.info("Real data processed.")

    # Get Fake images from the generator.
    samples = []
    logging.info("Running eval on checkpoint path: %s", checkpoint_path)
    with tf.Graph().as_default():
        with tf.Session() as sess:
            gan = gan_lib.create_gan(gan_type=gan_type,
                                     dataset=dataset,
                                     sess=sess,
                                     dataset_content=dataset_content,
                                     options=options,
                                     checkpoint_dir=checkpoint_dir,
                                     result_dir=result_dir,
                                     gan_log_dir=gan_log_dir)

            gan.build_model(is_training=False)

            tf.global_variables_initializer().run()
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_path)

            # Make sure we have >= examples as in the test set.
            num_batches = int(np.ceil(num_test_examples / gan.batch_size))
            for _ in range(num_batches):
                z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
                feed_dict = {gan.z: z_sample}
                x = sess.run(gan.fake_images, feed_dict=feed_dict)
                # If NaNs were generated, ignore this checkpoint and assign a very high
                # FID score which we handle specially later.
                while np.isnan(x).any():
                    logging.error(
                        "Detected NaN in fake_images! Returning NaN.")
                    return NAN_DETECTED, NAN_DETECTED
                samples.append(x)

    fake_images = np.concatenate(samples, axis=0)
    # Adjust the number of fake images to the number of images in the test set.
    fake_images = fake_images[:num_test_examples, :, :, :]
    # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel.
    if fake_images.shape[3] == 1:
        fake_images = np.tile(fake_images, [1, 1, 1, 3])
    fake_images *= 255.0
    logging.info("Fake data processed, computing inception score.")
    inception_score = GetInceptionScore(fake_images, inception_graph)
    logging.info("Inception score computed: %.3f", inception_score)

    assert fake_images.shape == real_images.shape

    fid_score = ComputeTFGanFIDScore(fake_images, real_images, inception_graph)

    logging.info("Frechet Inception Distance for checkpoint %s is %.3f",
                 checkpoint_path, fid_score)
    return inception_score, fid_score
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph):
  """Evaluate model at given checkpoint_path."""

  # Make sure that the same latent variables are used for each evaluation.
  np.random.seed(42)

  checkpoint_dir = os.path.join(task_workdir, "checkpoint")
  result_dir = os.path.join(task_workdir, "result")
  gan_log_dir = os.path.join(task_workdir, "logs")

  gan_type = options["gan_type"]
  if gan_type not in SUPPORTED_GANS:
    raise ValueError("Gan type %s is not supported." % gan_type)
  dataset = options["dataset"]

  dataset_content = gan_lib.load_dataset(dataset, split_name="test")
  num_test_examples = FLAGS.num_test_examples

  if num_test_examples % INCEPTION_BATCH != 0:
    logging.info("Padding number of examples to fit inception batch.")
    num_test_examples -= num_test_examples % INCEPTION_BATCH

  # Get real images from the dataset. In the case of a 1-channel
  # dataset (like mnist) convert it to 3 channels.

  data_x = []
  with tf.Graph().as_default():
    get_next = dataset_content.make_one_shot_iterator().get_next()
    with tf.Session() as sess:
      for _ in range(num_test_examples):
        data_x.append(sess.run(get_next[0]))

  real_images = np.array(data_x)
  if real_images.shape[0] != num_test_examples:
    raise ValueError("Not enough examples in the dataset.")

  if real_images.shape[3] == 1:
    real_images = np.tile(real_images, [1, 1, 1, 3])
  real_images *= 255.0
  logging.info("Real data processed.")

  # Get Fake images from the generator.
  samples = []
  logging.info("Running eval on checkpoint path: %s", checkpoint_path)
  with tf.Graph().as_default():
    with tf.Session() as sess:
      gan = gan_lib.create_gan(
          gan_type=gan_type,
          dataset=dataset,
          sess=sess,
          dataset_content=dataset_content,
          options=options,
          checkpoint_dir=checkpoint_dir,
          result_dir=result_dir,
          gan_log_dir=gan_log_dir)

      gan.build_model(is_training=False)

      tf.global_variables_initializer().run()
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)

      # Make sure we have >= examples as in the test set.
      num_batches = int(np.ceil(num_test_examples / gan.batch_size))
      for _ in range(num_batches):
        z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
        feed_dict = {gan.z: z_sample}
        x = sess.run(gan.fake_images, feed_dict=feed_dict)
        # If NaNs were generated, ignore this checkpoint and assign a very high
        # FID score which we handle specially later.
        while np.isnan(x).any():
          logging.error("Detected NaN in fake_images! Returning NaN.")
          return NAN_DETECTED, NAN_DETECTED
        samples.append(x)

  fake_images = np.concatenate(samples, axis=0)
  # Adjust the number of fake images to the number of images in the test set.
  fake_images = fake_images[:num_test_examples, :, :, :]
  # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel.
  if fake_images.shape[3] == 1:
    fake_images = np.tile(fake_images, [1, 1, 1, 3])
  fake_images *= 255.0
  logging.info("Fake data processed, computing inception score.")
  inception_score = GetInceptionScore(fake_images, inception_graph)
  logging.info("Inception score computed: %.3f", inception_score)

  assert fake_images.shape == real_images.shape

  fid_score = ComputeTFGanFIDScore(fake_images, real_images, inception_graph)

  logging.info("Frechet Inception Distance for checkpoint %s is %.3f",
               checkpoint_path, fid_score)
  return inception_score, fid_score