Exemplo n.º 1
0
def _norm_relu(input_tensor, norm):
    """Applies normalization and ReLU activation to an input tensor.

  Args:
    input_tensor: The `tf.Tensor` to apply the block to.
    norm: A `NormLayer` specifying the type of normalization layer used.

  Returns:
    A `tf.Tensor`.
  """
    if tf.keras.backend.image_data_format() == 'channels_last':
        channel_axis = 3
    else:
        channel_axis = 1

    if norm is NormLayer.group_norm:
        x = tfa_norms.GroupNormalization(axis=channel_axis)(input_tensor)
    elif norm is NormLayer.batch_norm:
        x = tf.keras.layers.BatchNormalization(
            axis=channel_axis,
            momentum=BATCH_NORM_DECAY,
            epsilon=BATCH_NORM_EPSILON)(input_tensor)
    else:
        raise ValueError('The norm argument must be of type `NormLayer`.')
    return tf.keras.layers.Activation('relu')(x)
Exemplo n.º 2
0
def _shortcut(input_tensor, residual, norm):
    """Computes the output of a shortcut block between an input and residual.

  More specifically, this block takes `input` and adds it to `residual`. If
  `input` is not the same shape as `residual`, then we first apply an
  appropriately-sized convolutional layer to alter its shape to that of
  `residual` and normalize via `norm` before adding it to `residual`.

  Args:
    input_tensor: The `tf.Tensor` to apply the block to.
    residual: A `tf.Tensor` added to `input_tensor` after it has been passed
      through a convolution and normalization.
    norm: A `NormLayer` specifying the type of normalization layer used.

  Returns:
    A `tf.Tensor`.
  """
    input_shape = tf.keras.backend.int_shape(input_tensor)
    residual_shape = tf.keras.backend.int_shape(residual)

    if tf.keras.backend.image_data_format() == 'channels_last':
        row_axis = 1
        col_axis = 2
        channel_axis = 3
    else:
        channel_axis = 1
        row_axis = 2
        col_axis = 3

    stride_width = int(round(input_shape[row_axis] / residual_shape[row_axis]))
    stride_height = int(round(input_shape[col_axis] /
                              residual_shape[col_axis]))
    equal_channels = input_shape[channel_axis] == residual_shape[channel_axis]

    shortcut = input_tensor
    # Use a 1-by-1 kernel if the strides are greater than 1, or there the input
    # and residual tensors have different numbers of channels.
    if stride_width > 1 or stride_height > 1 or not equal_channels:
        shortcut = tf.keras.layers.Conv2D(
            filters=residual_shape[channel_axis],
            kernel_size=(1, 1),
            strides=(stride_width, stride_height),
            padding='valid',
            use_bias=False,
            kernel_initializer='he_normal',
            kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY))(
                shortcut)

        if norm is NormLayer.group_norm:
            shortcut = tfa_norms.GroupNormalization(
                axis=channel_axis)(shortcut)
        elif norm is NormLayer.batch_norm:
            shortcut = tf.keras.layers.BatchNormalization(
                axis=channel_axis,
                momentum=BATCH_NORM_DECAY,
                epsilon=BATCH_NORM_EPSILON)(shortcut)
        else:
            raise ValueError('The norm argument must be of type `NormLayer`.')

    return tf.keras.layers.add([shortcut, residual])
Exemplo n.º 3
0
def _norm_relu(input_tensor, norm='group'):
    """Helper function to make a Norm -> ReLU block."""
    if tf.keras.backend.image_data_format() == 'channels_last':
        channel_axis = 3
    else:
        channel_axis = 1

    if norm == 'group':
        x = tfa_norms.GroupNormalization(axis=channel_axis)(input_tensor)
    else:
        x = tf.keras.layers.BatchNormalization(
            axis=channel_axis,
            momentum=BATCH_NORM_DECAY,
            epsilon=BATCH_NORM_EPSILON)(input_tensor)
    return tf.keras.layers.Activation('relu')(x)
Exemplo n.º 4
0
def _shortcut(input_tensor, residual, norm='group', seed=0):
    """Adds a shortcut between input and the residual."""
    input_shape = tf.keras.backend.int_shape(input_tensor)
    residual_shape = tf.keras.backend.int_shape(residual)

    if tf.keras.backend.image_data_format() == 'channels_last':
        row_axis = 1
        col_axis = 2
        channel_axis = 3
    else:
        channel_axis = 1
        row_axis = 2
        col_axis = 3

    stride_width = int(round(input_shape[row_axis] / residual_shape[row_axis]))
    stride_height = int(round(input_shape[col_axis] /
                              residual_shape[col_axis]))
    equal_channels = input_shape[channel_axis] == residual_shape[channel_axis]

    shortcut = input_tensor
    # 1 X 1 conv if shape is different. Else identity.
    if stride_width > 1 or stride_height > 1 or not equal_channels:
        shortcut = tf.keras.layers.Conv2D(
            filters=residual_shape[channel_axis],
            kernel_size=(1, 1),
            strides=(stride_width, stride_height),
            padding='valid',
            use_bias=False,
            kernel_initializer=tf.keras.initializers.HeNormal(seed=seed),
            kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY))(
                shortcut)

        if norm == 'group':
            shortcut = tfa_norms.GroupNormalization(
                axis=channel_axis)(shortcut)
        else:
            shortcut = tf.keras.layers.BatchNormalization(
                axis=channel_axis,
                momentum=BATCH_NORM_DECAY,
                epsilon=BATCH_NORM_EPSILON)(shortcut)

    return tf.keras.layers.add([shortcut, residual])
def create_discriminator(
        downsampling_blocks_num_channels: Sequence[Sequence[int]] = ((64, 128),
                                                                     (128,
                                                                      128),
                                                                     (256,
                                                                      256),
                                                                     (512,
                                                                      512)),
        relu_leakiness: float = 0.2,
        kernel_initializer: Optional[_KerasInitializer] = None,
        use_fan_in_scaled_kernels: bool = True,
        use_layer_normalization: bool = False,
        use_intermediate_inputs: bool = False,
        use_antialiased_bilinear_downsampling: bool = False,
        name: str = 'progressive_gan_discriminator'):
    """Creates a Keras model for the discriminator architecture.

  This architecture is implemented according to the paper "Progressive growing
  of GANs for Improved Quality, Stability, and Variation"
  https://arxiv.org/abs/1710.10196
  The intermediate outputs can optionally be given as input for the architecture
  of "MSG-GAN: Multi-Scale Gradient GAN for Stable Image Synthesis"
  https://arxiv.org/abs/1903.06048

  Args:
    downsampling_blocks_num_channels: The number of channels in the downsampling
      blocks for each block the number of channels for the first and second
      convolution are specified.
    relu_leakiness: Slope of the negative part of the leaky relu.
    kernel_initializer: Initializer of the kernel. If none TruncatedNormal is
      used.
    use_fan_in_scaled_kernels: This rescales the kernels using the scale factor
      from the he initializer, which implements the equalized learning rate.
    use_layer_normalization: If layer normalization layers should be inserted to
      the network.
    use_intermediate_inputs: If true the model expects a list of tf.Tensors with
      increasing resolution starting with the starting_size up to the final
      resolution as input.
    use_antialiased_bilinear_downsampling: If true the downsampling operation is
      ani-aliased bilinear downsampling with a [1, 3, 3, 1] tent kernel. If
      false standard bilinear downsampling, i.e. average pooling is used ([1, 1]
      tent kernel).
    name: The name of the Keras model.

  Returns:
    The generated discriminator keras model.
  """
    if kernel_initializer is None:
        kernel_initializer = tf.keras.initializers.TruncatedNormal(mean=0.0,
                                                                   stddev=1.0)

    if use_intermediate_inputs:
        inputs = tuple(
            tf.keras.Input(shape=(None, None, 3))
            for _ in range(len(downsampling_blocks_num_channels) + 1))
        tensor = inputs[-1]
    else:
        input_tensor = tf.keras.Input(shape=(None, None, 3))
        tensor = input_tensor

    tensor = from_rgb(tensor,
                      use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
                      num_channels=downsampling_blocks_num_channels[0][0],
                      kernel_initializer=kernel_initializer,
                      relu_leakiness=relu_leakiness)
    if use_layer_normalization:
        tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)

    for index, (channels_1,
                channels_2) in enumerate(downsampling_blocks_num_channels):
        tensor = create_conv_layer(
            use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
            filters=channels_1,
            kernel_size=3,
            strides=1,
            padding='same',
            kernel_initializer=kernel_initializer)(tensor)
        tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
        if use_layer_normalization:
            tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
        tensor = create_conv_layer(
            use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
            filters=channels_2,
            kernel_size=3,
            strides=1,
            padding='same',
            kernel_initializer=kernel_initializer)(tensor)
        tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
        if use_layer_normalization:
            tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
        if use_antialiased_bilinear_downsampling:
            tensor = keras_layers.Blur2D()(tensor)
        tensor = tf.keras.layers.AveragePooling2D()(tensor)

        if use_intermediate_inputs:
            tensor = tf.keras.layers.Concatenate()(
                [inputs[-index - 2], tensor])

    tensor = create_conv_layer(
        use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
        filters=downsampling_blocks_num_channels[-1][1],
        kernel_size=3,
        strides=1,
        padding='same',
        kernel_initializer=kernel_initializer)(tensor)
    tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
    if use_layer_normalization:
        tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)

    tensor = create_conv_layer(
        use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
        filters=downsampling_blocks_num_channels[-1][1],
        kernel_size=4,
        strides=1,
        padding='valid',
        kernel_initializer=kernel_initializer)(tensor)
    tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
    if use_layer_normalization:
        tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)

    tensor = create_conv_layer(
        use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
        multiplier=1.0,
        filters=1,
        kernel_size=1,
        kernel_initializer=kernel_initializer)(tensor)
    tensor = tf.keras.layers.Reshape((-1, ))(tensor)

    if use_intermediate_inputs:
        return tf.keras.Model(inputs=inputs, outputs=tensor, name=name)
    else:
        return tf.keras.Model(inputs=input_tensor, outputs=tensor, name=name)