def build_model(opts): """Builds a ResNet keras.models.Model.""" is_dropout_last = opts.method in ('ll_dropout', 'dropout', 'dropout_nofirst', 'wide_dropout') is_dropout_all = opts.method in ('dropout', 'dropout_nofirst', 'wide_dropout') all_dropout_rate = opts.dropout_rate if is_dropout_all else None last_dropout_rate = opts.dropout_rate if is_dropout_last else None eb_prior_fn = uq_utils.make_prior_fn_for_empirical_bayes( opts.init_prior_scale_mean, opts.init_prior_scale_std) keras_in = keras.layers.Input(shape=opts.image_shape) net = resnet.build_resnet_v1( keras_in, depth=opts.resnet_depth, variational=opts.method == 'svi', std_prior_scale=opts.std_prior_scale, eb_prior_fn=eb_prior_fn, always_on_dropout_rate=all_dropout_rate, no_first_layer_dropout=opts.method == 'dropout_nofirst', examples_per_epoch=opts.examples_per_epoch, num_filters=opts.num_resnet_filters) if opts.method == 'vanilla': keras_out = keras.layers.Dense(opts.num_classes, kernel_initializer='he_normal')(net) elif is_dropout_last: net = keras.layers.Dropout(last_dropout_rate)(net, training=True) keras_out = keras.layers.Dense(opts.num_classes, kernel_initializer='he_normal')(net) elif opts.method in ('svi', 'll_svi'): divergence_fn = uq_utils.make_divergence_fn_for_empirical_bayes( opts.std_prior_scale, opts.examples_per_epoch) keras_out = tfp.layers.DenseReparameterization( opts.num_classes, kernel_prior_fn=eb_prior_fn, kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn( loc_initializer=keras.initializers.he_normal()), kernel_divergence_fn=divergence_fn)(net) return keras.models.Model(inputs=keras_in, outputs=keras_out)
def _resnet_layer(inputs, num_filters=16, kernel_size=3, strides=1, activation='relu', depth=20, batch_norm=True, conv_first=True, variational=False, std_prior_scale=1.5, eb_prior_fn=None, always_on_dropout_rate=None, examples_per_epoch=None): """2D Convolution-Batch Normalization-Activation stack builder. Args: inputs (tensor): input tensor from input image or previous layer num_filters (int): Conv2D number of filters kernel_size (int): Conv2D square kernel dimensions strides (int): Conv2D square stride dimensions activation (string): Activation function string. depth (int): ResNet depth; used for initialization scale. batch_norm (bool): whether to include batch normalization conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False) variational (bool): Whether to use a variational convolutional layer. std_prior_scale (float): Scale for log-normal hyperprior. eb_prior_fn (callable): Empirical Bayes prior for use with TFP layers. always_on_dropout_rate (float): Dropout rate (active in train and test). examples_per_epoch (int): Number of examples per epoch for variational KL. Returns: x (tensor): tensor as input to the next layer """ if variational: divergence_fn = uq_utils.make_divergence_fn_for_empirical_bayes( std_prior_scale, examples_per_epoch) def fixup_init(shape, dtype=None): """Fixup initialization; see https://arxiv.org/abs/1901.09321.""" return keras.initializers.he_normal()(shape, dtype=dtype) * depth**(-1/4) conv = tfp.layers.Convolution2DFlipout( num_filters, kernel_size=kernel_size, strides=strides, padding='same', kernel_prior_fn=eb_prior_fn, kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn( loc_initializer=fixup_init), kernel_divergence_fn=divergence_fn) else: conv = keras.layers.Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same', kernel_initializer='he_normal', kernel_regularizer=keras.regularizers.l2(1e-4)) def apply_conv(net): logging.info('Applying conv layer; always_on_dropout=%s.', always_on_dropout_rate) if always_on_dropout_rate: net = keras.layers.Dropout(always_on_dropout_rate)(net, training=True) return conv(net) x = inputs x = apply_conv(x) if conv_first else x x = (keras.layers.BatchNormalization()(x) if batch_norm and not variational else x) x = keras.layers.Activation(activation)(x) if activation is not None else x x = x if conv_first else apply_conv(x) return x