Пример #1
0
 def discriminator(self, x, is_training, reuse=False):
     if self.architecture == consts.INFOGAN_ARCH:
         return super(AbstractGANWithPenalty,
                      self).discriminator(x, is_training, reuse)
     elif self.architecture == consts.DCGAN_ARCH:
         return dcgan_architecture.discriminator(
             x, self.batch_size, is_training,
             self.discriminator_normalization, reuse)
     elif self.architecture == consts.RESNET5_ARCH:
         return resnet_architecture.resnet5_discriminator(
             x, is_training, self.discriminator_normalization, reuse)
     elif self.architecture == consts.RESNET107_ARCH:
         return resnet_architecture.resnet107_discriminator(
             x, is_training, self.discriminator_normalization, reuse)
     elif self.architecture == consts.RESNET_CIFAR:
         return resnet_architecture.resnet_cifar_discriminator(
             x, is_training, self.discriminator_normalization, reuse)
     elif self.architecture == consts.RESNET_STL:
         return resnet_architecture.resnet_stl_discriminator(
             x, is_training, self.discriminator_normalization, reuse)
     elif self.architecture == consts.SNDCGAN_ARCH:
         assert self.discriminator_normalization in [
             consts.SPECTRAL_NORM, consts.NO_NORMALIZATION
         ]
         return dcgan_architecture.sn_discriminator(
             x,
             self.batch_size,
             reuse,
             use_sn=self.discriminator_normalization ==
             consts.SPECTRAL_NORM)
     else:
         raise NotImplementedError("Architecture %s not implemented." %
                                   self.architecture)
 def testResnet5DiscriminatorRuns(self):
   config = tf.ConfigProto(allow_soft_placement=True)
   tf.reset_default_graph()
   batch_size = 8
   with tf.Session(config=config) as sess:
     images = tf.random_normal([batch_size, 128, 128, 3])
     out, _, _ = resnet_arch.resnet5_discriminator(
         images, is_training=True, discriminator_normalization="spectral_norm",
         reuse=False)
     tf.global_variables_initializer().run()
     output = sess.run([out])
     self.assertEquals(output[0].shape, (batch_size, 1))