def testDefaultGeneratorWithConditionalBatchNorm(self):
     with tf.Graph().as_default():
         # Batch size 8, 32x32x3 images, 10 classes.
         z = tf.zeros((8, 128))
         y = tf.one_hot(tf.ones((8, ), dtype=tf.int32), 10)
         generator = resnet_cifar.Generator(
             image_shape=(32, 32, 3),
             batch_norm_fn=arch_ops.conditional_batch_norm)
         fake_images = generator(z, y=y, is_training=True, reuse=False)
         self.assertEqual(fake_images.shape.as_list(), [8, 32, 32, 3])
         expected_variables = [
             # Name and shape.
             ("generator/fc_noise/kernel:0", [128, 4096]),
             ("generator/fc_noise/bias:0", [4096]),
             ("generator/B1/up_conv_shortcut/kernel:0", [3, 3, 256, 256]),
             ("generator/B1/up_conv_shortcut/bias:0", [256]),
             ("generator/B1/bn1/condition/gamma/kernel:0", [10, 256]),
             ("generator/B1/bn1/condition/beta/kernel:0", [10, 256]),
             ("generator/B1/up_conv1/kernel:0", [3, 3, 256, 256]),
             ("generator/B1/up_conv1/bias:0", [256]),
             ("generator/B1/bn2/condition/gamma/kernel:0", [10, 256]),
             ("generator/B1/bn2/condition/beta/kernel:0", [10, 256]),
             ("generator/B1/same_conv2/kernel:0", [3, 3, 256, 256]),
             ("generator/B1/same_conv2/bias:0", [256]),
             ("generator/B2/up_conv_shortcut/kernel:0", [3, 3, 256, 256]),
             ("generator/B2/up_conv_shortcut/bias:0", [256]),
             ("generator/B2/bn1/condition/gamma/kernel:0", [10, 256]),
             ("generator/B2/bn1/condition/beta/kernel:0", [10, 256]),
             ("generator/B2/up_conv1/kernel:0", [3, 3, 256, 256]),
             ("generator/B2/up_conv1/bias:0", [256]),
             ("generator/B2/bn2/condition/gamma/kernel:0", [10, 256]),
             ("generator/B2/bn2/condition/beta/kernel:0", [10, 256]),
             ("generator/B2/same_conv2/kernel:0", [3, 3, 256, 256]),
             ("generator/B2/same_conv2/bias:0", [256]),
             ("generator/B3/up_conv_shortcut/kernel:0", [3, 3, 256, 256]),
             ("generator/B3/up_conv_shortcut/bias:0", [256]),
             ("generator/B3/bn1/condition/gamma/kernel:0", [10, 256]),
             ("generator/B3/bn1/condition/beta/kernel:0", [10, 256]),
             ("generator/B3/up_conv1/kernel:0", [3, 3, 256, 256]),
             ("generator/B3/up_conv1/bias:0", [256]),
             ("generator/B3/bn2/condition/gamma/kernel:0", [10, 256]),
             ("generator/B3/bn2/condition/beta/kernel:0", [10, 256]),
             ("generator/B3/same_conv2/kernel:0", [3, 3, 256, 256]),
             ("generator/B3/same_conv2/bias:0", [256]),
             ("generator/final_norm/condition/gamma/kernel:0", [10, 256]),
             ("generator/final_norm/condition/beta/kernel:0", [10, 256]),
             ("generator/final_conv/kernel:0", [3, 3, 256, 3]),
             ("generator/final_conv/bias:0", [3]),
         ]
         actual_variables = [(v.name, v.shape.as_list())
                             for v in tf.trainable_variables()]
         for a in actual_variables:
             logging.info(a)
         for a, e in zip(actual_variables, expected_variables):
             logging.info("actual: %s, expected: %s", a, e)
             self.assertEqual(a, e)
         self.assertEqual(len(actual_variables), len(expected_variables))
Esempio n. 2
0
 def testResNetCifar(self, image_shape):
   self.assertArchitectureBuilds(
       gen=resnet_cifar.Generator(image_shape=image_shape),
       disc=resnet_cifar.Discriminator(),
       image_shape=image_shape)