def compute_accuracy_loss(sess,
                          gan,
                          test_images,
                          max_train_examples=50000,
                          num_repeat=5):
  """Compute discriminator's accuracy and loss on a given dataset.

  Args:
    sess: Tf.Session object.
    gan: Any AbstractGAN instance.
    test_images: numpy array with test images.
    max_train_examples: How many "train" examples to get from the dataset.
                        In each round, some of them will be randomly selected
                        to evaluate train set accuracy.
    num_repeat: How many times to repreat the computation.
                The mean of all the results is reported.
  Returns:
    Dict[Text, float] with all the computed scores.

  Raises:
    ValueError: If the number of test_images is greater than the number of
                training images returned by the dataset.
  """
  logging.info("Evaluating training and test accuracy...")
  train_images = eval_utils.get_real_images(
      dataset=datasets.get_dataset(),
      num_examples=max_train_examples,
      split="train",
      failure_on_insufficient_examples=False)
  if train_images.shape[0] < test_images.shape[0]:
    raise ValueError("num_train %d must be larger than num_test %d." %
                     (train_images.shape[0], test_images.shape[0]))

  num_batches = int(np.floor(test_images.shape[0] / gan.batch_size))
  if num_batches * gan.batch_size < test_images.shape[0]:
    logging.error("Ignoring the last batch with %d samples / %d epoch size.",
                  test_images.shape[0] - num_batches * gan.batch_size,
                  gan.batch_size)

  ret = {
      "train_accuracy": [],
      "test_accuracy": [],
      "fake_accuracy": [],
      "train_d_loss": [],
      "test_d_loss": []
  }

  for _ in range(num_repeat):
    idx = np.random.choice(train_images.shape[0], test_images.shape[0])
    bs = gan.batch_size
    train_subset = [train_images[i] for i in idx]
    train_predictions, test_predictions, fake_predictions = [], [], []
    train_d_losses, test_d_losses = [], []

    for i in range(num_batches):
      z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
      start_idx = i * bs
      end_idx = start_idx + bs
      test_batch = test_images[start_idx : end_idx]
      train_batch = train_subset[start_idx : end_idx]

      test_prediction, test_d_loss, fake_images = sess.run(
          [gan.discriminator_output, gan.d_loss, gan.fake_images],
          feed_dict={
              gan.inputs: test_batch, gan.z: z_sample
          })
      train_prediction, train_d_loss = sess.run(
          [gan.discriminator_output, gan.d_loss],
          feed_dict={
              gan.inputs: train_batch,
              gan.z: z_sample
          })
      fake_prediction = sess.run(
          gan.discriminator_output,
          feed_dict={gan.inputs: fake_images})[0]

      train_predictions.append(train_prediction[0])
      test_predictions.append(test_prediction[0])
      fake_predictions.append(fake_prediction)
      train_d_losses.append(train_d_loss)
      test_d_losses.append(test_d_loss)

    train_predictions = [x >= 0.5 for x in train_predictions]
    test_predictions = [x >= 0.5 for x in test_predictions]
    fake_predictions = [x < 0.5 for x in fake_predictions]

    ret["train_accuracy"].append(np.array(train_predictions).mean())
    ret["test_accuracy"].append(np.array(test_predictions).mean())
    ret["fake_accuracy"].append(np.array(fake_predictions).mean())
    ret["train_d_loss"].append(np.mean(train_d_losses))
    ret["test_d_loss"].append(np.mean(test_d_losses))

  for key in ret:
    ret[key] = np.mean(ret[key])

  return ret
def evaluate_tfhub_module(module_spec, eval_tasks, use_tpu,
                          num_averaging_runs, step):
  """Evaluate model at given checkpoint_path.

  Args:
    module_spec: string, path to a TF hub module.
    eval_tasks: List of objects that inherit from EvalTask.
    use_tpu: Whether to use TPUs.
    num_averaging_runs: Determines how many times each metric is computed.
    step: Name of the step being evaluated

  Returns:
    Dict[Text, float] with all the computed results.

  Raises:
    NanFoundError: If generator output has any NaNs.
  """
  # Make sure that the same latent variables are used for each evaluation.
  np.random.seed(42)
  dataset = datasets.get_dataset()
  num_test_examples = dataset.eval_test_samples

  batch_size = FLAGS.eval_batch_size
  num_batches = int(np.ceil(num_test_examples / batch_size))

  # Load and update the generator.
  result_dict = {}
  fake_dsets = []
  with tf.Graph().as_default():
    tf.set_random_seed(42)
    with tf.Session() as sess:
      if use_tpu:
        sess.run(tf.contrib.tpu.initialize_system())
      def sample_from_generator():
        """Create graph for sampling images."""
        generator = hub.Module(
            module_spec,
            name="gen_module",
            tags={"gen", "bs{}".format(batch_size)})
        logging.info("Generator inputs: %s", generator.get_input_info_dict())
        z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value
        z = z_generator(shape=[batch_size, z_dim])
        if "labels" in generator.get_input_info_dict():
          # Conditional GAN.
          assert dataset.num_classes

          if FLAGS.force_label is None:
            labels = tf.random.uniform(
                [batch_size], maxval=dataset.num_classes, dtype=tf.int32)
          else:
            labels = tf.constant(FLAGS.force_label, shape=[batch_size], dtype=tf.int32)

          inputs = dict(z=z, labels=labels)
        else:
          # Unconditional GAN.
          assert "labels" not in generator.get_input_info_dict()
          inputs = dict(z=z)
        return generator(inputs=inputs, as_dict=True)["generated"]
      
      if use_tpu:
        generated = tf.contrib.tpu.rewrite(sample_from_generator)
      else:
        generated = sample_from_generator()

      tf.global_variables_initializer().run()

      save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt")

      if not tf.io.gfile.exists(save_model_accu_path):
        if _update_bn_accumulators(sess, generated, num_accu_examples=204800):
          saver = tf.train.Saver()
          checkpoint_path = saver.save(
              sess,
              save_path=save_model_accu_path)
          logging.info("Exported generator with accumulated batch stats to "
                       "%s.", checkpoint_path)
      if not eval_tasks:
        logging.error("Task list is empty, returning.")
        return

      for i in range(num_averaging_runs):
        logging.info("Generating fake data set %d/%d.", i+1, num_averaging_runs)
        fake_dset = eval_utils.EvalDataSample(
            eval_utils.sample_fake_dataset(sess, generated, num_batches))
        fake_dsets.append(fake_dset)

        # Hacking this in here for speed for now
        save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step)

        logging.info("Computing inception features for generated data %d/%d.",
                     i+1, num_averaging_runs)
        activations, logits = eval_utils.inception_transform_np(
            fake_dset.images, batch_size)
        fake_dset.set_inception_features(
            activations=activations, logits=logits)
        fake_dset.set_num_examples(num_test_examples)
        if i != 0:
          # Free up some memory by releasing additional fake data samples.
          # For ImageNet128 50k images are ~9 GiB. This will blow up metrics
          # (such as fractal dimension) if num_averaging_runs > 1.
          fake_dset.discard_images()

  real_dset = eval_utils.EvalDataSample(
      eval_utils.get_real_images(
          dataset=dataset, num_examples=num_test_examples))
  logging.info("Getting Inception features for real images.")
  real_dset.activations, _ = eval_utils.inception_transform_np(
      real_dset.images, batch_size)
  real_dset.set_num_examples(num_test_examples)

  # Run all the tasks and update the result dictionary with the task statistics.
  result_dict = {}
  for task in eval_tasks:
    task_results_dicts = [
        task.run_after_session(fake_dset, real_dset)
        for fake_dset in fake_dsets
    ]
    # Average the score for each key.
    result_statistics = {}
    for key in task_results_dicts[0].keys():
      scores_for_key = np.array([d[key] for d in task_results_dicts])
      mean, std = np.mean(scores_for_key), np.std(scores_for_key)
      scores_as_string = "_".join([str(x) for x in scores_for_key])
      result_statistics[key + "_mean"] = mean
      result_statistics[key + "_std"] = std
      result_statistics[key + "_list"] = scores_as_string
    logging.info("Computed results for task %s: %s", task, result_statistics)

    result_dict.update(result_statistics)
  return result_dict