def RunCheckpointEval(checkpoint_path, task_workdir, options, tasks_to_run): """Evaluate model at given checkpoint_path. Args: checkpoint_path: string, path to the single checkpoint to evaluate. task_workdir: directory, where results and logs can be written. options: Dict[Text, Text] with all parameters for the current trial. tasks_to_run: List of objects that inherit from EvalTask. Returns: Dict[Text, float] with all the computed results. Raises: NanFoundError: If gan output has generated any NaNs. ValueError: If options["gan_type"] is not supported. """ # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) checkpoint_dir = os.path.join(task_workdir, "checkpoint") result_dir = os.path.join(task_workdir, "result") gan_log_dir = os.path.join(task_workdir, "logs") gan_type = options["gan_type"] if gan_type not in SUPPORTED_GANS: raise ValueError("Gan type %s is not supported." % gan_type) dataset = options["dataset"] dataset_params = params.GetDatasetParameters(dataset) dataset_params.update(options) num_test_examples = dataset_params.get("eval_test_samples", 10000) if num_test_examples % INCEPTION_BATCH != 0: logging.info("Padding number of examples to fit inception batch.") num_test_examples -= num_test_examples % INCEPTION_BATCH real_images = GetRealImages(options["dataset"], split_name="test", num_examples=num_test_examples) logging.info("Real data processed.") result_dict = {} # Get Fake images from the generator. samples = [] logging.info( "Running eval for dataset %s, checkpoint: %s, num_examples: %d ", dataset, checkpoint_path, num_test_examples) with tf.Graph().as_default(): with tf.Session() as sess: gan = gan_lib.create_gan(gan_type=gan_type, dataset=dataset, dataset_content=None, options=options, checkpoint_dir=checkpoint_dir, result_dir=result_dir, gan_log_dir=gan_log_dir) gan.build_model(is_training=False) tf.global_variables_initializer().run() saver = tf.train.Saver() saver.restore(sess, checkpoint_path) # Make sure we have >= examples as in the test set. num_batches = int(np.ceil(num_test_examples / gan.batch_size)) for _ in range(num_batches): z_sample = gan.z_generator(gan.batch_size, gan.z_dim) x = sess.run(gan.fake_images, feed_dict={gan.z: z_sample}) # If NaNs were generated, ignore this checkpoint and assign a very high # FID score which we handle specially later. if np.isnan(x).any(): logging.error( "Detected NaN in fake_images! Returning NaN.") raise NanFoundError("Detected NaN in fake images.") samples.append(x) print("Fake data generated, running tasks in session.") for task in tasks_to_run: result_dict.update( task.RunInSession(options, sess, gan, real_images)) fake_images = np.concatenate(samples, axis=0) # Adjust the number of fake images to the number of images in the test set. fake_images = fake_images[:num_test_examples, :, :, :] assert fake_images.shape == real_images.shape # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel. if fake_images.shape[3] == 1: fake_images = np.tile(fake_images, [1, 1, 1, 3]) # change the real_images' shape too - so that it keep matching # fake_images' shape. real_images = np.tile(real_images, [1, 1, 1, 3]) fake_images *= 255.0 logging.info("Fake data processed. Starting tasks for checkpoint: %s.", checkpoint_path) for task in tasks_to_run: result_dict.update( task.RunAfterSession(options, fake_images, real_images)) return result_dict
def EvalCheckpoint(checkpoint_path, task_workdir, options, out_cp_dir): """Evaluate model at given checkpoint_path.""" # Overwrite batch size options["batch_size"] = FLAGS.batch_size checkpoint_dir = os.path.join(task_workdir, "checkpoint") result_dir = os.path.join(task_workdir, "result") gan_log_dir = os.path.join(task_workdir, "logs") dataset_params = params.GetDatasetParameters(options["dataset"]) dataset_params.update(options) # generate fake images with tf.Graph().as_default() as g: with tf.Session() as sess: gan = gan_lib.create_gan( gan_type=options["gan_type"], dataset=options["dataset"], dataset_content=None, options=options, checkpoint_dir=checkpoint_dir, result_dir=result_dir, gan_log_dir=gan_log_dir) gan.build_model(is_training=False) tf.global_variables_initializer().run() saver = tf.train.Saver() saver.restore(sess, checkpoint_path) # Compute outputs for MultiGanGeneratorImages. if (FLAGS.visualization_type == "multi_image" and "MultiGAN" in options["gan_type"]): generator_preds_op = GetMultiGANGeneratorsOp( g, options["gan_type"], options["architecture"], options["aggregate"]) fetches = [gan.fake_images, generator_preds_op] # Construct feed dict z_sample = gan.z_generator(gan.batch_size, gan.z_dim) feed_dict = {gan.z: z_sample} # Fetch data and save images. fake_images, generator_preds = sess.run(fetches, feed_dict=feed_dict) SaveMultiGanGeneratorImages(fake_images, generator_preds, out_cp_dir) # Compute outputs for MultiGanLatentTraversalImages elif (FLAGS.visualization_type == "multi_latent" and "MultiGAN" in options["gan_type"]): generator_preds_op = GetMultiGANGeneratorsOp( g, options["gan_type"], options["architecture"], options["aggregate"]) fetches = [gan.fake_images, generator_preds_op] # Init latent params z_sample = gan.z_generator(gan.batch_size, gan.z_dim) directions = np.random.uniform(size=z_sample.shape) k_indices = np.random.randint(gan.k, size=gan.batch_size) n_steps, step_size = 10, 0.1 images, gen_preds = [], [] # Traverse in latent space of a single component n_steps times and # generate the corresponding images. for step in range(n_steps + 1): new_z = z_sample.copy() for i in range(z_sample.shape[0]): new_z[i, k_indices[i]] += ( step * step_size * directions[i, k_indices[i]]) images_batch, gen_preds_batch = sess.run(fetches, {gan.z: new_z}) images.append(images_batch) gen_preds.append(gen_preds_batch) images = np.stack(images, axis=1) gen_preds = np.stack(gen_preds, axis=1) SaveMultiGanLatentTraversalImages(images, gen_preds, out_cp_dir) # Compute outputs for GanLatentTraversalImages elif FLAGS.visualization_type == "latent": # Init latent params. z_sample = gan.z_generator(gan.batch_size, gan.z_dim) directions = np.random.uniform(size=z_sample.shape) k_indices = np.random.randint(options.get("k", 1), size=gan.batch_size) n_steps, step_size = 5, 0.1 images = [] # Traverse in latent space of a single component n_steps times and # generate the corresponding images. for step in range(n_steps + 1): new_z = z_sample.copy() for i in range(z_sample.shape[0]): if "MultiGAN" in options["gan_type"]: new_z[i, k_indices[i]] += ( step * step_size * directions[i, k_indices[i]]) else: new_z[i] += step * step_size * directions[i] images_batch = sess.run(gan.fake_images, {gan.z: new_z}) images.append(images_batch) images = np.stack(images, axis=1) SaveGanLatentTraversalImages(images, out_cp_dir)
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph): """Evaluate model at given checkpoint_path.""" # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) checkpoint_dir = os.path.join(task_workdir, "checkpoint") result_dir = os.path.join(task_workdir, "result") gan_log_dir = os.path.join(task_workdir, "logs") gan_type = options["gan_type"] if gan_type not in SUPPORTED_GANS: raise ValueError("Gan type %s is not supported." % gan_type) dataset = options["dataset"] dataset_params = params.GetDatasetParameters(dataset) dataset_params.update(options) num_test_examples = dataset_params.get("eval_test_samples", 10000) if num_test_examples % INCEPTION_BATCH != 0: logging.info("Padding number of examples to fit inception batch.") num_test_examples -= num_test_examples % INCEPTION_BATCH real_images = GetRealImages( options["dataset"], split_name="test", num_examples=num_test_examples) logging.info("Real data processed.") result_dict = {} default_value = -1.0 # Get Fake images from the generator. samples = [] logging.info("Running eval for dataset %s, checkpoint: %s, num_examples: %d ", dataset, checkpoint_path, num_test_examples) with tf.Graph().as_default(): with tf.Session() as sess: gan = gan_lib.create_gan( gan_type=gan_type, dataset=dataset, dataset_content=FakeDatasetContent(), # This should never be used. options=options, checkpoint_dir=checkpoint_dir, result_dir=result_dir, gan_log_dir=gan_log_dir) gan.build_model(is_training=False) tf.global_variables_initializer().run() saver = tf.train.Saver() saver.restore(sess, checkpoint_path) # Make sure we have >= examples as in the test set. num_batches = int(np.ceil(num_test_examples / gan.batch_size)) for _ in range(num_batches): z_sample = gan.z_generator(gan.batch_size, gan.z_dim) feed_dict = {gan.z: z_sample} x = sess.run(gan.fake_images, feed_dict=feed_dict) # If NaNs were generated, ignore this checkpoint and assign a very high # FID score which we handle specially later. while np.isnan(x).any(): logging.error("Detected NaN in fake_images! Returning NaN.") default_value = NAN_DETECTED return result_dict, default_value samples.append(x) if ShouldRunAccuracyLossTrainVsTest(options): result_dict = ComputeAccuracyLoss(options, gan, real_images, sess, result_dict) fake_images = np.concatenate(samples, axis=0) # Adjust the number of fake images to the number of images in the test set. fake_images = fake_images[:num_test_examples, :, :, :] assert fake_images.shape == real_images.shape # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel. if fake_images.shape[3] == 1: fake_images = np.tile(fake_images, [1, 1, 1, 3]) # change the real_images' shape too - so that it keep matching # fake_images' shape. real_images = np.tile(real_images, [1, 1, 1, 3]) fake_images *= 255.0 logging.info("Fake data processed, computing inception score.") result_dict["inception_score"] = GetInceptionScore(fake_images, inception_graph) logging.info("Inception score computed: %.3f", result_dict["inception_score"]) result_dict["fid_score"] = ComputeTFGanFIDScore(fake_images, real_images, inception_graph) logging.info("Frechet Inception Distance for checkpoint %s is %.3f", checkpoint_path, result_dict["fid_score"]) if ShouldRunMultiscaleSSIM(options): result_dict["ms_ssim"] = ComputeMultiscaleSSIMScore(fake_images) logging.info("MS-SSIM score computed: %.3f", result_dict["ms_ssim"]) return result_dict, default_value
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph): """Evaluate model at given checkpoint_path.""" # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) checkpoint_dir = os.path.join(task_workdir, "checkpoint") result_dir = os.path.join(task_workdir, "result") gan_log_dir = os.path.join(task_workdir, "logs") gan_type = options["gan_type"] if gan_type not in SUPPORTED_GANS: raise ValueError("Gan type %s is not supported." % gan_type) dataset = options["dataset"] dataset_content = gan_lib.load_dataset(dataset, split_name="test") num_test_examples = FLAGS.num_test_examples if num_test_examples % INCEPTION_BATCH != 0: logging.info("Padding number of examples to fit inception batch.") num_test_examples -= num_test_examples % INCEPTION_BATCH # Get real images from the dataset. In the case of a 1-channel # dataset (like mnist) convert it to 3 channels. data_x = [] with tf.Graph().as_default(): get_next = dataset_content.make_one_shot_iterator().get_next() with tf.Session() as sess: for _ in range(num_test_examples): data_x.append(sess.run(get_next[0])) real_images = np.array(data_x) if real_images.shape[0] != num_test_examples: raise ValueError("Not enough examples in the dataset.") if real_images.shape[3] == 1: real_images = np.tile(real_images, [1, 1, 1, 3]) real_images *= 255.0 logging.info("Real data processed.") # Get Fake images from the generator. samples = [] logging.info("Running eval on checkpoint path: %s", checkpoint_path) with tf.Graph().as_default(): with tf.Session() as sess: gan = gan_lib.create_gan(gan_type=gan_type, dataset=dataset, sess=sess, dataset_content=dataset_content, options=options, checkpoint_dir=checkpoint_dir, result_dir=result_dir, gan_log_dir=gan_log_dir) gan.build_model(is_training=False) tf.global_variables_initializer().run() saver = tf.train.Saver() saver.restore(sess, checkpoint_path) # Make sure we have >= examples as in the test set. num_batches = int(np.ceil(num_test_examples / gan.batch_size)) for _ in range(num_batches): z_sample = gan.z_generator(gan.batch_size, gan.z_dim) feed_dict = {gan.z: z_sample} x = sess.run(gan.fake_images, feed_dict=feed_dict) # If NaNs were generated, ignore this checkpoint and assign a very high # FID score which we handle specially later. while np.isnan(x).any(): logging.error( "Detected NaN in fake_images! Returning NaN.") return NAN_DETECTED, NAN_DETECTED samples.append(x) fake_images = np.concatenate(samples, axis=0) # Adjust the number of fake images to the number of images in the test set. fake_images = fake_images[:num_test_examples, :, :, :] # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel. if fake_images.shape[3] == 1: fake_images = np.tile(fake_images, [1, 1, 1, 3]) fake_images *= 255.0 logging.info("Fake data processed, computing inception score.") inception_score = GetInceptionScore(fake_images, inception_graph) logging.info("Inception score computed: %.3f", inception_score) assert fake_images.shape == real_images.shape fid_score = ComputeTFGanFIDScore(fake_images, real_images, inception_graph) logging.info("Frechet Inception Distance for checkpoint %s is %.3f", checkpoint_path, fid_score) return inception_score, fid_score
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph): """Evaluate model at given checkpoint_path.""" # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) checkpoint_dir = os.path.join(task_workdir, "checkpoint") result_dir = os.path.join(task_workdir, "result") gan_log_dir = os.path.join(task_workdir, "logs") gan_type = options["gan_type"] if gan_type not in SUPPORTED_GANS: raise ValueError("Gan type %s is not supported." % gan_type) dataset = options["dataset"] dataset_content = gan_lib.load_dataset(dataset, split_name="test") num_test_examples = FLAGS.num_test_examples if num_test_examples % INCEPTION_BATCH != 0: logging.info("Padding number of examples to fit inception batch.") num_test_examples -= num_test_examples % INCEPTION_BATCH # Get real images from the dataset. In the case of a 1-channel # dataset (like mnist) convert it to 3 channels. data_x = [] with tf.Graph().as_default(): get_next = dataset_content.make_one_shot_iterator().get_next() with tf.Session() as sess: for _ in range(num_test_examples): data_x.append(sess.run(get_next[0])) real_images = np.array(data_x) if real_images.shape[0] != num_test_examples: raise ValueError("Not enough examples in the dataset.") if real_images.shape[3] == 1: real_images = np.tile(real_images, [1, 1, 1, 3]) real_images *= 255.0 logging.info("Real data processed.") # Get Fake images from the generator. samples = [] logging.info("Running eval on checkpoint path: %s", checkpoint_path) with tf.Graph().as_default(): with tf.Session() as sess: gan = gan_lib.create_gan( gan_type=gan_type, dataset=dataset, sess=sess, dataset_content=dataset_content, options=options, checkpoint_dir=checkpoint_dir, result_dir=result_dir, gan_log_dir=gan_log_dir) gan.build_model(is_training=False) tf.global_variables_initializer().run() saver = tf.train.Saver() saver.restore(sess, checkpoint_path) # Make sure we have >= examples as in the test set. num_batches = int(np.ceil(num_test_examples / gan.batch_size)) for _ in range(num_batches): z_sample = gan.z_generator(gan.batch_size, gan.z_dim) feed_dict = {gan.z: z_sample} x = sess.run(gan.fake_images, feed_dict=feed_dict) # If NaNs were generated, ignore this checkpoint and assign a very high # FID score which we handle specially later. while np.isnan(x).any(): logging.error("Detected NaN in fake_images! Returning NaN.") return NAN_DETECTED, NAN_DETECTED samples.append(x) fake_images = np.concatenate(samples, axis=0) # Adjust the number of fake images to the number of images in the test set. fake_images = fake_images[:num_test_examples, :, :, :] # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel. if fake_images.shape[3] == 1: fake_images = np.tile(fake_images, [1, 1, 1, 3]) fake_images *= 255.0 logging.info("Fake data processed, computing inception score.") inception_score = GetInceptionScore(fake_images, inception_graph) logging.info("Inception score computed: %.3f", inception_score) assert fake_images.shape == real_images.shape fid_score = ComputeTFGanFIDScore(fake_images, real_images, inception_graph) logging.info("Frechet Inception Distance for checkpoint %s is %.3f", checkpoint_path, fid_score) return inception_score, fid_score