def basic_block(inputs, filters, strides, alpha_initializer, gamma_initializer, alpha_regularizer, gamma_regularizer, use_additive_perturbation, ensemble_size, random_sign_init, dropout_rate, prior_mean, prior_stddev): """Basic residual block of two 3x3 convs. Args: inputs: tf.Tensor. filters: Number of filters for Conv2D. strides: Stride dimensions for Conv2D. alpha_initializer: The initializer for the alpha parameters. gamma_initializer: The initializer for the gamma parameters. alpha_regularizer: The regularizer for the alpha parameters. gamma_regularizer: The regularizer for the gamma parameters. use_additive_perturbation: Whether or not to use additive perturbations instead of multiplicative perturbations. ensemble_size: Number of ensemble members. random_sign_init: Value used to initialize trainable deterministic initializers, as applicable. Values greater than zero result in initialization to a random sign vector, where random_sign_init is the probability of a 1 value. Values less than zero result in initialization from a Gaussian with mean 1 and standard deviation equal to -random_sign_init. dropout_rate: Dropout rate. prior_mean: Mean of the prior. prior_stddev: Standard deviation of the prior. Returns: tf.Tensor. """ x = inputs y = inputs y = BatchNormalization()(y) y = tf.keras.layers.Activation('relu')(y) y = Conv2DRank1(filters, strides=strides, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, prior_mean, prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, prior_mean, prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size)(y) y = BatchNormalization()(y) y = tf.keras.layers.Activation('relu')(y) y = Conv2DRank1(filters, strides=1, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, prior_mean, prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, prior_mean, prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size)(y) if not x.shape.is_compatible_with(y.shape): x = Conv2DRank1(filters, kernel_size=1, strides=strides, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, prior_mean, prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, prior_mean, prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size)(x) x = tf.keras.layers.add([x, y]) return x
def wide_resnet_rank1(input_shape, depth, width_multiplier, num_classes, alpha_initializer, gamma_initializer, alpha_regularizer, gamma_regularizer, use_additive_perturbation, ensemble_size, random_sign_init, dropout_rate, prior_mean, prior_stddev): """Builds Wide ResNet. Following Zagoruyko and Komodakis (2016), it accepts a width multiplier on the number of filters. Using three groups of residual blocks, the network maps spatial features of size 32x32 -> 16x16 -> 8x8. Args: input_shape: tf.Tensor. depth: Total number of convolutional layers. "n" in WRN-n-k. It differs from He et al. (2015)'s notation which uses the maximum depth of the network counting non-conv layers like dense. width_multiplier: Integer to multiply the number of typical filters by. "k" in WRN-n-k. num_classes: Number of output classes. alpha_initializer: The initializer for the alpha parameters. gamma_initializer: The initializer for the gamma parameters. alpha_regularizer: The regularizer for the alpha parameters. gamma_regularizer: The regularizer for the gamma parameters. use_additive_perturbation: Whether or not to use additive perturbations instead of multiplicative perturbations. ensemble_size: Number of ensemble members. random_sign_init: Value used to initialize trainable deterministic initializers, as applicable. Values greater than zero result in initialization to a random sign vector, where random_sign_init is the probability of a 1 value. Values less than zero result in initialization from a Gaussian with mean 1 and standard deviation equal to -random_sign_init. dropout_rate: Dropout rate. prior_mean: Mean of the prior. prior_stddev: Standard deviation of the prior. Returns: tf.keras.Model. """ if (depth - 4) % 6 != 0: raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).') num_blocks = (depth - 4) // 6 inputs = tf.keras.layers.Input(shape=input_shape) x = Conv2DRank1(16, strides=1, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, prior_mean, prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, prior_mean, prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size)(inputs) for strides, filters in zip([1, 2, 2], [16, 32, 64]): x = group(x, filters=filters * width_multiplier, strides=strides, num_blocks=num_blocks, alpha_initializer=alpha_initializer, gamma_initializer=gamma_initializer, alpha_regularizer=alpha_regularizer, gamma_regularizer=gamma_regularizer, use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, random_sign_init=random_sign_init, dropout_rate=dropout_rate, prior_mean=prior_mean, prior_stddev=prior_stddev) x = BatchNormalization()(x) x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.AveragePooling2D(pool_size=8)(x) x = tf.keras.layers.Flatten()(x) x = ed.layers.DenseRank1( num_classes, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', activation=None, alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, prior_mean, prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, prior_mean, prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size)(x) return tf.keras.Model(inputs=inputs, outputs=x)
def resnet50_rank1(input_shape, num_classes, alpha_initializer, gamma_initializer, alpha_regularizer, gamma_regularizer, use_additive_perturbation, ensemble_size, random_sign_init, dropout_rate, prior_stddev, use_tpu, use_ensemble_bn): """Builds ResNet50 with rank 1 priors. Using strided conv, pooling, four groups of residual blocks, and pooling, the network maps spatial features of size 224x224 -> 112x112 -> 56x56 -> 28x28 -> 14x14 -> 7x7 (Table 1 of He et al. (2015)). Args: input_shape: Shape tuple of input excluding batch dimension. num_classes: Number of output classes. alpha_initializer: The initializer for the alpha parameters. gamma_initializer: The initializer for the gamma parameters. alpha_regularizer: The regularizer for the alpha parameters. gamma_regularizer: The regularizer for the gamma parameters. use_additive_perturbation: Whether or not to use additive perturbations instead of multiplicative perturbations. ensemble_size: Number of ensemble members. random_sign_init: Value used to initialize trainable deterministic initializers, as applicable. Values greater than zero result in initialization to a random sign vector, where random_sign_init is the probability of a 1 value. Values less than zero result in initialization from a Gaussian with mean 1 and standard deviation equal to -random_sign_init. dropout_rate: Dropout rate. prior_stddev: Standard deviation of the prior. use_tpu: whether the model runs on TPU. use_ensemble_bn: Whether to use ensemble batch norm. Returns: tf.keras.Model. """ group_ = functools.partial( group, alpha_initializer=alpha_initializer, gamma_initializer=gamma_initializer, alpha_regularizer=alpha_regularizer, gamma_regularizer=gamma_regularizer, use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, random_sign_init=random_sign_init, dropout_rate=dropout_rate, prior_stddev=prior_stddev, use_tpu=use_tpu, use_ensemble_bn=use_ensemble_bn) inputs = tf.keras.layers.Input(shape=input_shape) x = tf.keras.layers.ZeroPadding2D(padding=3, name='conv1_pad')(inputs) x = ed.layers.Conv2DRank1( 64, kernel_size=7, strides=2, padding='valid', use_bias=False, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, name='conv1', ensemble_size=ensemble_size)(x) if use_ensemble_bn: x = EnsembleSyncBatchNormalization( ensemble_size=ensemble_size, name='bn_conv1')(x) else: x = ed.layers.ensemble_batchnorm( x, ensemble_size=ensemble_size, use_tpu=use_tpu, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, name='bn_conv1') x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.MaxPooling2D(3, strides=(2, 2), padding='same')(x) x = group_(x, [64, 64, 256], stage=2, num_blocks=3, strides=1) x = group_(x, [128, 128, 512], stage=3, num_blocks=4, strides=2) x = group_(x, [256, 256, 1024], stage=4, num_blocks=6, strides=2) x = group_(x, [512, 512, 2048], stage=5, num_blocks=3, strides=2) x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x) x = ed.layers.DenseRank1( num_classes, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, activation=None, name='fc1000')(x) return tf.keras.Model(inputs=inputs, outputs=x, name='resnet50')
def bottleneck_block(inputs, filters, stage, block, strides, alpha_initializer, gamma_initializer, alpha_regularizer, gamma_regularizer, use_additive_perturbation, ensemble_size, random_sign_init, dropout_rate, prior_stddev, use_tpu, use_ensemble_bn): """Residual block with 1x1 -> 3x3 -> 1x1 convs in main path. Note that strides appear in the second conv (3x3) rather than the first (1x1). This is also known as "ResNet v1.5" as it differs from He et al. (2015) (http://torch.ch/blog/2016/02/04/resnets.html). Args: inputs: tf.Tensor. filters: list of integers, the filters of 3 conv layer at main path stage: integer, current stage label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names strides: Strides for the second conv layer in the block. alpha_initializer: The initializer for the alpha parameters. gamma_initializer: The initializer for the gamma parameters. alpha_regularizer: The regularizer for the alpha parameters. gamma_regularizer: The regularizer for the gamma parameters. use_additive_perturbation: Whether or not to use additive perturbations instead of multiplicative perturbations. ensemble_size: Number of ensemble members. random_sign_init: Value used to initialize trainable deterministic initializers, as applicable. Values greater than zero result in initialization to a random sign vector, where random_sign_init is the probability of a 1 value. Values less than zero result in initialization from a Gaussian with mean 1 and standard deviation equal to -random_sign_init. dropout_rate: Dropout rate. prior_stddev: Standard deviation of the prior. use_tpu: whether the model runs on TPU. use_ensemble_bn: Whether to use ensemble sync BN. Returns: tf.Tensor. """ filters1, filters2, filters3 = filters conv_name_base = 'res' + str(stage) + block + '_branch' bn_name_base = 'bn' + str(stage) + block + '_branch' x = ed.layers.Conv2DRank1( filters1, kernel_size=1, use_bias=False, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, name=conv_name_base + '2a', ensemble_size=ensemble_size)(inputs) if use_ensemble_bn: x = EnsembleSyncBatchNormalization( ensemble_size=ensemble_size, name=bn_name_base + '2a')(x) else: x = ed.layers.ensemble_batchnorm( x, ensemble_size=ensemble_size, use_tpu=use_tpu, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, name=bn_name_base+'2a') x = tf.keras.layers.Activation('relu')(x) x = ed.layers.Conv2DRank1( filters2, kernel_size=3, strides=strides, padding='same', use_bias=False, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, name=conv_name_base + '2b', ensemble_size=ensemble_size)(x) if use_ensemble_bn: x = EnsembleSyncBatchNormalization( ensemble_size=ensemble_size, name=bn_name_base + '2b')(x) else: x = ed.layers.ensemble_batchnorm( x, ensemble_size=ensemble_size, use_tpu=use_tpu, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, name=bn_name_base+'2b') x = tf.keras.layers.Activation('relu')(x) x = ed.layers.Conv2DRank1( filters3, kernel_size=1, use_bias=False, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, name=conv_name_base + '2c', ensemble_size=ensemble_size)(x) if use_ensemble_bn: x = EnsembleSyncBatchNormalization( ensemble_size=ensemble_size, name=bn_name_base + '2c')(x) else: x = ed.layers.ensemble_batchnorm( x, ensemble_size=ensemble_size, use_tpu=use_tpu, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, name=bn_name_base + '2c') shortcut = inputs if not x.shape.is_compatible_with(shortcut.shape): shortcut = ed.layers.Conv2DRank1( filters3, kernel_size=1, strides=strides, use_bias=False, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, name=conv_name_base + '1', ensemble_size=ensemble_size)(inputs) if use_ensemble_bn: shortcut = EnsembleSyncBatchNormalization( ensemble_size=ensemble_size, name=bn_name_base + '1')(shortcut) else: shortcut = ed.layers.ensemble_batchnorm( shortcut, ensemble_size=ensemble_size, use_tpu=use_tpu, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, name=bn_name_base + '1') x = tf.keras.layers.add([x, shortcut]) x = tf.keras.layers.Activation('relu')(x) return x
def resnet50_het_rank1(input_shape, num_classes, alpha_initializer, gamma_initializer, alpha_regularizer, gamma_regularizer, use_additive_perturbation, ensemble_size, random_sign_init, dropout_rate, prior_stddev, use_tpu, use_ensemble_bn, num_factors, temperature, num_mc_samples, eps=1e-5): """Builds ResNet50 with rank 1 priors and an heteroscedastic output layer. Using strided conv, pooling, four groups of residual blocks, and pooling, the network maps spatial features of size 224x224 -> 112x112 -> 56x56 -> 28x28 -> 14x14 -> 7x7 (Table 1 of He et al. (2015)). Args: input_shape: Shape tuple of input excluding batch dimension. num_classes: Number of output classes. alpha_initializer: The initializer for the alpha parameters. gamma_initializer: The initializer for the gamma parameters. alpha_regularizer: The regularizer for the alpha parameters. gamma_regularizer: The regularizer for the gamma parameters. use_additive_perturbation: Whether or not to use additive perturbations instead of multiplicative perturbations. ensemble_size: Number of ensemble members. random_sign_init: Value used to initialize trainable deterministic initializers, as applicable. Values greater than zero result in initialization to a random sign vector, where random_sign_init is the probability of a 1 value. Values less than zero result in initialization from a Gaussian with mean 1 and standard deviation equal to -random_sign_init. dropout_rate: Dropout rate. prior_stddev: Standard deviation of the prior. use_tpu: whether the model runs on TPU. use_ensemble_bn: Whether to use ensemble batch norm. num_factors: Integer. Number of factors to use in approximation to full rank covariance matrix. It is required that num_factors > 0. temperature: Float or scalar `Tensor` representing the softmax temperature. num_mc_samples: The number of Monte-Carlo samples used to estimate the predictive distribution. eps: Float. Clip probabilities into [eps, 1.0] softmax before applying log (softmax). Returns: tf.keras.Model. """ assert num_factors > 0 group_ = functools.partial( group, alpha_initializer=alpha_initializer, gamma_initializer=gamma_initializer, alpha_regularizer=alpha_regularizer, gamma_regularizer=gamma_regularizer, use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, random_sign_init=random_sign_init, dropout_rate=dropout_rate, prior_stddev=prior_stddev, use_tpu=use_tpu, use_ensemble_bn=use_ensemble_bn) inputs = tf.keras.layers.Input(shape=input_shape) x = tf.keras.layers.ZeroPadding2D(padding=3, name='conv1_pad')(inputs) x = ed.layers.Conv2DRank1( 64, kernel_size=7, strides=2, padding='valid', use_bias=False, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, name='conv1', ensemble_size=ensemble_size)(x) if use_ensemble_bn: x = EnsembleSyncBatchNormalization( ensemble_size=ensemble_size, name='bn_conv1')(x) else: x = ed.layers.ensemble_batchnorm( x, ensemble_size=ensemble_size, use_tpu=use_tpu, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, name='bn_conv1') x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.MaxPooling2D(3, strides=(2, 2), padding='same')(x) x = group_(x, [64, 64, 256], stage=2, num_blocks=3, strides=1) x = group_(x, [128, 128, 512], stage=3, num_blocks=4, strides=2) x = group_(x, [256, 256, 1024], stage=4, num_blocks=6, strides=2) x = group_(x, [512, 512, 2048], stage=5, num_blocks=3, strides=2) x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x) scale_layer = ed.layers.DenseRank1( num_classes * num_factors, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, activation=None, name='scale_layer') loc_layer = ed.layers.DenseRank1( num_classes, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, activation=None, name='loc_layer') diag_layer = ed.layers.DenseRank1( num_classes, alpha_initializer=rank1_bnn_utils.make_initializer( alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=rank1_bnn_utils.make_initializer( gamma_initializer, random_sign_init, dropout_rate), kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), alpha_regularizer=rank1_bnn_utils.make_regularizer( alpha_regularizer, 1., prior_stddev), gamma_regularizer=rank1_bnn_utils.make_regularizer( gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, activation=tf.math.softplus, name='diag_layer') x = ed.layers.MCSoftmaxDenseFACustomLayers( scale_layer=scale_layer, loc_layer=loc_layer, diag_layer=diag_layer, temperature=temperature, train_mc_samples=num_mc_samples, test_mc_samples=num_mc_samples, share_samples_across_batch=True, num_classes=num_classes, num_factors=num_factors, logits_only=True, eps=eps, dtype=tf.float32, name='fc1000')(x) return tf.keras.Model(inputs=inputs, outputs=x, name='resnet50')