def generator(noise, mode): """TF-GAN compatible generator function.""" batch_size = tf.shape(input=noise)[0] is_train = (mode == tf.estimator.ModeKeys.TRAIN) # Some label trickery. gen_class_logits = tf.zeros((batch_size, hparams.num_classes)) gen_class_ints = tf.random.categorical(logits=gen_class_logits, num_samples=1) gen_sparse_class = tf.squeeze(gen_class_ints, -1) gen_sparse_class.shape.assert_is_compatible_with([None]) if hparams.debug_params.fake_nets: gen_imgs = tf.zeros([batch_size, 128, 128, 3]) * tf.get_variable( 'dummy_g', initializer=2.0) generator_vars = () else: gen_imgs, generator_vars = gen_module.generator( noise, gen_sparse_class, hparams.gf_dim, hparams.num_classes, training=is_train) # Print debug statistics and log the generated variables. gen_imgs, gen_sparse_class = eval_lib.print_debug_statistics( gen_imgs, gen_sparse_class, 'generator', hparams.tpu_params.use_tpu_estimator) eval_lib.log_and_summarize_variables( generator_vars, 'gvars', hparams.tpu_params.use_tpu_estimator) gen_imgs.shape.assert_is_compatible_with([None, 128, 128, 3]) if mode == tf.estimator.ModeKeys.PREDICT: return gen_imgs else: return {'images': gen_imgs, 'labels': gen_sparse_class}
def generator(noise_and_lbls, mode): """TF-GAN compatible generator function.""" noise, labs = noise_and_lbls['z'], noise_and_lbls['labels'] batch_size = tf.shape(input=noise)[0] is_train = (mode == tf.estimator.ModeKeys.TRAIN) # labs.shape.assert_is_compatible_with([None]) # not correct for gen_images if hparams.debug_params.fake_nets: gen_imgs = tf.zeros([ batch_size, flags.FLAGS.image_size, flags.FLAGS.image_size, 3 ]) * tf.compat.v1.get_variable('dummy_g', initializer=2.0) generator_vars = () else: gen_imgs, generator_vars = gen_module.generator( noise, labs, hparams.gf_dim, hparams.num_classes, training=is_train) # Print debug statistics and log the generated variables. gen_imgs, gen_sparse_class = eval_lib.print_debug_statistics( gen_imgs, labs, 'generator', hparams.tpu_params.use_tpu_estimator) eval_lib.log_and_summarize_variables( generator_vars, 'gvars', hparams.tpu_params.use_tpu_estimator) gen_imgs.shape.assert_is_compatible_with( [None, flags.FLAGS.image_size, flags.FLAGS.image_size, 3]) if mode == tf.estimator.ModeKeys.PREDICT and not flags.FLAGS.gen_images_with_margins: return gen_imgs else: return {'images': gen_imgs, 'labels': labs}
def test_generator_shapes_and_ranges(self): """Tests the generator. Make sure the image shapes and pixel value ranges are as expected. """ if tf.executing_eagerly(): # `compute_spectral_norm` doesn't work when executing eagerly. return batch_size = 10 num_classes = 1000 zs = tf.random.normal((batch_size, 128)) gen_class_logits = tf.zeros((batch_size, num_classes)) gen_class_ints = tf.multinomial(gen_class_logits, 1) gen_sparse_class = tf.squeeze(gen_class_ints) images, var_list = generator.generator( zs, gen_sparse_class, gf_dim=32, num_classes=num_classes) sess = tf.train.MonitoredTrainingSession() images_np = sess.run(images) self.assertEqual((batch_size, 128, 128, 3), images_np.shape) self.assertAllInRange(images_np, -1.0, 1.0) self.assertIsInstance(var_list, list)