def testDenseRank1BatchEnsemble(self, alpha_initializer, gamma_initializer, bias_initializer): tf.keras.backend.set_learning_phase(1) # training time ensemble_size = 3 examples_per_model = 4 input_dim = 5 output_dim = 5 inputs = tf.random.normal([examples_per_model, input_dim]) batched_inputs = tf.tile(inputs, [ensemble_size, 1]) layer = rank1_bnn_layers.DenseRank1( output_dim, alpha_initializer=alpha_initializer, gamma_initializer=gamma_initializer, bias_initializer=bias_initializer, alpha_regularizer=None, gamma_regularizer=None, activation=None, ensemble_size=ensemble_size) output = layer(batched_inputs) manual_output = [ layer.dense(inputs*layer.alpha[i]) * layer.gamma[i] + layer.bias[i] for i in range(ensemble_size)] manual_output = tf.concat(manual_output, axis=0) expected_shape = (ensemble_size*examples_per_model, output_dim) self.assertEqual(output.shape, expected_shape) self.assertAllClose(output, manual_output)
def testDenseRank1Model(self): inputs = np.random.rand(3, 4, 4, 1).astype(np.float32) model = tf.keras.Sequential([ tf.keras.layers.Conv2D(3, kernel_size=2, padding='SAME', activation=tf.nn.relu), tf.keras.layers.Flatten(), rank1_bnn_layers.DenseRank1(2, activation=None), ]) outputs = model(inputs, training=True) self.assertEqual(outputs.shape, (3, 2)) self.assertLen(model.losses, 2)
def testDenseRank1AlphaGamma(self, alpha_initializer, gamma_initializer, all_close, use_additive_perturbation, ensemble_size): tf.keras.backend.set_learning_phase(1) # training time inputs = np.random.rand(5 * ensemble_size, 12).astype(np.float32) model = rank1_bnn_layers.DenseRank1( 4, ensemble_size=ensemble_size, alpha_initializer=alpha_initializer, gamma_initializer=gamma_initializer, activation=None) outputs1 = model(inputs) outputs2 = model(inputs) self.assertEqual(outputs1.shape, (5 * ensemble_size, 4)) if all_close: self.assertAllClose(outputs1, outputs2) else: self.assertNotAllClose(outputs1, outputs2) model.get_config()
def wide_resnet(input_shape, depth, width_multiplier, num_classes, alpha_initializer, gamma_initializer, alpha_regularizer, gamma_regularizer, use_additive_perturbation, ensemble_size, random_sign_init, dropout_rate, prior_mean, prior_stddev): """Builds Wide ResNet. Following Zagoruyko and Komodakis (2016), it accepts a width multiplier on the number of filters. Using three groups of residual blocks, the network maps spatial features of size 32x32 -> 16x16 -> 8x8. Args: input_shape: tf.Tensor. depth: Total number of convolutional layers. "n" in WRN-n-k. It differs from He et al. (2015)'s notation which uses the maximum depth of the network counting non-conv layers like dense. width_multiplier: Integer to multiply the number of typical filters by. "k" in WRN-n-k. num_classes: Number of output classes. alpha_initializer: The initializer for the alpha parameters. gamma_initializer: The initializer for the gamma parameters. alpha_regularizer: The regularizer for the alpha parameters. gamma_regularizer: The regularizer for the gamma parameters. use_additive_perturbation: Whether or not to use additive perturbations instead of multiplicative perturbations. ensemble_size: Number of ensemble members. random_sign_init: Value used to initialize trainable deterministic initializers, as applicable. Values greater than zero result in initialization to a random sign vector, where random_sign_init is the probability of a 1 value. Values less than zero result in initialization from a Gaussian with mean 1 and standard deviation equal to -random_sign_init. dropout_rate: Dropout rate. prior_mean: Mean of the prior. prior_stddev: Standard deviation of the prior. Returns: tf.keras.Model. """ if (depth - 4) % 6 != 0: raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).') num_blocks = (depth - 4) // 6 inputs = tf.keras.layers.Input(shape=input_shape) x = Conv2DRank1( 16, strides=1, alpha_initializer=utils.make_initializer(alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=utils.make_initializer(gamma_initializer, random_sign_init, dropout_rate), alpha_regularizer=utils.make_regularizer(alpha_regularizer, prior_mean, prior_stddev), gamma_regularizer=utils.make_regularizer(gamma_regularizer, prior_mean, prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size)(inputs) for strides, filters in zip([1, 2, 2], [16, 32, 64]): x = group(x, filters=filters * width_multiplier, strides=strides, num_blocks=num_blocks, alpha_initializer=alpha_initializer, gamma_initializer=gamma_initializer, alpha_regularizer=alpha_regularizer, gamma_regularizer=gamma_regularizer, use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, random_sign_init=random_sign_init, dropout_rate=dropout_rate, prior_mean=prior_mean, prior_stddev=prior_stddev) x = BatchNormalization()(x) x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.AveragePooling2D(pool_size=8)(x) x = tf.keras.layers.Flatten()(x) x = rank1_bnn_layers.DenseRank1( num_classes, alpha_initializer=utils.make_initializer(alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=utils.make_initializer(gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', activation=None, alpha_regularizer=utils.make_regularizer(alpha_regularizer, prior_mean, prior_stddev), gamma_regularizer=utils.make_regularizer(gamma_regularizer, prior_mean, prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size)(x) return tf.keras.Model(inputs=inputs, outputs=x)
def rank1_resnet50(input_shape, num_classes, alpha_initializer, gamma_initializer, alpha_regularizer, gamma_regularizer, use_additive_perturbation, ensemble_size, random_sign_init, dropout_rate, prior_stddev, use_tpu): """Builds ResNet50 with rank 1 priors. Using strided conv, pooling, four groups of residual blocks, and pooling, the network maps spatial features of size 224x224 -> 112x112 -> 56x56 -> 28x28 -> 14x14 -> 7x7 (Table 1 of He et al. (2015)). Args: input_shape: Shape tuple of input excluding batch dimension. num_classes: Number of output classes. alpha_initializer: The initializer for the alpha parameters. gamma_initializer: The initializer for the gamma parameters. alpha_regularizer: The regularizer for the alpha parameters. gamma_regularizer: The regularizer for the gamma parameters. use_additive_perturbation: Whether or not to use additive perturbations instead of multiplicative perturbations. ensemble_size: Number of ensemble members. random_sign_init: Value used to initialize trainable deterministic initializers, as applicable. Values greater than zero result in initialization to a random sign vector, where random_sign_init is the probability of a 1 value. Values less than zero result in initialization from a Gaussian with mean 1 and standard deviation equal to -random_sign_init. dropout_rate: Dropout rate. prior_stddev: Standard deviation of the prior. use_tpu: whether the model runs on TPU. Returns: tf.keras.Model. """ group_ = functools.partial( group, alpha_initializer=alpha_initializer, gamma_initializer=gamma_initializer, alpha_regularizer=alpha_regularizer, gamma_regularizer=gamma_regularizer, use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, random_sign_init=random_sign_init, dropout_rate=dropout_rate, prior_stddev=prior_stddev, use_tpu=use_tpu) inputs = tf.keras.layers.Input(shape=input_shape) x = tf.keras.layers.ZeroPadding2D(padding=3, name='conv1_pad')(inputs) x = rank1_bnn_layers.Conv2DRank1( 64, kernel_size=7, strides=2, padding='valid', use_bias=False, alpha_initializer=utils.make_initializer(alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=utils.make_initializer(gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', alpha_regularizer=utils.make_regularizer(alpha_regularizer, 1., prior_stddev), gamma_regularizer=utils.make_regularizer(gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, name='conv1', ensemble_size=ensemble_size)(x) x = ed.layers.ensemble_batchnorm(x, ensemble_size=ensemble_size, use_tpu=use_tpu, momentum=BATCH_NORM_DECAY, epsilon=BATCH_NORM_EPSILON, name='bn_conv1') x = tf.keras.layers.Activation('relu')(x) x = tf.keras.layers.MaxPooling2D(3, strides=(2, 2), padding='same')(x) x = group_(x, [64, 64, 256], stage=2, num_blocks=3, strides=1) x = group_(x, [128, 128, 512], stage=3, num_blocks=4, strides=2) x = group_(x, [256, 256, 1024], stage=4, num_blocks=6, strides=2) x = group_(x, [512, 512, 2048], stage=5, num_blocks=3, strides=2) x = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')(x) x = rank1_bnn_layers.DenseRank1( num_classes, alpha_initializer=utils.make_initializer(alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=utils.make_initializer(gamma_initializer, random_sign_init, dropout_rate), kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), alpha_regularizer=utils.make_regularizer(alpha_regularizer, 1., prior_stddev), gamma_regularizer=utils.make_regularizer(gamma_regularizer, 1., prior_stddev), use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, activation=None, name='fc1000')(x) return tf.keras.Model(inputs=inputs, outputs=x, name='resnet50')
def rank1_resnet_v1(input_shape, depth, num_classes, width_multiplier, alpha_initializer, gamma_initializer, alpha_regularizer, gamma_regularizer, use_additive_perturbation, ensemble_size, random_sign_init, dropout_rate): """Builds Bayesian rank-1 prior ResNet v1. Args: input_shape: tf.Tensor. depth: ResNet depth. num_classes: Number of output classes. width_multiplier: Integer to multiply the number of typical filters by. alpha_initializer: The initializer for the alpha parameters. gamma_initializer: The initializer for the gamma parameters. alpha_regularizer: The regularizer for the alpha parameters. gamma_regularizer: The regularizer for the gamma parameters. use_additive_perturbation: Whether or not to use additive perturbations instead of multiplicative perturbations. ensemble_size: Number of ensemble members. random_sign_init: Value used to initialize trainable deterministic initializers, as applicable. Values greater than zero result in initialization to a random sign vector, where random_sign_init is the probability of a 1 value. Values less than zero result in initialization from a Gaussian with mean 1 and standard deviation equal to -random_sign_init. dropout_rate: Dropout rate. Returns: tf.keras.Model. """ if (depth - 2) % 6 != 0: raise ValueError('depth should be 6n+2 (e.g., 20, 32, 44).') filters = 16 * width_multiplier num_res_blocks = int((depth - 2) / 6) resnet_layer = functools.partial( rank1_resnet_layer, alpha_initializer=alpha_initializer, gamma_initializer=gamma_initializer, alpha_regularizer=alpha_regularizer, gamma_regularizer=gamma_regularizer, use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size, random_sign_init=random_sign_init, dropout_rate=dropout_rate) inputs = tf.keras.layers.Input(shape=input_shape) x = resnet_layer(inputs, filters=filters, kernel_size=3, strides=1, activation='relu') for stack in range(3): for res_block in range(num_res_blocks): strides = 1 if stack > 0 and res_block == 0: # first layer but not first stack strides = 2 # downsample y = resnet_layer(x, filters=filters, kernel_size=3, strides=strides, activation='relu') y = resnet_layer(y, filters=filters, kernel_size=3, strides=1, activation=None) if stack > 0 and res_block == 0: # first layer but not first stack # linear projection residual shortcut connection to match # changed dims x = resnet_layer(x, filters=filters, kernel_size=1, strides=strides, activation=None) x = tf.keras.layers.add([x, y]) x = tf.keras.layers.Activation('relu')(x) filters *= 2 # v1 does not use BN after last shortcut connection-ReLU x = tf.keras.layers.AveragePooling2D(pool_size=8)(x) x = tf.keras.layers.Flatten()(x) x = rank1_bnn_layers.DenseRank1( num_classes, activation=None, alpha_initializer=utils.make_initializer(alpha_initializer, random_sign_init, dropout_rate), gamma_initializer=utils.make_initializer(gamma_initializer, random_sign_init, dropout_rate), kernel_initializer='he_normal', alpha_regularizer=alpha_regularizer, gamma_regularizer=gamma_regularizer, use_additive_perturbation=use_additive_perturbation, ensemble_size=ensemble_size)(x) model = tf.keras.Model(inputs=inputs, outputs=x) return model