def _build_generator(self): layer_list = [ tf.keras.layers.Reshape([4, 4, self.latent_dim // 16]), layers.Conv(256, ks=4, strides=2, transpose=True), layers.Act('relu'), layers.Conv(128, ks=4, strides=2, transpose=True), layers.Act('relu'), layers.Conv(FLAGS.out_dim, ks=4, strides=2, transpose=True) ] return self.sequential(layer_list)
def _build_inference(self): layer_list = [ layers.Conv(32, ks=3, strides=2), layers.Act('relu'), layers.Conv(64, ks=3, strides=2), layers.Act('relu'), tf.keras.layers.Flatten(), layers.Dense(self.latent_dim + self.latent_dim) ] return self.sequential(layer_list)
def _build_encoder(self): layer_list = [ layers.Conv(16, ks=4, strides=2), layers.Norm(), layers.Conv(64, ks=4, strides=2), layers.Norm(), layers.Conv(256, ks=4, strides=2), layers.Norm(), layers.Conv(512, ks=4, strides=2) ] return self.sequential(layer_list)
def _build_generator(self): layer_list = [ layers.Dense(7 * 7 * 32), layers.Act('relu'), tf.keras.layers.Reshape([7, 7, 32]), layers.Conv(64, ks=3, strides=2, transpose=True), layers.Act('relu'), layers.Conv(32, ks=3, strides=2, transpose=True), layers.Act('relu'), layers.Conv(FLAGS.out_dim, ks=3, strides=1, transpose=True) ] return self.sequential(layer_list)
def __init__(self, filters, ks, strides=1, preact=False, use_norm=True, use_act=True, use_bias=False, last_norm=False, transpose=False): super(BasicBlock, self).__init__(filters, strides=strides) self.preact = preact self.use_norm = use_norm self.use_act = use_act self.last_norm = last_norm if self.use_norm: norm_scale = False if 'relu' in FLAGS.conv_act.lower() else True self.bn = layers.Norm(scale=norm_scale) if self.use_act: self.act = layers.Act() self.conv = layers.Conv(filters, ks, strides=strides, use_bias=use_bias, transpose=transpose) if self.last_norm: self.last_bn = layers.Norm()