def testLoadingTriangles(self): with tf.Graph().as_default(): iterator = gan_lib.load_dataset("triangles").batch( 32).make_one_shot_iterator().get_next() with tf.Session() as sess: image, label = sess.run(iterator) self.assertEqual(image.shape, (32, 28, 28, 1)) self.assertEqual(label.shape, (32, )) self.assertEqual(label[4], 3) with tf.Graph().as_default(): iterator = gan_lib.load_dataset( "triangles", split_name="test").make_one_shot_iterator().get_next() with tf.Session() as sess: image, label = sess.run(iterator) self.assertEqual(image.shape, (28, 28, 1)) self.assertEqual(label.shape, ()) with tf.Graph().as_default(): iterator = gan_lib.load_dataset( "triangles", split_name="val").make_one_shot_iterator().get_next() with tf.Session() as sess: image, label = sess.run(iterator) self.assertEqual(image.shape, (28, 28, 1)) self.assertEqual(label.shape, ())
def testLoadingMnist(self): with tf.Graph().as_default(): dataset = gan_lib.load_dataset("mnist") iterator = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: image, label = sess.run(iterator) self.assertEqual(image.shape, (28, 28, 1)) self.assertEqual(label.shape, ())
def testLoadingMnist(self): with tf.Graph().as_default(): dataset = gan_lib.load_dataset("mnist") iterator = dataset.make_one_shot_iterator().get_next() with tf.Session() as sess: image, label = sess.run(iterator) self.assertEqual(image.shape, (28, 28, 1)) self.assertEqual(label.shape, ())
def testLoadingTriangles(self): with tf.Graph().as_default(): iterator = gan_lib.load_dataset("triangles").batch( 32).make_one_shot_iterator().get_next() with tf.Session() as sess: image, label = sess.run(iterator) self.assertEqual(image.shape, (32, 28, 28, 1)) self.assertEqual(label.shape, (32,)) self.assertEqual(label[4], 3) with tf.Graph().as_default(): iterator = gan_lib.load_dataset( "triangles", split_name="test").make_one_shot_iterator().get_next() with tf.Session() as sess: image, label = sess.run(iterator) self.assertEqual(image.shape, (28, 28, 1)) self.assertEqual(label.shape, ()) with tf.Graph().as_default(): iterator = gan_lib.load_dataset( "triangles", split_name="validation").make_one_shot_iterator( ).get_next() with tf.Session() as sess: image, label = sess.run(iterator) self.assertEqual(image.shape, (28, 28, 1)) self.assertEqual(label.shape, ())
def GetRealImages(dataset, split_name, num_examples, failure_on_insufficient_examples=True): """Get num_examples images from the given dataset/split.""" # Multithread and buffer could improve the training speed by 20%, however it # consumes more memory. In evaluation, we used single thread without buffer # to avoid using too much memory. dataset_content = gan_lib.load_dataset( dataset, split_name=split_name, num_threads=1, buffer_size=0) # Get real images from the dataset. In the case of a 1-channel # dataset (like mnist) convert it to 3 channels. data_x = [] with tf.Graph().as_default(): get_next = dataset_content.make_one_shot_iterator().get_next() with tf.train.MonitoredTrainingSession() as sess: for i in range(num_examples): try: data_x.append(sess.run(get_next[0])) except tf.errors.OutOfRangeError: logging.error( "Reached the end of dataset. Read: %d samples." % i) break real_images = np.array(data_x) if real_images.shape[0] != num_examples: if failure_on_insufficient_examples: raise ValueError("Not enough examples in the dataset %s: %d / %d" % (dataset, real_images.shape[0], num_examples)) else: logging.error("Not enough examples in the dataset %s: %d / %d", dataset, real_images.shape[0], num_examples) real_images *= 255.0 return real_images
def TrainGILBO(gan, sess, outdir, checkpoint_path, dataset, options): """Build and train GILBO model. Args: gan: GAN object. sess: tf.Session. outdir: Output directory. A pickle file will be written there. checkpoint_path: Path where gan's checkpoints are written. Only used to ensure that GILBO files are written to a unique subdirectory of outdir. dataset: Name of dataset used to train the GAN. options: Options dictionary. Returns: mean_eval_info: Mean GILBO computed over a large number of images generated by the trained GAN or VAE. mean_train_consistency: Mean consistency of the trained GILBO model with data from the training set. mean_eval_consistency: Same consistency measure for the trained model with data from the validation set. mean_self_consistency: Same consistency measure for the trained model with data generated by the trained model itself. See the GILBO paper for an explanation of these metrics. Raises: ValueError: If the GAN has uninitialized variables. """ uninitialized = sess.run(tf.report_uninitialized_variables()) if uninitialized: raise ValueError('Model has uninitialized variables!\n%r' % uninitialized) outdir = os.path.join(outdir, checkpoint_path.replace('/', '_')) if isinstance(gan, VAE): gan_type = 'VAE' else: gan_type = 'GAN' tf.gfile.MakeDirs(outdir) with tf.variable_scope('gilbo'): ones = tf.ones((gan.batch_size, gan.z_dim)) # Get a distribution for the prior, depending on whether the model is a VAE # or a GAN. if gan_type == 'VAE': z_dist = ds.Independent(ds.Normal(0.0 * ones, ones), 1) else: z_dist = ds.Independent(ds.Uniform(-ones, ones), 1) z_sample = z_dist.sample() epsneg = np.finfo('float32').epsneg if gan_type == 'VAE': ganz_clip = gan.z else: # Clip samples from the GAN uniform prior because the Beta distribution # doesn't include the top endpoint and has issues with the bottom endpoint ganz_clip = tf.clip_by_value(gan.z, -(1 - epsneg), 1 - epsneg) # Get generated images from the model. fake_images = gan.fake_images # Build the regressor distribution that encodes images back to predicted # samples from the prior. with tf.variable_scope('regressor'): z_pred_dist = _Regressor(fake_images, gan.z_dim, gan_type) # Capture the parameters of the distributions for later analysis. if gan_type == 'VAE': dist_p1 = z_pred_dist.distribution.loc dist_p2 = z_pred_dist.distribution.scale else: dist_p1 = z_pred_dist.distribution.distribution.concentration0 dist_p2 = z_pred_dist.distribution.distribution.concentration1 # info and avg_info compute the GILBO. info = z_pred_dist.log_prob(ganz_clip) - z_dist.log_prob(ganz_clip) avg_info = tf.reduce_mean(info) # Set up training of the GILBO model. lr = options.get('gilbo_learning_rate', 4e-4) learning_rate = tf.get_variable('learning_rate', initializer=lr, trainable=False) gilbo_step = tf.get_variable('gilbo_step', dtype=tf.int32, initializer=0, trainable=False) opt = tf.train.AdamOptimizer(learning_rate) regressor_vars = tf.contrib.framework.get_variables('gilbo/regressor') train_op = opt.minimize(-info, var_list=regressor_vars) # Initialize the variables we just created. uninitialized = plist(tf.report_uninitialized_variables().eval()) uninitialized_vars = uninitialized.apply( tf.contrib.framework.get_variables_by_name)._[0] tf.variables_initializer(uninitialized_vars).run() saver = tf.train.Saver(uninitialized_vars, max_to_keep=1) try: checkpoint_path = tf.train.latest_checkpoint(outdir) saver.restore(sess, checkpoint_path) except ValueError: # Failing to restore just indicates that we don't have a valid checkpoint, # so we will just start training a fresh GILBO model. pass _TrainGILBO(sess, gan, saver, learning_rate, gilbo_step, z_sample, avg_info, z_pred_dist, train_op, gan_type, outdir, options) mean_eval_info = _EvalGILBO(sess, gan, z_sample, avg_info, dist_p1, dist_p2, fake_images, outdir, options) # Collect encoded distributions on the training and eval set in order to do # kl-nearest-neighbors on generated samples and measure consistency. train_images = gan_lib.load_dataset(dataset, split_name='train').apply( tf.contrib.data.batch_and_drop_remainder( gan.batch_size)).make_one_shot_iterator().get_next()[0] train_images = tf.reshape(train_images, fake_images.shape) eval_images = gan_lib.load_dataset(dataset, split_name='test').apply( tf.contrib.data.batch_and_drop_remainder( gan.batch_size)).make_one_shot_iterator().get_next()[0] eval_images = tf.reshape(eval_images, fake_images.shape) mean_train_consistency = _RunGILBOConsistency(train_images, 'train', extract_input_images=0, save_consistency_images=20, num_batches=5, **locals()) mean_eval_consistency = _RunGILBOConsistency(eval_images, 'eval', extract_input_images=0, save_consistency_images=20, num_batches=5, **locals()) mean_self_consistency = _RunGILBOConsistency(fake_images, 'self', extract_input_images=20, save_consistency_images=20, num_batches=5, **locals()) return (mean_eval_info, mean_train_consistency, mean_eval_consistency, mean_self_consistency)
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph): """Evaluate model at given checkpoint_path.""" # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) checkpoint_dir = os.path.join(task_workdir, "checkpoint") result_dir = os.path.join(task_workdir, "result") gan_log_dir = os.path.join(task_workdir, "logs") gan_type = options["gan_type"] if gan_type not in SUPPORTED_GANS: raise ValueError("Gan type %s is not supported." % gan_type) dataset = options["dataset"] dataset_content = gan_lib.load_dataset(dataset, split_name="test") num_test_examples = FLAGS.num_test_examples if num_test_examples % INCEPTION_BATCH != 0: logging.info("Padding number of examples to fit inception batch.") num_test_examples -= num_test_examples % INCEPTION_BATCH # Get real images from the dataset. In the case of a 1-channel # dataset (like mnist) convert it to 3 channels. data_x = [] with tf.Graph().as_default(): get_next = dataset_content.make_one_shot_iterator().get_next() with tf.Session() as sess: for _ in range(num_test_examples): data_x.append(sess.run(get_next[0])) real_images = np.array(data_x) if real_images.shape[0] != num_test_examples: raise ValueError("Not enough examples in the dataset.") if real_images.shape[3] == 1: real_images = np.tile(real_images, [1, 1, 1, 3]) real_images *= 255.0 logging.info("Real data processed.") # Get Fake images from the generator. samples = [] logging.info("Running eval on checkpoint path: %s", checkpoint_path) with tf.Graph().as_default(): with tf.Session() as sess: gan = gan_lib.create_gan(gan_type=gan_type, dataset=dataset, sess=sess, dataset_content=dataset_content, options=options, checkpoint_dir=checkpoint_dir, result_dir=result_dir, gan_log_dir=gan_log_dir) gan.build_model(is_training=False) tf.global_variables_initializer().run() saver = tf.train.Saver() saver.restore(sess, checkpoint_path) # Make sure we have >= examples as in the test set. num_batches = int(np.ceil(num_test_examples / gan.batch_size)) for _ in range(num_batches): z_sample = gan.z_generator(gan.batch_size, gan.z_dim) feed_dict = {gan.z: z_sample} x = sess.run(gan.fake_images, feed_dict=feed_dict) # If NaNs were generated, ignore this checkpoint and assign a very high # FID score which we handle specially later. while np.isnan(x).any(): logging.error( "Detected NaN in fake_images! Returning NaN.") return NAN_DETECTED, NAN_DETECTED samples.append(x) fake_images = np.concatenate(samples, axis=0) # Adjust the number of fake images to the number of images in the test set. fake_images = fake_images[:num_test_examples, :, :, :] # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel. if fake_images.shape[3] == 1: fake_images = np.tile(fake_images, [1, 1, 1, 3]) fake_images *= 255.0 logging.info("Fake data processed, computing inception score.") inception_score = GetInceptionScore(fake_images, inception_graph) logging.info("Inception score computed: %.3f", inception_score) assert fake_images.shape == real_images.shape fid_score = ComputeTFGanFIDScore(fake_images, real_images, inception_graph) logging.info("Frechet Inception Distance for checkpoint %s is %.3f", checkpoint_path, fid_score) return inception_score, fid_score
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph): """Evaluate model at given checkpoint_path.""" # Make sure that the same latent variables are used for each evaluation. np.random.seed(42) checkpoint_dir = os.path.join(task_workdir, "checkpoint") result_dir = os.path.join(task_workdir, "result") gan_log_dir = os.path.join(task_workdir, "logs") gan_type = options["gan_type"] if gan_type not in SUPPORTED_GANS: raise ValueError("Gan type %s is not supported." % gan_type) dataset = options["dataset"] dataset_content = gan_lib.load_dataset(dataset, split_name="test") num_test_examples = FLAGS.num_test_examples if num_test_examples % INCEPTION_BATCH != 0: logging.info("Padding number of examples to fit inception batch.") num_test_examples -= num_test_examples % INCEPTION_BATCH # Get real images from the dataset. In the case of a 1-channel # dataset (like mnist) convert it to 3 channels. data_x = [] with tf.Graph().as_default(): get_next = dataset_content.make_one_shot_iterator().get_next() with tf.Session() as sess: for _ in range(num_test_examples): data_x.append(sess.run(get_next[0])) real_images = np.array(data_x) if real_images.shape[0] != num_test_examples: raise ValueError("Not enough examples in the dataset.") if real_images.shape[3] == 1: real_images = np.tile(real_images, [1, 1, 1, 3]) real_images *= 255.0 logging.info("Real data processed.") # Get Fake images from the generator. samples = [] logging.info("Running eval on checkpoint path: %s", checkpoint_path) with tf.Graph().as_default(): with tf.Session() as sess: gan = gan_lib.create_gan( gan_type=gan_type, dataset=dataset, sess=sess, dataset_content=dataset_content, options=options, checkpoint_dir=checkpoint_dir, result_dir=result_dir, gan_log_dir=gan_log_dir) gan.build_model(is_training=False) tf.global_variables_initializer().run() saver = tf.train.Saver() saver.restore(sess, checkpoint_path) # Make sure we have >= examples as in the test set. num_batches = int(np.ceil(num_test_examples / gan.batch_size)) for _ in range(num_batches): z_sample = gan.z_generator(gan.batch_size, gan.z_dim) feed_dict = {gan.z: z_sample} x = sess.run(gan.fake_images, feed_dict=feed_dict) # If NaNs were generated, ignore this checkpoint and assign a very high # FID score which we handle specially later. while np.isnan(x).any(): logging.error("Detected NaN in fake_images! Returning NaN.") return NAN_DETECTED, NAN_DETECTED samples.append(x) fake_images = np.concatenate(samples, axis=0) # Adjust the number of fake images to the number of images in the test set. fake_images = fake_images[:num_test_examples, :, :, :] # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel. if fake_images.shape[3] == 1: fake_images = np.tile(fake_images, [1, 1, 1, 3]) fake_images *= 255.0 logging.info("Fake data processed, computing inception score.") inception_score = GetInceptionScore(fake_images, inception_graph) logging.info("Inception score computed: %.3f", inception_score) assert fake_images.shape == real_images.shape fid_score = ComputeTFGanFIDScore(fake_images, real_images, inception_graph) logging.info("Frechet Inception Distance for checkpoint %s is %.3f", checkpoint_path, fid_score) return inception_score, fid_score