Пример #1
0
 def testLoadingTriangles(self):
     with tf.Graph().as_default():
         iterator = gan_lib.load_dataset("triangles").batch(
             32).make_one_shot_iterator().get_next()
         with tf.Session() as sess:
             image, label = sess.run(iterator)
             self.assertEqual(image.shape, (32, 28, 28, 1))
             self.assertEqual(label.shape, (32, ))
             self.assertEqual(label[4], 3)
     with tf.Graph().as_default():
         iterator = gan_lib.load_dataset(
             "triangles",
             split_name="test").make_one_shot_iterator().get_next()
         with tf.Session() as sess:
             image, label = sess.run(iterator)
             self.assertEqual(image.shape, (28, 28, 1))
             self.assertEqual(label.shape, ())
     with tf.Graph().as_default():
         iterator = gan_lib.load_dataset(
             "triangles",
             split_name="val").make_one_shot_iterator().get_next()
         with tf.Session() as sess:
             image, label = sess.run(iterator)
             self.assertEqual(image.shape, (28, 28, 1))
             self.assertEqual(label.shape, ())
Пример #2
0
 def testLoadingMnist(self):
     with tf.Graph().as_default():
         dataset = gan_lib.load_dataset("mnist")
         iterator = dataset.make_one_shot_iterator().get_next()
         with tf.Session() as sess:
             image, label = sess.run(iterator)
             self.assertEqual(image.shape, (28, 28, 1))
             self.assertEqual(label.shape, ())
Пример #3
0
 def testLoadingMnist(self):
   with tf.Graph().as_default():
     dataset = gan_lib.load_dataset("mnist")
     iterator = dataset.make_one_shot_iterator().get_next()
     with tf.Session() as sess:
       image, label = sess.run(iterator)
       self.assertEqual(image.shape, (28, 28, 1))
       self.assertEqual(label.shape, ())
Пример #4
0
 def testLoadingTriangles(self):
   with tf.Graph().as_default():
     iterator = gan_lib.load_dataset("triangles").batch(
         32).make_one_shot_iterator().get_next()
     with tf.Session() as sess:
       image, label = sess.run(iterator)
       self.assertEqual(image.shape, (32, 28, 28, 1))
       self.assertEqual(label.shape, (32,))
       self.assertEqual(label[4], 3)
   with tf.Graph().as_default():
     iterator = gan_lib.load_dataset(
         "triangles", split_name="test").make_one_shot_iterator().get_next()
     with tf.Session() as sess:
       image, label = sess.run(iterator)
       self.assertEqual(image.shape, (28, 28, 1))
       self.assertEqual(label.shape, ())
   with tf.Graph().as_default():
     iterator = gan_lib.load_dataset(
         "triangles", split_name="validation").make_one_shot_iterator(
             ).get_next()
     with tf.Session() as sess:
       image, label = sess.run(iterator)
       self.assertEqual(image.shape, (28, 28, 1))
       self.assertEqual(label.shape, ())
Пример #5
0
def GetRealImages(dataset,
                  split_name,
                  num_examples,
                  failure_on_insufficient_examples=True):
    """Get num_examples images from the given dataset/split."""
    # Multithread and buffer could improve the training speed by 20%, however it
    # consumes more memory. In evaluation, we used single thread without buffer
    # to avoid using too much memory.
    dataset_content = gan_lib.load_dataset(
        dataset,
        split_name=split_name,
        num_threads=1,
        buffer_size=0)
    # Get real images from the dataset. In the case of a 1-channel
    # dataset (like mnist) convert it to 3 channels.
    data_x = []
    with tf.Graph().as_default():
        get_next = dataset_content.make_one_shot_iterator().get_next()
        with tf.train.MonitoredTrainingSession() as sess:
            for i in range(num_examples):
                try:
                    data_x.append(sess.run(get_next[0]))
                except tf.errors.OutOfRangeError:
                    logging.error(
                        "Reached the end of dataset. Read: %d samples." % i)
                    break

    real_images = np.array(data_x)
    if real_images.shape[0] != num_examples:
        if failure_on_insufficient_examples:
            raise ValueError("Not enough examples in the dataset %s: %d / %d" %
                             (dataset, real_images.shape[0], num_examples))
        else:
            logging.error("Not enough examples in the dataset %s: %d / %d", dataset,
                          real_images.shape[0], num_examples)

    real_images *= 255.0
    return real_images
Пример #6
0
def TrainGILBO(gan, sess, outdir, checkpoint_path, dataset, options):
    """Build and train GILBO model.

  Args:
    gan: GAN object.
    sess: tf.Session.
    outdir: Output directory. A pickle file will be written there.
    checkpoint_path: Path where gan's checkpoints are written. Only used to
                     ensure that GILBO files are written to a unique
                     subdirectory of outdir.
    dataset: Name of dataset used to train the GAN.
    options: Options dictionary.

  Returns:
    mean_eval_info: Mean GILBO computed over a large number of images generated
                    by the trained GAN or VAE.
    mean_train_consistency: Mean consistency of the trained GILBO model with
                            data from the training set.
    mean_eval_consistency: Same consistency measure for the trained model with
                           data from the validation set.
    mean_self_consistency: Same consistency measure for the trained model with
                           data generated by the trained model itself.
    See the GILBO paper for an explanation of these metrics.

  Raises:
    ValueError: If the GAN has uninitialized variables.
  """
    uninitialized = sess.run(tf.report_uninitialized_variables())
    if uninitialized:
        raise ValueError('Model has uninitialized variables!\n%r' %
                         uninitialized)

    outdir = os.path.join(outdir, checkpoint_path.replace('/', '_'))
    if isinstance(gan, VAE):
        gan_type = 'VAE'
    else:
        gan_type = 'GAN'

    tf.gfile.MakeDirs(outdir)
    with tf.variable_scope('gilbo'):
        ones = tf.ones((gan.batch_size, gan.z_dim))
        # Get a distribution for the prior, depending on whether the model is a VAE
        # or a GAN.
        if gan_type == 'VAE':
            z_dist = ds.Independent(ds.Normal(0.0 * ones, ones), 1)
        else:
            z_dist = ds.Independent(ds.Uniform(-ones, ones), 1)

        z_sample = z_dist.sample()

        epsneg = np.finfo('float32').epsneg

        if gan_type == 'VAE':
            ganz_clip = gan.z
        else:
            # Clip samples from the GAN uniform prior because the Beta distribution
            # doesn't include the top endpoint and has issues with the bottom endpoint
            ganz_clip = tf.clip_by_value(gan.z, -(1 - epsneg), 1 - epsneg)

        # Get generated images from the model.
        fake_images = gan.fake_images

        # Build the regressor distribution that encodes images back to predicted
        # samples from the prior.
        with tf.variable_scope('regressor'):
            z_pred_dist = _Regressor(fake_images, gan.z_dim, gan_type)

        # Capture the parameters of the distributions for later analysis.
        if gan_type == 'VAE':
            dist_p1 = z_pred_dist.distribution.loc
            dist_p2 = z_pred_dist.distribution.scale
        else:
            dist_p1 = z_pred_dist.distribution.distribution.concentration0
            dist_p2 = z_pred_dist.distribution.distribution.concentration1

        # info and avg_info compute the GILBO.
        info = z_pred_dist.log_prob(ganz_clip) - z_dist.log_prob(ganz_clip)
        avg_info = tf.reduce_mean(info)

        # Set up training of the GILBO model.
        lr = options.get('gilbo_learning_rate', 4e-4)
        learning_rate = tf.get_variable('learning_rate',
                                        initializer=lr,
                                        trainable=False)
        gilbo_step = tf.get_variable('gilbo_step',
                                     dtype=tf.int32,
                                     initializer=0,
                                     trainable=False)
        opt = tf.train.AdamOptimizer(learning_rate)

        regressor_vars = tf.contrib.framework.get_variables('gilbo/regressor')
        train_op = opt.minimize(-info, var_list=regressor_vars)

    # Initialize the variables we just created.
    uninitialized = plist(tf.report_uninitialized_variables().eval())
    uninitialized_vars = uninitialized.apply(
        tf.contrib.framework.get_variables_by_name)._[0]
    tf.variables_initializer(uninitialized_vars).run()

    saver = tf.train.Saver(uninitialized_vars, max_to_keep=1)
    try:
        checkpoint_path = tf.train.latest_checkpoint(outdir)
        saver.restore(sess, checkpoint_path)
    except ValueError:
        # Failing to restore just indicates that we don't have a valid checkpoint,
        # so we will just start training a fresh GILBO model.
        pass

    _TrainGILBO(sess, gan, saver, learning_rate, gilbo_step, z_sample,
                avg_info, z_pred_dist, train_op, gan_type, outdir, options)

    mean_eval_info = _EvalGILBO(sess, gan, z_sample, avg_info, dist_p1,
                                dist_p2, fake_images, outdir, options)

    # Collect encoded distributions on the training and eval set in order to do
    # kl-nearest-neighbors on generated samples and measure consistency.
    train_images = gan_lib.load_dataset(dataset, split_name='train').apply(
        tf.contrib.data.batch_and_drop_remainder(
            gan.batch_size)).make_one_shot_iterator().get_next()[0]
    train_images = tf.reshape(train_images, fake_images.shape)

    eval_images = gan_lib.load_dataset(dataset, split_name='test').apply(
        tf.contrib.data.batch_and_drop_remainder(
            gan.batch_size)).make_one_shot_iterator().get_next()[0]
    eval_images = tf.reshape(eval_images, fake_images.shape)

    mean_train_consistency = _RunGILBOConsistency(train_images,
                                                  'train',
                                                  extract_input_images=0,
                                                  save_consistency_images=20,
                                                  num_batches=5,
                                                  **locals())
    mean_eval_consistency = _RunGILBOConsistency(eval_images,
                                                 'eval',
                                                 extract_input_images=0,
                                                 save_consistency_images=20,
                                                 num_batches=5,
                                                 **locals())
    mean_self_consistency = _RunGILBOConsistency(fake_images,
                                                 'self',
                                                 extract_input_images=20,
                                                 save_consistency_images=20,
                                                 num_batches=5,
                                                 **locals())

    return (mean_eval_info, mean_train_consistency, mean_eval_consistency,
            mean_self_consistency)
Пример #7
0
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph):
    """Evaluate model at given checkpoint_path."""

    # Make sure that the same latent variables are used for each evaluation.
    np.random.seed(42)

    checkpoint_dir = os.path.join(task_workdir, "checkpoint")
    result_dir = os.path.join(task_workdir, "result")
    gan_log_dir = os.path.join(task_workdir, "logs")

    gan_type = options["gan_type"]
    if gan_type not in SUPPORTED_GANS:
        raise ValueError("Gan type %s is not supported." % gan_type)
    dataset = options["dataset"]

    dataset_content = gan_lib.load_dataset(dataset, split_name="test")
    num_test_examples = FLAGS.num_test_examples

    if num_test_examples % INCEPTION_BATCH != 0:
        logging.info("Padding number of examples to fit inception batch.")
        num_test_examples -= num_test_examples % INCEPTION_BATCH

    # Get real images from the dataset. In the case of a 1-channel
    # dataset (like mnist) convert it to 3 channels.

    data_x = []
    with tf.Graph().as_default():
        get_next = dataset_content.make_one_shot_iterator().get_next()
        with tf.Session() as sess:
            for _ in range(num_test_examples):
                data_x.append(sess.run(get_next[0]))

    real_images = np.array(data_x)
    if real_images.shape[0] != num_test_examples:
        raise ValueError("Not enough examples in the dataset.")

    if real_images.shape[3] == 1:
        real_images = np.tile(real_images, [1, 1, 1, 3])
    real_images *= 255.0
    logging.info("Real data processed.")

    # Get Fake images from the generator.
    samples = []
    logging.info("Running eval on checkpoint path: %s", checkpoint_path)
    with tf.Graph().as_default():
        with tf.Session() as sess:
            gan = gan_lib.create_gan(gan_type=gan_type,
                                     dataset=dataset,
                                     sess=sess,
                                     dataset_content=dataset_content,
                                     options=options,
                                     checkpoint_dir=checkpoint_dir,
                                     result_dir=result_dir,
                                     gan_log_dir=gan_log_dir)

            gan.build_model(is_training=False)

            tf.global_variables_initializer().run()
            saver = tf.train.Saver()
            saver.restore(sess, checkpoint_path)

            # Make sure we have >= examples as in the test set.
            num_batches = int(np.ceil(num_test_examples / gan.batch_size))
            for _ in range(num_batches):
                z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
                feed_dict = {gan.z: z_sample}
                x = sess.run(gan.fake_images, feed_dict=feed_dict)
                # If NaNs were generated, ignore this checkpoint and assign a very high
                # FID score which we handle specially later.
                while np.isnan(x).any():
                    logging.error(
                        "Detected NaN in fake_images! Returning NaN.")
                    return NAN_DETECTED, NAN_DETECTED
                samples.append(x)

    fake_images = np.concatenate(samples, axis=0)
    # Adjust the number of fake images to the number of images in the test set.
    fake_images = fake_images[:num_test_examples, :, :, :]
    # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel.
    if fake_images.shape[3] == 1:
        fake_images = np.tile(fake_images, [1, 1, 1, 3])
    fake_images *= 255.0
    logging.info("Fake data processed, computing inception score.")
    inception_score = GetInceptionScore(fake_images, inception_graph)
    logging.info("Inception score computed: %.3f", inception_score)

    assert fake_images.shape == real_images.shape

    fid_score = ComputeTFGanFIDScore(fake_images, real_images, inception_graph)

    logging.info("Frechet Inception Distance for checkpoint %s is %.3f",
                 checkpoint_path, fid_score)
    return inception_score, fid_score
Пример #8
0
def RunCheckpointEval(checkpoint_path, task_workdir, options, inception_graph):
  """Evaluate model at given checkpoint_path."""

  # Make sure that the same latent variables are used for each evaluation.
  np.random.seed(42)

  checkpoint_dir = os.path.join(task_workdir, "checkpoint")
  result_dir = os.path.join(task_workdir, "result")
  gan_log_dir = os.path.join(task_workdir, "logs")

  gan_type = options["gan_type"]
  if gan_type not in SUPPORTED_GANS:
    raise ValueError("Gan type %s is not supported." % gan_type)
  dataset = options["dataset"]

  dataset_content = gan_lib.load_dataset(dataset, split_name="test")
  num_test_examples = FLAGS.num_test_examples

  if num_test_examples % INCEPTION_BATCH != 0:
    logging.info("Padding number of examples to fit inception batch.")
    num_test_examples -= num_test_examples % INCEPTION_BATCH

  # Get real images from the dataset. In the case of a 1-channel
  # dataset (like mnist) convert it to 3 channels.

  data_x = []
  with tf.Graph().as_default():
    get_next = dataset_content.make_one_shot_iterator().get_next()
    with tf.Session() as sess:
      for _ in range(num_test_examples):
        data_x.append(sess.run(get_next[0]))

  real_images = np.array(data_x)
  if real_images.shape[0] != num_test_examples:
    raise ValueError("Not enough examples in the dataset.")

  if real_images.shape[3] == 1:
    real_images = np.tile(real_images, [1, 1, 1, 3])
  real_images *= 255.0
  logging.info("Real data processed.")

  # Get Fake images from the generator.
  samples = []
  logging.info("Running eval on checkpoint path: %s", checkpoint_path)
  with tf.Graph().as_default():
    with tf.Session() as sess:
      gan = gan_lib.create_gan(
          gan_type=gan_type,
          dataset=dataset,
          sess=sess,
          dataset_content=dataset_content,
          options=options,
          checkpoint_dir=checkpoint_dir,
          result_dir=result_dir,
          gan_log_dir=gan_log_dir)

      gan.build_model(is_training=False)

      tf.global_variables_initializer().run()
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)

      # Make sure we have >= examples as in the test set.
      num_batches = int(np.ceil(num_test_examples / gan.batch_size))
      for _ in range(num_batches):
        z_sample = gan.z_generator(gan.batch_size, gan.z_dim)
        feed_dict = {gan.z: z_sample}
        x = sess.run(gan.fake_images, feed_dict=feed_dict)
        # If NaNs were generated, ignore this checkpoint and assign a very high
        # FID score which we handle specially later.
        while np.isnan(x).any():
          logging.error("Detected NaN in fake_images! Returning NaN.")
          return NAN_DETECTED, NAN_DETECTED
        samples.append(x)

  fake_images = np.concatenate(samples, axis=0)
  # Adjust the number of fake images to the number of images in the test set.
  fake_images = fake_images[:num_test_examples, :, :, :]
  # In case we use a 1-channel dataset (like mnist) - convert it to 3 channel.
  if fake_images.shape[3] == 1:
    fake_images = np.tile(fake_images, [1, 1, 1, 3])
  fake_images *= 255.0
  logging.info("Fake data processed, computing inception score.")
  inception_score = GetInceptionScore(fake_images, inception_graph)
  logging.info("Inception score computed: %.3f", inception_score)

  assert fake_images.shape == real_images.shape

  fid_score = ComputeTFGanFIDScore(fake_images, real_images, inception_graph)

  logging.info("Frechet Inception Distance for checkpoint %s is %.3f",
               checkpoint_path, fid_score)
  return inception_score, fid_score