Example #1
0
 def generator(self, z, is_training, reuse=False):
     if self.architecture == consts.INFOGAN_ARCH:
         return super(AbstractGANWithPenalty,
                      self).generator(z, is_training, reuse)
     elif self.architecture == consts.DCGAN_ARCH:
         return dcgan_architecture.generator(z, self.batch_size,
                                             self.output_height,
                                             self.output_width, self.c_dim,
                                             is_training, reuse)
     elif self.architecture == consts.RESNET5_ARCH:
         assert self.output_height == self.output_width
         return resnet_architecture.resnet5_generator(
             z,
             is_training=is_training,
             reuse=reuse,
             colors=self.c_dim,
             output_shape=self.output_height)
     elif self.architecture == consts.RESNET_STL:
         return resnet_architecture.resnet_stl_generator(
             z, is_training=is_training, reuse=reuse, colors=self.c_dim)
     elif self.architecture == consts.RESNET107_ARCH:
         return resnet_architecture.resnet107_generator(
             z, is_training=is_training, reuse=reuse, colors=self.c_dim)
     elif self.architecture == consts.RESNET_CIFAR:
         return resnet_architecture.resnet_cifar_generator(
             z, is_training=is_training, reuse=reuse, colors=self.c_dim)
     elif self.architecture == consts.SNDCGAN_ARCH:
         return dcgan_architecture.sn_generator(z, self.batch_size,
                                                self.output_height,
                                                self.output_width,
                                                self.c_dim, is_training,
                                                reuse)
     else:
         raise NotImplementedError("Architecture %s not implemented." %
                                   self.architecture)
def TestResnet5GeneratorShape(output_shape):
  config = tf.ConfigProto(allow_soft_placement=True)
  tf.reset_default_graph()
  batch_size = 8
  z_dim = 64
  os = output_shape
  with tf.Session(config=config) as sess:
    z = tf.random_normal([batch_size, z_dim])
    g = resnet_arch.resnet5_generator(
        noise=z, is_training=True, reuse=False, colors=3, output_shape=os)
    tf.global_variables_initializer().run()
    output = sess.run([g])
    return [output[0].shape, (batch_size, os, os, 3)]