def evaluate_tfhub_module(module_spec, eval_tasks, use_tpu, num_averaging_runs, step): """Evaluate model at given checkpoint_path. Args: module_spec: string, path to a TF hub module. eval_tasks: List of objects that inherit from EvalTask. use_tpu: Whether to use TPUs. num_averaging_runs: Determines how many times each metric is computed. step: Name of the step being evaluated Returns: Dict[Text, float] with all the computed results. Raises: NanFoundError: If generator output has any NaNs. """ # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) dataset = datasets.get_dataset() num_test_examples = dataset.eval_test_samples batch_size = FLAGS.eval_batch_size num_batches = int(np.ceil(num_test_examples / batch_size)) # Load and update the generator. result_dict = {} fake_dsets = [] with tf.Graph().as_default(): tf.set_random_seed(42) with tf.Session() as sess: if use_tpu: sess.run(tf.contrib.tpu.initialize_system()) def sample_from_generator(): """Create graph for sampling images.""" generator = hub.Module( module_spec, name="gen_module", tags={"gen", "bs{}".format(batch_size)}) logging.info("Generator inputs: %s", generator.get_input_info_dict()) z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value z = z_generator(shape=[batch_size, z_dim]) if "labels" in generator.get_input_info_dict(): # Conditional GAN. assert dataset.num_classes if FLAGS.force_label is None: labels = tf.random.uniform( [batch_size], maxval=dataset.num_classes, dtype=tf.int32) else: labels = tf.constant(FLAGS.force_label, shape=[batch_size], dtype=tf.int32) inputs = dict(z=z, labels=labels) else: # Unconditional GAN. assert "labels" not in generator.get_input_info_dict() inputs = dict(z=z) return generator(inputs=inputs, as_dict=True)["generated"] if use_tpu: generated = tf.contrib.tpu.rewrite(sample_from_generator) else: generated = sample_from_generator() tf.global_variables_initializer().run() save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt") if not tf.io.gfile.exists(save_model_accu_path): if _update_bn_accumulators(sess, generated, num_accu_examples=204800): saver = tf.train.Saver() checkpoint_path = saver.save( sess, save_path=save_model_accu_path) logging.info("Exported generator with accumulated batch stats to " "%s.", checkpoint_path) if not eval_tasks: logging.error("Task list is empty, returning.") return for i in range(num_averaging_runs): logging.info("Generating fake data set %d/%d.", i+1, num_averaging_runs) fake_dset = eval_utils.EvalDataSample( eval_utils.sample_fake_dataset(sess, generated, num_batches)) fake_dsets.append(fake_dset) # Hacking this in here for speed for now save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step) logging.info("Computing inception features for generated data %d/%d.", i+1, num_averaging_runs) activations, logits = eval_utils.inception_transform_np( fake_dset.images, batch_size) fake_dset.set_inception_features( activations=activations, logits=logits) fake_dset.set_num_examples(num_test_examples) if i != 0: # Free up some memory by releasing additional fake data samples. # For ImageNet128 50k images are ~9 GiB. This will blow up metrics # (such as fractal dimension) if num_averaging_runs > 1. fake_dset.discard_images() real_dset = eval_utils.EvalDataSample( eval_utils.get_real_images( dataset=dataset, num_examples=num_test_examples)) logging.info("Getting Inception features for real images.") real_dset.activations, _ = eval_utils.inception_transform_np( real_dset.images, batch_size) real_dset.set_num_examples(num_test_examples) # Run all the tasks and update the result dictionary with the task statistics. result_dict = {} for task in eval_tasks: task_results_dicts = [ task.run_after_session(fake_dset, real_dset) for fake_dset in fake_dsets ] # Average the score for each key. result_statistics = {} for key in task_results_dicts[0].keys(): scores_for_key = np.array([d[key] for d in task_results_dicts]) mean, std = np.mean(scores_for_key), np.std(scores_for_key) scores_as_string = "_".join([str(x) for x in scores_for_key]) result_statistics[key + "_mean"] = mean result_statistics[key + "_std"] = std result_statistics[key + "_list"] = scores_as_string logging.info("Computed results for task %s: %s", task, result_statistics) result_dict.update(result_statistics) return result_dict
def generate_tfhub_module(module_spec, use_tpu, step): """Generate from model at given checkpoint_path. Args: module_spec: string, path to a TF hub module. use_tpu: Whether to use TPUs. step: Name of the step being evaluated Returns: Nothing Raises: NanFoundError: If generator output has any NaNs. """ # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) dataset = datasets.get_dataset() num_test_examples = FLAGS.num_samples batch_size = FLAGS.eval_batch_size num_batches = int(np.ceil(num_test_examples / batch_size)) # Load and update the generator. result_dict = {} fake_dsets = [] with tf.Graph().as_default(): tf.set_random_seed(42) with tf.Session() as sess: if use_tpu: sess.run(tf.contrib.tpu.initialize_system()) def create_generator(force_label=None): def sample_from_generator(): """Create graph for sampling images.""" generator = hub.Module( module_spec, name="gen_module", tags={"gen", "bs{}".format(batch_size)}) logging.info("Generator inputs: %s", generator.get_input_info_dict()) z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value z = z_generator(shape=[batch_size, z_dim]) if "labels" in generator.get_input_info_dict(): # Conditional GAN. assert dataset.num_classes if force_label is None: labels = tf.random.uniform( [batch_size], maxval=dataset.num_classes, dtype=tf.int32) else: labels = tf.constant(force_label, shape=[batch_size], dtype=tf.int32) inputs = dict(z=z, labels=labels) else: # Unconditional GAN. assert "labels" not in generator.get_input_info_dict() inputs = dict(z=z) return generator(inputs=inputs, as_dict=True)["generated"] return sample_from_generator if use_tpu: generated = tf.contrib.tpu.rewrite(create_generator()) else: generated = create_generator()() tf.global_variables_initializer().run() save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt") saver = tf.train.Saver() if _update_bn_accumulators(sess, generated, num_accu_examples=FLAGS.num_accu_examples): checkpoint_path = saver.save(sess, save_path=save_model_accu_path) logging.info("Exported generator with accumulated batch stats to " "%s.", checkpoint_path) logging.info("Generating fake data set") fake_dset = eval_utils.EvalDataSample( eval_utils.sample_fake_dataset(sess, generated, num_batches)) save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step) if FLAGS.force_label is not None: if use_tpu: generated = tf.contrib.tpu.rewrite(create_generator(FLAGS.force_label)) else: generated = create_generator(FLAGS.force_label)() tf.global_variables_initializer().run() _update_bn_accumulators(sess, generated, num_accu_examples=FLAGS.num_accu_examples) logging.info("Generating fake data set with forced label") fake_dset = eval_utils.EvalDataSample( eval_utils.sample_fake_dataset(sess, generated, num_batches)) save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step, force_label=FLAGS.force_label)