Beispiel #1
0
 def test_frn_serialization(self):
     layer0 = frn.FRN(reg_epsilon=1.0e-4)
     config0 = layer0.get_config()
     layer1 = layer0.__class__(**config0)
     config1 = layer1.get_config()
     self.assertEqual(config0,
                      config1,
                      msg='Serialization does not capture all state.')
Beispiel #2
0
  def resnet_layer(inputs,
                   filters,
                   kernel_size=3,
                   strides=1,
                   activation=None,
                   pfac=None,
                   use_frn=False,
                   use_bias=True):
    """2D Convolution-Batch Normalization-Activation stack builder.

    Args:
      inputs: tf.Tensor.
      filters: Number of filters for Conv2D.
      kernel_size: Kernel dimensions for Conv2D.
      strides: Stride dimensinons for Conv2D.
      activation: tf.keras.activations.Activation.
      pfac: prior.PriorFactory object.
      use_frn: if True, use Filter Response Normalization (FRN) layer
      use_bias: if True, use biases in Conv layers.

    Returns:
      tf.Tensor.
    """
    x = inputs
    logging.info('Applying conv layer.')
    x = pfac(tf.keras.layers.Conv2D(
        filters,
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        kernel_initializer='he_normal',
        use_bias=use_bias))(x)

    if use_frn:
      x = pfac(frn.FRN())(x)
    else:
      x = tf.keras.layers.BatchNormalization()(x)
    if activation is not None:
      x = tf.keras.layers.Activation(activation)(x)
    return x