Exemplo n.º 1
0
def wide_resnet(input_shape, depth, width_multiplier, num_classes,
                ensemble_size):
  """Builds Wide ResNet with Sparse BatchEnsemble.

  Following Zagoruyko and Komodakis (2016), it accepts a width multiplier on the
  number of filters. Using three groups of residual blocks, the network maps
  spatial features of size 32x32 -> 16x16 -> 8x8.

  Args:
    input_shape: tf.Tensor. The input shape must be (ensemble_size, width,
      height, channels).
    depth: Total number of convolutional layers. "n" in WRN-n-k. It differs from
      He et al. (2015)'s notation which uses the maximum depth of the network
      counting non-conv layers like dense.
    width_multiplier: Integer to multiply the number of typical filters by. "k"
      in WRN-n-k.
    num_classes: Number of output classes.
    ensemble_size: Number of ensemble members.

  Returns:
    tf.keras.Model.
  """
  if (depth - 4) % 6 != 0:
    raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).')
  num_blocks = (depth - 4) // 6
  input_shape = list(input_shape)
  inputs = tf.keras.layers.Input(shape=input_shape)
  x = tf.keras.layers.Permute([2, 3, 4, 1])(inputs)
  if ensemble_size != input_shape[0]:
    raise ValueError('the first dimension of input_shape must be ensemble_size')
  x = tf.keras.layers.Reshape(input_shape[1:-1] +
                              [input_shape[-1] * ensemble_size])(x)
  x = Conv2D(16, strides=1)(x)
  for strides, filters in zip([1, 2, 2], [16, 32, 64]):
    x = group(
        x,
        filters=filters * width_multiplier,
        strides=strides,
        num_blocks=num_blocks)

  x = BatchNormalization()(x)
  x = tf.keras.layers.Activation('relu')(x)
  x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)
  x = tf.keras.layers.Flatten()(x)
  x = layers.DenseMultihead(
      num_classes,
      kernel_initializer='he_normal',
      activation=None,
      ensemble_size=ensemble_size)(x)
  return tf.keras.Model(inputs=inputs, outputs=x)
Exemplo n.º 2
0
def resnet50(input_shape, num_classes, ensemble_size, width_multiplier=1):
    """Builds a multiheaded ResNet50.

  Using strided conv, pooling, four groups of residual blocks, and pooling, the
  network maps spatial features of size 224x224 -> 112x112 -> 56x56 -> 28x28 ->
  14x14 -> 7x7 (Table 1 of He et al. (2015)).

  Args:
    input_shape: Shape tuple of input excluding batch dimension.
    num_classes: Number of output classes.
    ensemble_size: Number of ensembles i.e. number of heads and inputs.
    width_multiplier: Multiply the number of filters for wide ResNet.

  Returns:
    tf.keras.Model.
  """
    input_shape = list(input_shape)
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = tf.keras.layers.Permute([2, 3, 4, 1])(inputs)
    assert ensemble_size == input_shape[0]
    x = tf.keras.layers.Reshape(
        list(input_shape[1:-1]) + [input_shape[-1] * ensemble_size])(x)
    x = tf.keras.layers.ZeroPadding2D(padding=3, name='conv1_pad')(x)
    x = tf.keras.layers.Conv2D(width_multiplier * 64,
                               kernel_size=7,
                               strides=2,
                               padding='valid',
                               use_bias=False,
                               kernel_initializer='he_normal',
                               name='conv1')(x)
    x = tf.keras.layers.BatchNormalization(momentum=BATCH_NORM_DECAY,
                                           epsilon=BATCH_NORM_EPSILON,
                                           name='bn_conv1')(x)
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.MaxPooling2D(3, strides=2, padding='same')(x)
    x = group(
        x,
        [width_multiplier * 64, width_multiplier * 64, width_multiplier * 256],
        stage=2,
        num_blocks=3,
        strides=1)
    x = group(x, [
        width_multiplier * 128, width_multiplier * 128, width_multiplier * 512
    ],
              stage=3,
              num_blocks=4,
              strides=2)
    x = group(x, [
        width_multiplier * 256, width_multiplier * 256, width_multiplier * 1024
    ],
              stage=4,
              num_blocks=6,
              strides=2)
    x = group(x, [
        width_multiplier * 512, width_multiplier * 512, width_multiplier * 2048
    ],
              stage=5,
              num_blocks=3,
              strides=2)
    x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.DenseMultihead(
        num_classes,
        activation=None,
        kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
        ensemble_size=ensemble_size,
        name='fc1000')(x)
    return tf.keras.Model(inputs=inputs, outputs=x, name='resnet50')