Beispiel #1
0
    def generator(noise, mode):
        """TF-GAN compatible generator function."""
        batch_size = tf.shape(input=noise)[0]
        is_train = (mode == tf.estimator.ModeKeys.TRAIN)

        # Some label trickery.
        gen_class_logits = tf.zeros((batch_size, hparams.num_classes))
        gen_class_ints = tf.random.categorical(logits=gen_class_logits,
                                               num_samples=1)
        gen_sparse_class = tf.squeeze(gen_class_ints, -1)
        gen_sparse_class.shape.assert_is_compatible_with([None])

        if hparams.debug_params.fake_nets:
            gen_imgs = tf.zeros([batch_size, 128, 128, 3]) * tf.get_variable(
                'dummy_g', initializer=2.0)
            generator_vars = ()
        else:
            gen_imgs, generator_vars = gen_module.generator(
                noise,
                gen_sparse_class,
                hparams.gf_dim,
                hparams.num_classes,
                training=is_train)
        # Print debug statistics and log the generated variables.
        gen_imgs, gen_sparse_class = eval_lib.print_debug_statistics(
            gen_imgs, gen_sparse_class, 'generator',
            hparams.tpu_params.use_tpu_estimator)
        eval_lib.log_and_summarize_variables(
            generator_vars, 'gvars', hparams.tpu_params.use_tpu_estimator)
        gen_imgs.shape.assert_is_compatible_with([None, 128, 128, 3])

        if mode == tf.estimator.ModeKeys.PREDICT:
            return gen_imgs
        else:
            return {'images': gen_imgs, 'labels': gen_sparse_class}
Beispiel #2
0
    def generator(noise_and_lbls, mode):
        """TF-GAN compatible generator function."""
        noise, labs = noise_and_lbls['z'], noise_and_lbls['labels']
        batch_size = tf.shape(input=noise)[0]
        is_train = (mode == tf.estimator.ModeKeys.TRAIN)

        # labs.shape.assert_is_compatible_with([None]) # not correct for gen_images

        if hparams.debug_params.fake_nets:
            gen_imgs = tf.zeros([
                batch_size, flags.FLAGS.image_size, flags.FLAGS.image_size, 3
            ]) * tf.compat.v1.get_variable('dummy_g', initializer=2.0)
            generator_vars = ()
        else:
            gen_imgs, generator_vars = gen_module.generator(
                noise,
                labs,
                hparams.gf_dim,
                hparams.num_classes,
                training=is_train)
        # Print debug statistics and log the generated variables.
        gen_imgs, gen_sparse_class = eval_lib.print_debug_statistics(
            gen_imgs, labs, 'generator', hparams.tpu_params.use_tpu_estimator)
        eval_lib.log_and_summarize_variables(
            generator_vars, 'gvars', hparams.tpu_params.use_tpu_estimator)
        gen_imgs.shape.assert_is_compatible_with(
            [None, flags.FLAGS.image_size, flags.FLAGS.image_size, 3])

        if mode == tf.estimator.ModeKeys.PREDICT and not flags.FLAGS.gen_images_with_margins:
            return gen_imgs
        else:
            return {'images': gen_imgs, 'labels': labs}
Beispiel #3
0
  def test_generator_shapes_and_ranges(self):
    """Tests the generator.

    Make sure the image shapes and pixel value ranges are as expected.
    """
    if tf.executing_eagerly():
      # `compute_spectral_norm` doesn't work when executing eagerly.
      return
    batch_size = 10
    num_classes = 1000
    zs = tf.random.normal((batch_size, 128))
    gen_class_logits = tf.zeros((batch_size, num_classes))
    gen_class_ints = tf.multinomial(gen_class_logits, 1)
    gen_sparse_class = tf.squeeze(gen_class_ints)
    images, var_list = generator.generator(
        zs, gen_sparse_class, gf_dim=32, num_classes=num_classes)
    sess = tf.train.MonitoredTrainingSession()
    images_np = sess.run(images)
    self.assertEqual((batch_size, 128, 128, 3), images_np.shape)
    self.assertAllInRange(images_np, -1.0, 1.0)
    self.assertIsInstance(var_list, list)