def _norm_relu(input_tensor, norm): """Applies normalization and ReLU activation to an input tensor. Args: input_tensor: The `tf.Tensor` to apply the block to. norm: A `NormLayer` specifying the type of normalization layer used. Returns: A `tf.Tensor`. """ if tf.keras.backend.image_data_format() == 'channels_last': channel_axis = 3 else: channel_axis = 1 if norm is NormLayer.group_norm: x = tfa_norms.GroupNormalization(axis=channel_axis)(input_tensor) elif norm is NormLayer.batch_norm: x = tf.keras.layers.BatchNormalization( axis=channel_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON)(input_tensor) else: raise ValueError('The norm argument must be of type `NormLayer`.') return tf.keras.layers.Activation('relu')(x)
def _shortcut(input_tensor, residual, norm): """Computes the output of a shortcut block between an input and residual. More specifically, this block takes `input` and adds it to `residual`. If `input` is not the same shape as `residual`, then we first apply an appropriately-sized convolutional layer to alter its shape to that of `residual` and normalize via `norm` before adding it to `residual`. Args: input_tensor: The `tf.Tensor` to apply the block to. residual: A `tf.Tensor` added to `input_tensor` after it has been passed through a convolution and normalization. norm: A `NormLayer` specifying the type of normalization layer used. Returns: A `tf.Tensor`. """ input_shape = tf.keras.backend.int_shape(input_tensor) residual_shape = tf.keras.backend.int_shape(residual) if tf.keras.backend.image_data_format() == 'channels_last': row_axis = 1 col_axis = 2 channel_axis = 3 else: channel_axis = 1 row_axis = 2 col_axis = 3 stride_width = int(round(input_shape[row_axis] / residual_shape[row_axis])) stride_height = int(round(input_shape[col_axis] / residual_shape[col_axis])) equal_channels = input_shape[channel_axis] == residual_shape[channel_axis] shortcut = input_tensor # Use a 1-by-1 kernel if the strides are greater than 1, or there the input # and residual tensors have different numbers of channels. if stride_width > 1 or stride_height > 1 or not equal_channels: shortcut = tf.keras.layers.Conv2D( filters=residual_shape[channel_axis], kernel_size=(1, 1), strides=(stride_width, stride_height), padding='valid', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY))( shortcut) if norm is NormLayer.group_norm: shortcut = tfa_norms.GroupNormalization( axis=channel_axis)(shortcut) elif norm is NormLayer.batch_norm: shortcut = tf.keras.layers.BatchNormalization( axis=channel_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON)(shortcut) else: raise ValueError('The norm argument must be of type `NormLayer`.') return tf.keras.layers.add([shortcut, residual])
def _norm_relu(input_tensor, norm='group'): """Helper function to make a Norm -> ReLU block.""" if tf.keras.backend.image_data_format() == 'channels_last': channel_axis = 3 else: channel_axis = 1 if norm == 'group': x = tfa_norms.GroupNormalization(axis=channel_axis)(input_tensor) else: x = tf.keras.layers.BatchNormalization( axis=channel_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON)(input_tensor) return tf.keras.layers.Activation('relu')(x)
def _shortcut(input_tensor, residual, norm='group', seed=0): """Adds a shortcut between input and the residual.""" input_shape = tf.keras.backend.int_shape(input_tensor) residual_shape = tf.keras.backend.int_shape(residual) if tf.keras.backend.image_data_format() == 'channels_last': row_axis = 1 col_axis = 2 channel_axis = 3 else: channel_axis = 1 row_axis = 2 col_axis = 3 stride_width = int(round(input_shape[row_axis] / residual_shape[row_axis])) stride_height = int(round(input_shape[col_axis] / residual_shape[col_axis])) equal_channels = input_shape[channel_axis] == residual_shape[channel_axis] shortcut = input_tensor # 1 X 1 conv if shape is different. Else identity. if stride_width > 1 or stride_height > 1 or not equal_channels: shortcut = tf.keras.layers.Conv2D( filters=residual_shape[channel_axis], kernel_size=(1, 1), strides=(stride_width, stride_height), padding='valid', use_bias=False, kernel_initializer=tf.keras.initializers.HeNormal(seed=seed), kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY))( shortcut) if norm == 'group': shortcut = tfa_norms.GroupNormalization( axis=channel_axis)(shortcut) else: shortcut = tf.keras.layers.BatchNormalization( axis=channel_axis, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON)(shortcut) return tf.keras.layers.add([shortcut, residual])
def create_discriminator( downsampling_blocks_num_channels: Sequence[Sequence[int]] = ((64, 128), (128, 128), (256, 256), (512, 512)), relu_leakiness: float = 0.2, kernel_initializer: Optional[_KerasInitializer] = None, use_fan_in_scaled_kernels: bool = True, use_layer_normalization: bool = False, use_intermediate_inputs: bool = False, use_antialiased_bilinear_downsampling: bool = False, name: str = 'progressive_gan_discriminator'): """Creates a Keras model for the discriminator architecture. This architecture is implemented according to the paper "Progressive growing of GANs for Improved Quality, Stability, and Variation" https://arxiv.org/abs/1710.10196 The intermediate outputs can optionally be given as input for the architecture of "MSG-GAN: Multi-Scale Gradient GAN for Stable Image Synthesis" https://arxiv.org/abs/1903.06048 Args: downsampling_blocks_num_channels: The number of channels in the downsampling blocks for each block the number of channels for the first and second convolution are specified. relu_leakiness: Slope of the negative part of the leaky relu. kernel_initializer: Initializer of the kernel. If none TruncatedNormal is used. use_fan_in_scaled_kernels: This rescales the kernels using the scale factor from the he initializer, which implements the equalized learning rate. use_layer_normalization: If layer normalization layers should be inserted to the network. use_intermediate_inputs: If true the model expects a list of tf.Tensors with increasing resolution starting with the starting_size up to the final resolution as input. use_antialiased_bilinear_downsampling: If true the downsampling operation is ani-aliased bilinear downsampling with a [1, 3, 3, 1] tent kernel. If false standard bilinear downsampling, i.e. average pooling is used ([1, 1] tent kernel). name: The name of the Keras model. Returns: The generated discriminator keras model. """ if kernel_initializer is None: kernel_initializer = tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=1.0) if use_intermediate_inputs: inputs = tuple( tf.keras.Input(shape=(None, None, 3)) for _ in range(len(downsampling_blocks_num_channels) + 1)) tensor = inputs[-1] else: input_tensor = tf.keras.Input(shape=(None, None, 3)) tensor = input_tensor tensor = from_rgb(tensor, use_fan_in_scaled_kernel=use_fan_in_scaled_kernels, num_channels=downsampling_blocks_num_channels[0][0], kernel_initializer=kernel_initializer, relu_leakiness=relu_leakiness) if use_layer_normalization: tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor) for index, (channels_1, channels_2) in enumerate(downsampling_blocks_num_channels): tensor = create_conv_layer( use_fan_in_scaled_kernel=use_fan_in_scaled_kernels, filters=channels_1, kernel_size=3, strides=1, padding='same', kernel_initializer=kernel_initializer)(tensor) tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor) if use_layer_normalization: tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor) tensor = create_conv_layer( use_fan_in_scaled_kernel=use_fan_in_scaled_kernels, filters=channels_2, kernel_size=3, strides=1, padding='same', kernel_initializer=kernel_initializer)(tensor) tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor) if use_layer_normalization: tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor) if use_antialiased_bilinear_downsampling: tensor = keras_layers.Blur2D()(tensor) tensor = tf.keras.layers.AveragePooling2D()(tensor) if use_intermediate_inputs: tensor = tf.keras.layers.Concatenate()( [inputs[-index - 2], tensor]) tensor = create_conv_layer( use_fan_in_scaled_kernel=use_fan_in_scaled_kernels, filters=downsampling_blocks_num_channels[-1][1], kernel_size=3, strides=1, padding='same', kernel_initializer=kernel_initializer)(tensor) tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor) if use_layer_normalization: tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor) tensor = create_conv_layer( use_fan_in_scaled_kernel=use_fan_in_scaled_kernels, filters=downsampling_blocks_num_channels[-1][1], kernel_size=4, strides=1, padding='valid', kernel_initializer=kernel_initializer)(tensor) tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor) if use_layer_normalization: tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor) tensor = create_conv_layer( use_fan_in_scaled_kernel=use_fan_in_scaled_kernels, multiplier=1.0, filters=1, kernel_size=1, kernel_initializer=kernel_initializer)(tensor) tensor = tf.keras.layers.Reshape((-1, ))(tensor) if use_intermediate_inputs: return tf.keras.Model(inputs=inputs, outputs=tensor, name=name) else: return tf.keras.Model(inputs=input_tensor, outputs=tensor, name=name)