def compute_accuracy_loss(sess, gan, test_images, max_train_examples=50000, num_repeat=5): """Compute discriminator's accuracy and loss on a given dataset. Args: sess: Tf.Session object. gan: Any AbstractGAN instance. test_images: numpy array with test images. max_train_examples: How many "train" examples to get from the dataset. In each round, some of them will be randomly selected to evaluate train set accuracy. num_repeat: How many times to repreat the computation. The mean of all the results is reported. Returns: Dict[Text, float] with all the computed scores. Raises: ValueError: If the number of test_images is greater than the number of training images returned by the dataset. """ logging.info("Evaluating training and test accuracy...") train_images = eval_utils.get_real_images( dataset=datasets.get_dataset(), num_examples=max_train_examples, split="train", failure_on_insufficient_examples=False) if train_images.shape[0] < test_images.shape[0]: raise ValueError("num_train %d must be larger than num_test %d." % (train_images.shape[0], test_images.shape[0])) num_batches = int(np.floor(test_images.shape[0] / gan.batch_size)) if num_batches * gan.batch_size < test_images.shape[0]: logging.error("Ignoring the last batch with %d samples / %d epoch size.", test_images.shape[0] - num_batches * gan.batch_size, gan.batch_size) ret = { "train_accuracy": [], "test_accuracy": [], "fake_accuracy": [], "train_d_loss": [], "test_d_loss": [] } for _ in range(num_repeat): idx = np.random.choice(train_images.shape[0], test_images.shape[0]) bs = gan.batch_size train_subset = [train_images[i] for i in idx] train_predictions, test_predictions, fake_predictions = [], [], [] train_d_losses, test_d_losses = [], [] for i in range(num_batches): z_sample = gan.z_generator(gan.batch_size, gan.z_dim) start_idx = i * bs end_idx = start_idx + bs test_batch = test_images[start_idx : end_idx] train_batch = train_subset[start_idx : end_idx] test_prediction, test_d_loss, fake_images = sess.run( [gan.discriminator_output, gan.d_loss, gan.fake_images], feed_dict={ gan.inputs: test_batch, gan.z: z_sample }) train_prediction, train_d_loss = sess.run( [gan.discriminator_output, gan.d_loss], feed_dict={ gan.inputs: train_batch, gan.z: z_sample }) fake_prediction = sess.run( gan.discriminator_output, feed_dict={gan.inputs: fake_images})[0] train_predictions.append(train_prediction[0]) test_predictions.append(test_prediction[0]) fake_predictions.append(fake_prediction) train_d_losses.append(train_d_loss) test_d_losses.append(test_d_loss) train_predictions = [x >= 0.5 for x in train_predictions] test_predictions = [x >= 0.5 for x in test_predictions] fake_predictions = [x < 0.5 for x in fake_predictions] ret["train_accuracy"].append(np.array(train_predictions).mean()) ret["test_accuracy"].append(np.array(test_predictions).mean()) ret["fake_accuracy"].append(np.array(fake_predictions).mean()) ret["train_d_loss"].append(np.mean(train_d_losses)) ret["test_d_loss"].append(np.mean(test_d_losses)) for key in ret: ret[key] = np.mean(ret[key]) return ret
def evaluate_tfhub_module(module_spec, eval_tasks, use_tpu, num_averaging_runs, step): """Evaluate model at given checkpoint_path. Args: module_spec: string, path to a TF hub module. eval_tasks: List of objects that inherit from EvalTask. use_tpu: Whether to use TPUs. num_averaging_runs: Determines how many times each metric is computed. step: Name of the step being evaluated Returns: Dict[Text, float] with all the computed results. Raises: NanFoundError: If generator output has any NaNs. """ # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) dataset = datasets.get_dataset() num_test_examples = dataset.eval_test_samples batch_size = FLAGS.eval_batch_size num_batches = int(np.ceil(num_test_examples / batch_size)) # Load and update the generator. result_dict = {} fake_dsets = [] with tf.Graph().as_default(): tf.set_random_seed(42) with tf.Session() as sess: if use_tpu: sess.run(tf.contrib.tpu.initialize_system()) def sample_from_generator(): """Create graph for sampling images.""" generator = hub.Module( module_spec, name="gen_module", tags={"gen", "bs{}".format(batch_size)}) logging.info("Generator inputs: %s", generator.get_input_info_dict()) z_dim = generator.get_input_info_dict()["z"].get_shape()[1].value z = z_generator(shape=[batch_size, z_dim]) if "labels" in generator.get_input_info_dict(): # Conditional GAN. assert dataset.num_classes if FLAGS.force_label is None: labels = tf.random.uniform( [batch_size], maxval=dataset.num_classes, dtype=tf.int32) else: labels = tf.constant(FLAGS.force_label, shape=[batch_size], dtype=tf.int32) inputs = dict(z=z, labels=labels) else: # Unconditional GAN. assert "labels" not in generator.get_input_info_dict() inputs = dict(z=z) return generator(inputs=inputs, as_dict=True)["generated"] if use_tpu: generated = tf.contrib.tpu.rewrite(sample_from_generator) else: generated = sample_from_generator() tf.global_variables_initializer().run() save_model_accu_path = os.path.join(module_spec, "model-with-accu.ckpt") if not tf.io.gfile.exists(save_model_accu_path): if _update_bn_accumulators(sess, generated, num_accu_examples=204800): saver = tf.train.Saver() checkpoint_path = saver.save( sess, save_path=save_model_accu_path) logging.info("Exported generator with accumulated batch stats to " "%s.", checkpoint_path) if not eval_tasks: logging.error("Task list is empty, returning.") return for i in range(num_averaging_runs): logging.info("Generating fake data set %d/%d.", i+1, num_averaging_runs) fake_dset = eval_utils.EvalDataSample( eval_utils.sample_fake_dataset(sess, generated, num_batches)) fake_dsets.append(fake_dset) # Hacking this in here for speed for now save_examples_lib.SaveExamplesTask().run_after_session(fake_dset, None, step) logging.info("Computing inception features for generated data %d/%d.", i+1, num_averaging_runs) activations, logits = eval_utils.inception_transform_np( fake_dset.images, batch_size) fake_dset.set_inception_features( activations=activations, logits=logits) fake_dset.set_num_examples(num_test_examples) if i != 0: # Free up some memory by releasing additional fake data samples. # For ImageNet128 50k images are ~9 GiB. This will blow up metrics # (such as fractal dimension) if num_averaging_runs > 1. fake_dset.discard_images() real_dset = eval_utils.EvalDataSample( eval_utils.get_real_images( dataset=dataset, num_examples=num_test_examples)) logging.info("Getting Inception features for real images.") real_dset.activations, _ = eval_utils.inception_transform_np( real_dset.images, batch_size) real_dset.set_num_examples(num_test_examples) # Run all the tasks and update the result dictionary with the task statistics. result_dict = {} for task in eval_tasks: task_results_dicts = [ task.run_after_session(fake_dset, real_dset) for fake_dset in fake_dsets ] # Average the score for each key. result_statistics = {} for key in task_results_dicts[0].keys(): scores_for_key = np.array([d[key] for d in task_results_dicts]) mean, std = np.mean(scores_for_key), np.std(scores_for_key) scores_as_string = "_".join([str(x) for x in scores_for_key]) result_statistics[key + "_mean"] = mean result_statistics[key + "_std"] = std result_statistics[key + "_list"] = scores_as_string logging.info("Computed results for task %s: %s", task, result_statistics) result_dict.update(result_statistics) return result_dict