Пример #1
0
 def generator(self, z, is_training, reuse=False):
     if self.architecture == consts.INFOGAN_ARCH:
         return super(AbstractGANWithPenalty,
                      self).generator(z, is_training, reuse)
     elif self.architecture == consts.DCGAN_ARCH:
         return dcgan_architecture.generator(z, self.batch_size,
                                             self.output_height,
                                             self.output_width, self.c_dim,
                                             is_training, reuse)
     elif self.architecture == consts.RESNET5_ARCH:
         assert self.output_height == self.output_width
         return resnet_architecture.resnet5_generator(
             z,
             is_training=is_training,
             reuse=reuse,
             colors=self.c_dim,
             output_shape=self.output_height)
     elif self.architecture == consts.RESNET_STL:
         return resnet_architecture.resnet_stl_generator(
             z, is_training=is_training, reuse=reuse, colors=self.c_dim)
     elif self.architecture == consts.RESNET107_ARCH:
         return resnet_architecture.resnet107_generator(
             z, is_training=is_training, reuse=reuse, colors=self.c_dim)
     elif self.architecture == consts.RESNET_CIFAR:
         return resnet_architecture.resnet_cifar_generator(
             z, is_training=is_training, reuse=reuse, colors=self.c_dim)
     elif self.architecture == consts.SNDCGAN_ARCH:
         return dcgan_architecture.sn_generator(z, self.batch_size,
                                                self.output_height,
                                                self.output_width,
                                                self.c_dim, is_training,
                                                reuse)
     else:
         raise NotImplementedError("Architecture %s not implemented." %
                                   self.architecture)
 def testResnet107GeneratorRuns(self):
   config = tf.ConfigProto(allow_soft_placement=True)
   tf.reset_default_graph()
   batch_size = 8
   z_dim = 64
   with tf.Session(config=config) as sess:
     z = tf.random_normal([batch_size, z_dim])
     g = resnet_arch.resnet107_generator(
         noise=z, is_training=True, reuse=False, colors=3)
     tf.global_variables_initializer().run()
     output = sess.run([g])
     self.assertEquals(output[0].shape, (batch_size, 128, 128, 3))