def wide_resnet(input_shape, depth, width_multiplier, num_classes, ensemble_size): """Builds Wide ResNet with Sparse BatchEnsemble. Following Zagoruyko and Komodakis (2016), it accepts a width multiplier on the number of filters. Using three groups of residual blocks, the network maps spatial features of size 32x32 -> 16x16 -> 8x8. Args: input_shape: tf.Tensor. The input shape must be (ensemble_size, width, height, channels). depth: Total number of convolutional layers. "n" in WRN-n-k. It differs from He et al. (2015)'s notation which uses the maximum depth of the network counting non-conv layers like dense. width_multiplier: Integer to multiply the number of typical filters by. "k" in WRN-n-k. num_classes: Number of output classes. ensemble_size: Number of ensemble members. Returns: tf.keras.Model. """ if (depth - 4) % 6 != 0: raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).') num_blocks = (depth - 4) // 6 input_shape = list(input_shape) inputs = tf.keras.layers.Input(shape=input_shape) x = tf.keras.layers.Permute([2, 3, 4, 1])(inputs) if ensemble_size != input_shape[0]: raise ValueError('the first dimension of input_shape must be ensemble_size') x = tf.keras.layers.Reshape(input_shape[1:-1] + [input_shape[-1] * ensemble_size])(x) x = Conv2D(16, strides=1)(x) for strides, filters in zip([1, 2, 2], [16, 32, 64]): x = group( x, filters=filters * width_multiplier, strides=strides, num_blocks=num_blocks) x = BatchNormalization()(x) x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.AveragePooling2D(pool_size=8)(x) x = tf.keras.layers.Flatten()(x) x = layers.DenseMultihead( num_classes, kernel_initializer='he_normal', activation=None, ensemble_size=ensemble_size)(x) return tf.keras.Model(inputs=inputs, outputs=x)
def resnet50(input_shape, num_classes, ensemble_size, width_multiplier=1): """Builds a multiheaded ResNet50. Using strided conv, pooling, four groups of residual blocks, and pooling, the network maps spatial features of size 224x224 -> 112x112 -> 56x56 -> 28x28 -> 14x14 -> 7x7 (Table 1 of He et al. (2015)). Args: input_shape: Shape tuple of input excluding batch dimension. num_classes: Number of output classes. ensemble_size: Number of ensembles i.e. number of heads and inputs. width_multiplier: Multiply the number of filters for wide ResNet. Returns: tf.keras.Model. """ input_shape = list(input_shape) inputs = tf.keras.layers.Input(shape=input_shape) x = tf.keras.layers.Permute([2, 3, 4, 1])(inputs) assert ensemble_size == input_shape[0] x = tf.keras.layers.Reshape( list(input_shape[1:-1]) + [input_shape[-1] * ensemble_size])(x) x = tf.keras.layers.ZeroPadding2D(padding=3, name='conv1_pad')(x) x = tf.keras.layers.Conv2D(width_multiplier * 64, kernel_size=7, strides=2, padding='valid', use_bias=False, kernel_initializer='he_normal', name='conv1')(x) x = tf.keras.layers.BatchNormalization(momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, name='bn_conv1')(x) x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.MaxPooling2D(3, strides=2, padding='same')(x) x = group( x, [width_multiplier * 64, width_multiplier * 64, width_multiplier * 256], stage=2, num_blocks=3, strides=1) x = group(x, [ width_multiplier * 128, width_multiplier * 128, width_multiplier * 512 ], stage=3, num_blocks=4, strides=2) x = group(x, [ width_multiplier * 256, width_multiplier * 256, width_multiplier * 1024 ], stage=4, num_blocks=6, strides=2) x = group(x, [ width_multiplier * 512, width_multiplier * 512, width_multiplier * 2048 ], stage=5, num_blocks=3, strides=2) x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x) x = layers.DenseMultihead( num_classes, activation=None, kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), ensemble_size=ensemble_size, name='fc1000')(x) return tf.keras.Model(inputs=inputs, outputs=x, name='resnet50')