def evaluate_tfhub_module(module_spec, eval_tasks, use_tpu,
                          num_averaging_runs, step):
  """Evaluate model at given checkpoint_path.

  Args:
    module_spec: string, path to a TF hub module.
    eval_tasks: List of objects that inherit from EvalTask.
    use_tpu: Whether to use TPUs.
    num_averaging_runs: Determines how many times each metric is computed.
    step: Name of the step being evaluated

  Returns:
    Dict[Text, float] with all the computed results.

  Raises:
    NanFoundError: If generator output has any NaNs.
  """
  # Make sure that the same latent variables are used for each evaluation.
  np.random.seed(42)
  dataset = datasets.get_dataset()
  num_test_examples = dataset.eval_test_samples

  batch_size = FLAGS.eval_batch_size
  num_batches = int(np.ceil(num_test_examples / batch_size))

  # Load and update the generator.
  result_dict = {}
  fake_dsets = []
  with tf.Graph().as_default():
    tf.set_random_seed(42)
    with tf.Session() as sess:
      if use_tpu:
        sess.run(tf.contrib.tpu.initialize_system())
      def sample_from_generator():
        """Create graph for sampling images."""
        generator = hub.Module(
            module_spec,
            name="gen_module",
            tags={"gen", "bs{}".format(batch_size)})
        logging.info("Generator inputs: %s", generator.get_input_info_dict())
        z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value
        z = z_generator(shape=[batch_size, z_dim])
        if "labels" in generator.get_input_info_dict():
          # Conditional GAN.
          assert dataset.num_classes

          if FLAGS.force_label is None:
            labels = tf.random.uniform(
                [batch_size], maxval=dataset.num_classes, dtype=tf.int32)
          else:
            labels = tf.constant(FLAGS.force_label, shape=[batch_size], dtype=tf.int32)

          inputs = dict(z=z, labels=labels)
        else:
          # Unconditional GAN.
          assert "labels" not in generator.get_input_info_dict()
          inputs = dict(z=z)
        return generator(inputs=inputs, as_dict=True)["generated"]
      
      if use_tpu:
        generated = tf.contrib.tpu.rewrite(sample_from_generator)
      else:
        generated = sample_from_generator()

      tf.global_variables_initializer().run()

      save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt")

      if not tf.io.gfile.exists(save_model_accu_path):
        if _update_bn_accumulators(sess, generated, num_accu_examples=204800):
          saver = tf.train.Saver()
          checkpoint_path = saver.save(
              sess,
              save_path=save_model_accu_path)
          logging.info("Exported generator with accumulated batch stats to "
                       "%s.", checkpoint_path)
      if not eval_tasks:
        logging.error("Task list is empty, returning.")
        return

      for i in range(num_averaging_runs):
        logging.info("Generating fake data set %d/%d.", i+1, num_averaging_runs)
        fake_dset = eval_utils.EvalDataSample(
            eval_utils.sample_fake_dataset(sess, generated, num_batches))
        fake_dsets.append(fake_dset)

        # Hacking this in here for speed for now
        save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step)

        logging.info("Computing inception features for generated data %d/%d.",
                     i+1, num_averaging_runs)
        activations, logits = eval_utils.inception_transform_np(
            fake_dset.images, batch_size)
        fake_dset.set_inception_features(
            activations=activations, logits=logits)
        fake_dset.set_num_examples(num_test_examples)
        if i != 0:
          # Free up some memory by releasing additional fake data samples.
          # For ImageNet128 50k images are ~9 GiB. This will blow up metrics
          # (such as fractal dimension) if num_averaging_runs > 1.
          fake_dset.discard_images()

  real_dset = eval_utils.EvalDataSample(
      eval_utils.get_real_images(
          dataset=dataset, num_examples=num_test_examples))
  logging.info("Getting Inception features for real images.")
  real_dset.activations, _ = eval_utils.inception_transform_np(
      real_dset.images, batch_size)
  real_dset.set_num_examples(num_test_examples)

  # Run all the tasks and update the result dictionary with the task statistics.
  result_dict = {}
  for task in eval_tasks:
    task_results_dicts = [
        task.run_after_session(fake_dset, real_dset)
        for fake_dset in fake_dsets
    ]
    # Average the score for each key.
    result_statistics = {}
    for key in task_results_dicts[0].keys():
      scores_for_key = np.array([d[key] for d in task_results_dicts])
      mean, std = np.mean(scores_for_key), np.std(scores_for_key)
      scores_as_string = "_".join([str(x) for x in scores_for_key])
      result_statistics[key + "_mean"] = mean
      result_statistics[key + "_std"] = std
      result_statistics[key + "_list"] = scores_as_string
    logging.info("Computed results for task %s: %s", task, result_statistics)

    result_dict.update(result_statistics)
  return result_dict
def generate_tfhub_module(module_spec, use_tpu, step):
  """Generate from model at given checkpoint_path.

  Args:
    module_spec: string, path to a TF hub module.
    use_tpu: Whether to use TPUs.
    step: Name of the step being evaluated

  Returns:
    Nothing

  Raises:
    NanFoundError: If generator output has any NaNs.
  """
  # Make sure that the same latent variables are used for each evaluation.
  np.random.seed(42)
  dataset = datasets.get_dataset()
  num_test_examples = FLAGS.num_samples

  batch_size = FLAGS.eval_batch_size
  num_batches = int(np.ceil(num_test_examples / batch_size))

  # Load and update the generator.
  result_dict = {}
  fake_dsets = []
  with tf.Graph().as_default():
    tf.set_random_seed(42)
    with tf.Session() as sess:
      if use_tpu:
        sess.run(tf.contrib.tpu.initialize_system())

      def create_generator(force_label=None):
        def sample_from_generator():
          """Create graph for sampling images."""
          generator = hub.Module(
              module_spec,
              name="gen_module",
              tags={"gen", "bs{}".format(batch_size)})
          logging.info("Generator inputs: %s", generator.get_input_info_dict())
          z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value
          z = z_generator(shape=[batch_size, z_dim])
          if "labels" in generator.get_input_info_dict():
            # Conditional GAN.
            assert dataset.num_classes

            if force_label is None:
              labels = tf.random.uniform(
                  [batch_size], maxval=dataset.num_classes, dtype=tf.int32)
            else:
              labels = tf.constant(force_label, shape=[batch_size], dtype=tf.int32)

            inputs = dict(z=z, labels=labels)
          else:
            # Unconditional GAN.
            assert "labels" not in generator.get_input_info_dict()
            inputs = dict(z=z)
          return generator(inputs=inputs, as_dict=True)["generated"]
        return sample_from_generator

      if use_tpu:
        generated = tf.contrib.tpu.rewrite(create_generator())
      else:
        generated = create_generator()()

      tf.global_variables_initializer().run()

      save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt")
      saver = tf.train.Saver()
    
      if _update_bn_accumulators(sess, generated, num_accu_examples=FLAGS.num_accu_examples):
        checkpoint_path = saver.save(sess, save_path=save_model_accu_path)
        logging.info("Exported generator with accumulated batch stats to "
                     "%s.", checkpoint_path)

      logging.info("Generating fake data set")
      fake_dset = eval_utils.EvalDataSample(
          eval_utils.sample_fake_dataset(sess, generated, num_batches))

      save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step)

      if FLAGS.force_label is not None:
        if use_tpu:
          generated = tf.contrib.tpu.rewrite(create_generator(FLAGS.force_label))
        else:
          generated = create_generator(FLAGS.force_label)()

        tf.global_variables_initializer().run()
        
        _update_bn_accumulators(sess, generated, num_accu_examples=FLAGS.num_accu_examples)

        logging.info("Generating fake data set with forced label")
        fake_dset = eval_utils.EvalDataSample(
            eval_utils.sample_fake_dataset(sess, generated, num_batches))

        save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step, force_label=FLAGS.force_label)