def build_4x4_discriminator(): ''' 4 * 4 Discriminator ''' inputs = tf.keras.layers.Input((4, 4, 3)) # Not used in 4 * 4 alpha = tf.keras.layers.Input((1), name='input_alpha') # From RGB from_rgb = mt.EqualizeLearningRate(tf.keras.layers.Conv2D( 512, kernel_size=1, strides=1, padding='same', activation=tf.nn.leaky_relu, kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='from_rgb_{}x{}'.format(4, 4)) x = from_rgb(inputs) x = mt.EqualizeLearningRate(tf.keras.layers.Conv2D( 512, kernel_size=1, strides=1, padding='same', activation=tf.nn.leaky_relu, kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='conv2d_up_channel')(x) x = discriminator_block(x) model = tf.keras.Model(inputs=[inputs, alpha], outputs=x) return model
def discriminator_block(x): ''' Discriminator output block ''' x = mt.MinibatchSTDDEV()(x) x = mt.EqualizeLearningRate(tf.keras.layers.Conv2D( 512, 3, strides=1, padding='same', kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='d_output_conv2d_1')(x) x = tf.keras.layers.LeakyReLU()(x) x = mt.EqualizeLearningRate(tf.keras.layers.Conv2D( 512, 4, strides=1, padding='valid', kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='d_output_conv2d_2')(x) x = tf.keras.layers.LeakyReLU()(x) x = tf.keras.layers.Flatten()(x) x = mt.EqualizeLearningRate(tf.keras.layers.Dense( 1, kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='d_output_dense')(x) return x
def build_8x8_discriminator(): ''' 8 * 8 Discriminator ''' fade_in_channel = 512 inputs = tf.keras.layers.Input((8, 8, 3)) alpha = tf.keras.layers.Input((1), name='input_alpha') downsample = tf.keras.layers.AveragePooling2D(pool_size=2) ######################## # Left branch in the paper ######################## previous_from_rgb = mt.EqualizeLearningRate( tf.keras.layers.Conv2D(512, kernel_size=1, strides=1, padding='same', activation=tf.nn.leaky_relu, kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='from_rgb_{}x{}'.format(4, 4)) l_x = previous_from_rgb(downsample(inputs)) l_x = tf.keras.layers.Multiply()([1 - alpha, l_x]) ######################## # Right branch in the paper ######################## from_rgb = mt.EqualizeLearningRate(tf.keras.layers.Conv2D( 512, kernel_size=1, strides=1, padding='same', activation=tf.nn.leaky_relu, kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='from_rgb_{}x{}'.format(8, 8)) r_x = from_rgb(inputs) ######################## # Fade in block ######################## r_x = mt.downsample_block(r_x, filters1=512, filters2=fade_in_channel, kernel_size=3, strides=1, padding='same', activation=tf.nn.leaky_relu, name='Down_{}x{}'.format(8, 8)) r_x = tf.keras.layers.Multiply()([alpha, r_x]) x = tf.keras.layers.Add()([l_x, r_x]) ######################## # Stable block ######################## x = discriminator_block(x) model = tf.keras.Model(inputs=[inputs, alpha], outputs=x) return model
def build_8x8_generator(noise_dim=NOISE_DIM): ''' 8 * 8 Generator ''' # Initial block inputs = tf.keras.layers.Input(noise_dim) x = generator_input_block(inputs) alpha = tf.keras.layers.Input((1), name='input_alpha') ######################## # Fade in block ######################## x, up_x = mt.upsample_block(x, in_filters=512, filters=512, kernel_size=3, strides=1, padding='same', activation=tf.nn.leaky_relu, name='Up_{}x{}'.format(8, 8)) previous_to_rgb = mt.EqualizeLearningRate(tf.keras.layers.Conv2D( 3, kernel_size=1, strides=1, padding='same', activation=output_activation, kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='to_rgb_{}x{}'.format(4, 4)) to_rgb = mt.EqualizeLearningRate(tf.keras.layers.Conv2D( 3, kernel_size=1, strides=1, padding='same', activation=output_activation, kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='to_rgb_{}x{}'.format(8, 8)) l_x = to_rgb(x) r_x = previous_to_rgb(up_x) ######################## # Left branch in the paper ######################## l_x = tf.keras.layers.Multiply()([1 - alpha, l_x]) ######################## # Right branch in the paper ######################## r_x = tf.keras.layers.Multiply()([alpha, r_x]) combined = tf.keras.layers.Add()([l_x, r_x]) model = tf.keras.Model(inputs=[inputs, alpha], outputs=combined) return model
def generator_input_block(x): ''' Generator input block ''' x = mt.EqualizeLearningRate(tf.keras.layers.Dense( 4 * 4 * 512, kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='g_input_dense')(x) x = mt.PixelNormalization()(x) x = tf.keras.layers.LeakyReLU()(x) x = tf.keras.layers.Reshape((4, 4, 512))(x) x = mt.EqualizeLearningRate(tf.keras.layers.Conv2D( 512, 3, strides=1, padding='same', kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='g_input_conv2d')(x) x = mt.PixelNormalization()(x) x = tf.keras.layers.LeakyReLU()(x) return x
def build_4x4_generator(noise_dim=NOISE_DIM): ''' 4 * 4 Generator ''' # Initial block inputs = tf.keras.layers.Input(noise_dim) x = generator_input_block(inputs) # Not used in 4 * 4, put it here in order to keep the input here same as the other models alpha = tf.keras.layers.Input((1), name='input_alpha') to_rgb = mt.EqualizeLearningRate(tf.keras.layers.Conv2D( 3, kernel_size=1, strides=1, padding='same', activation=output_activation, kernel_initializer=kernel_initializer, bias_initializer='zeros'), name='to_rgb_{}x{}'.format(4, 4)) rgb_out = to_rgb(x) model = tf.keras.Model(inputs=[inputs, alpha], outputs=rgb_out) return model