def build_generator(input_shape=(256, 256, 3), num_blocks=9): """Generator network architecture""" x0 = layers.Input(input_shape) x = ReflectionPadding2D(padding=(3, 3))(x0) x = layers.Conv2D(filters=64, kernel_size=7, strides=1, kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.ReLU()(x) # downsample x = layers.Conv2D(filters=128, kernel_size=3, strides=2, padding='same', kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.ReLU()(x) x = layers.Conv2D(filters=256, kernel_size=3, strides=2, padding='same', kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.ReLU()(x) # residual for _ in range(num_blocks): x = _resblock(x) # upsample x = layers.Conv2DTranspose(filters=128, kernel_size=3, strides=2, padding='same', kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.ReLU()(x) x = layers.Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.ReLU()(x) # final x = ReflectionPadding2D(padding=(3, 3))(x) x = layers.Conv2D(filters=3, kernel_size=7, activation='tanh', kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) return Model(inputs=x0, outputs=x)
def _resblock(x0, num_filter=256, kernel_size=3): """Residual block architecture""" x = ReflectionPadding2D()(x0) x = layers.Conv2D(filters=num_filter, kernel_size=kernel_size, kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.ReLU()(x) x = ReflectionPadding2D()(x) x = layers.Conv2D(filters=num_filter, kernel_size=kernel_size, kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.Add()([x, x0]) return x
def build_discriminator(input_shape=(256, 256, 3)): """Discriminator network architecture""" x0 = layers.Input(input_shape) x = layers.Conv2D(filters=64, kernel_size=4, strides=2, padding='same', kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x0) x = layers.LeakyReLU(0.2)(x) x = layers.Conv2D(filters=128, kernel_size=4, strides=2, padding='same', kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.LeakyReLU(0.2)(x) x = layers.Conv2D(filters=256, kernel_size=4, strides=2, padding='same', kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.LeakyReLU(0.2)(x) x = ReflectionPadding2D()(x) x = layers.Conv2D(filters=512, kernel_size=4, strides=1, kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) x = InstanceNormalization()(x) x = layers.LeakyReLU(0.2)(x) x = ReflectionPadding2D()(x) x = layers.Conv2D(filters=1, kernel_size=4, strides=1, kernel_initializer=RandomNormal(mean=0, stddev=0.02))(x) return Model(inputs=x0, outputs=x)
def _residual_block(x0, num_filter, kernel_size=(3, 3), strides=(1, 1)): initializer = tf.random_normal_initializer(0., 0.02) x0_cropped = tf.keras.layers.Cropping2D(cropping=2)(x0) x = tf.keras.layers.Conv2D(filters=num_filter, kernel_size=kernel_size, strides=strides, kernel_initializer=initializer)(x0) x = InstanceNormalization()(x) x = tf.keras.layers.ReLU()(x) x = tf.keras.layers.Conv2D(filters=num_filter, kernel_size=kernel_size, strides=strides, kernel_initializer=initializer)(x) x = InstanceNormalization()(x) x = tf.keras.layers.Add()([x, x0_cropped]) return x
def _downsample(x0, num_filter, kernel_size=(3, 3), strides=(2, 2), padding="same"): initializer = tf.random_normal_initializer(0., 0.02) x = tf.keras.layers.Conv2D(filters=num_filter, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer=initializer)(x0) x = InstanceNormalization()(x) x = tf.keras.layers.ReLU()(x) return x
def _conv_block(x0, num_filter, kernel_size=(9, 9), strides=(1, 1), padding="same", apply_relu=True): initializer = tf.random_normal_initializer(0., 0.02) x = tf.keras.layers.Conv2D(filters=num_filter, kernel_size=kernel_size, strides=strides, padding=padding, kernel_initializer=initializer)(x0) x = InstanceNormalization()(x) if apply_relu: x = tf.keras.layers.ReLU()(x) return x