예제 #1
0
 def test_tlu_serialization(self):
     layer0 = frn.TLU(tau_regularizer=tf.keras.regularizers.l2(l=0.01))
     config0 = layer0.get_config()
     layer1 = layer0.__class__(**config0)
     config1 = layer1.get_config()
     self.assertEqual(config0,
                      config1,
                      msg='Serialization does not capture all state.')
예제 #2
0
def build_resnet_v1(input_shape, depth, num_classes, pfac, use_frn=False,
                    use_internal_bias=True):
  """Builds ResNet v1.

  Args:
    input_shape: tf.Tensor.
    depth: ResNet depth.
    num_classes: Number of output classes.
    pfac: priorfactory.PriorFactory class.
    use_frn: if True, then use Filter Response Normalization (FRN) instead of
      batchnorm.
    use_internal_bias: if True, use biases in all Conv layers.
      If False, only use a bias in the final Dense layer.

  Returns:
    tf.keras.Model.
  """
  def resnet_layer(inputs,
                   filters,
                   kernel_size=3,
                   strides=1,
                   activation=None,
                   pfac=None,
                   use_frn=False,
                   use_bias=True):
    """2D Convolution-Batch Normalization-Activation stack builder.

    Args:
      inputs: tf.Tensor.
      filters: Number of filters for Conv2D.
      kernel_size: Kernel dimensions for Conv2D.
      strides: Stride dimensinons for Conv2D.
      activation: tf.keras.activations.Activation.
      pfac: prior.PriorFactory object.
      use_frn: if True, use Filter Response Normalization (FRN) layer
      use_bias: if True, use biases in Conv layers.

    Returns:
      tf.Tensor.
    """
    x = inputs
    logging.info('Applying conv layer.')
    x = pfac(tf.keras.layers.Conv2D(
        filters,
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        kernel_initializer='he_normal',
        use_bias=use_bias))(x)

    if use_frn:
      x = pfac(frn.FRN())(x)
    else:
      x = tf.keras.layers.BatchNormalization()(x)
    if activation is not None:
      x = tf.keras.layers.Activation(activation)(x)
    return x

  # Main network code
  num_res_blocks = (depth - 2) // 6
  filters = 16
  if (depth - 2) % 6 != 0:
    raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).')

  logging.info('Starting ResNet build.')
  inputs = tf.keras.layers.Input(shape=input_shape)
  x = resnet_layer(inputs,
                   filters=filters,
                   activation='relu',
                   pfac=pfac,
                   use_frn=use_frn,
                   use_bias=use_internal_bias)
  for stack in range(3):
    for res_block in range(num_res_blocks):
      logging.info('Starting ResNet stack #%d block #%d.', stack, res_block)
      strides = 1
      if stack > 0 and res_block == 0:  # first layer but not first stack
        strides = 2  # downsample
      y = resnet_layer(x,
                       filters=filters,
                       strides=strides,
                       activation='relu',
                       pfac=pfac,
                       use_frn=use_frn,
                       use_bias=use_internal_bias)
      y = resnet_layer(y,
                       filters=filters,
                       activation=None,
                       pfac=pfac,
                       use_frn=use_frn,
                       use_bias=use_internal_bias)
      if stack > 0 and res_block == 0:  # first layer but not first stack
        # linear projection residual shortcut connection to match changed dims
        x = resnet_layer(x,
                         filters=filters,
                         kernel_size=1,
                         strides=strides,
                         activation=None,
                         pfac=pfac,
                         use_frn=use_frn,
                         use_bias=use_internal_bias)
      x = tf.keras.layers.add([x, y])
      if use_frn:
        x = pfac(frn.TLU())(x)
      else:
        x = tf.keras.layers.Activation('relu')(x)
    filters *= 2

  # v1 does not use BN after last shortcut connection-ReLU
  x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)
  x = tf.keras.layers.Flatten()(x)
  x = pfac(tf.keras.layers.Dense(
      num_classes,
      kernel_initializer='he_normal'))(x)

  logging.info('ResNet successfully built.')
  return tf.keras.models.Model(inputs=inputs, outputs=x)