def test_frn_serialization(self): layer0 = frn.FRN(reg_epsilon=1.0e-4) config0 = layer0.get_config() layer1 = layer0.__class__(**config0) config1 = layer1.get_config() self.assertEqual(config0, config1, msg='Serialization does not capture all state.')
def resnet_layer(inputs, filters, kernel_size=3, strides=1, activation=None, pfac=None, use_frn=False, use_bias=True): """2D Convolution-Batch Normalization-Activation stack builder. Args: inputs: tf.Tensor. filters: Number of filters for Conv2D. kernel_size: Kernel dimensions for Conv2D. strides: Stride dimensinons for Conv2D. activation: tf.keras.activations.Activation. pfac: prior.PriorFactory object. use_frn: if True, use Filter Response Normalization (FRN) layer use_bias: if True, use biases in Conv layers. Returns: tf.Tensor. """ x = inputs logging.info('Applying conv layer.') x = pfac(tf.keras.layers.Conv2D( filters, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer='he_normal', use_bias=use_bias))(x) if use_frn: x = pfac(frn.FRN())(x) else: x = tf.keras.layers.BatchNormalization()(x) if activation is not None: x = tf.keras.layers.Activation(activation)(x) return x