Exemplo n.º 1
0
def build_model(opts):
    """Builds a ResNet keras.models.Model."""
    is_dropout_last = opts.method in ('ll_dropout', 'dropout',
                                      'dropout_nofirst', 'wide_dropout')
    is_dropout_all = opts.method in ('dropout', 'dropout_nofirst',
                                     'wide_dropout')
    all_dropout_rate = opts.dropout_rate if is_dropout_all else None
    last_dropout_rate = opts.dropout_rate if is_dropout_last else None

    eb_prior_fn = uq_utils.make_prior_fn_for_empirical_bayes(
        opts.init_prior_scale_mean, opts.init_prior_scale_std)

    keras_in = keras.layers.Input(shape=opts.image_shape)
    net = resnet.build_resnet_v1(
        keras_in,
        depth=opts.resnet_depth,
        variational=opts.method == 'svi',
        std_prior_scale=opts.std_prior_scale,
        eb_prior_fn=eb_prior_fn,
        always_on_dropout_rate=all_dropout_rate,
        no_first_layer_dropout=opts.method == 'dropout_nofirst',
        examples_per_epoch=opts.examples_per_epoch,
        num_filters=opts.num_resnet_filters)
    if opts.method == 'vanilla':
        keras_out = keras.layers.Dense(opts.num_classes,
                                       kernel_initializer='he_normal')(net)
    elif is_dropout_last:
        net = keras.layers.Dropout(last_dropout_rate)(net, training=True)
        keras_out = keras.layers.Dense(opts.num_classes,
                                       kernel_initializer='he_normal')(net)
    elif opts.method in ('svi', 'll_svi'):
        divergence_fn = uq_utils.make_divergence_fn_for_empirical_bayes(
            opts.std_prior_scale, opts.examples_per_epoch)

        keras_out = tfp.layers.DenseReparameterization(
            opts.num_classes,
            kernel_prior_fn=eb_prior_fn,
            kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(
                loc_initializer=keras.initializers.he_normal()),
            kernel_divergence_fn=divergence_fn)(net)

    return keras.models.Model(inputs=keras_in, outputs=keras_out)
Exemplo n.º 2
0
def _resnet_layer(inputs,
                  num_filters=16,
                  kernel_size=3,
                  strides=1,
                  activation='relu',
                  depth=20,
                  batch_norm=True,
                  conv_first=True,
                  variational=False,
                  std_prior_scale=1.5,
                  eb_prior_fn=None,
                  always_on_dropout_rate=None,
                  examples_per_epoch=None):
  """2D Convolution-Batch Normalization-Activation stack builder.

  Args:
    inputs (tensor): input tensor from input image or previous layer
    num_filters (int): Conv2D number of filters
    kernel_size (int): Conv2D square kernel dimensions
    strides (int): Conv2D square stride dimensions
    activation (string): Activation function string.
    depth (int): ResNet depth; used for initialization scale.
    batch_norm (bool): whether to include batch normalization
    conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)
    variational (bool): Whether to use a variational convolutional layer.
    std_prior_scale (float): Scale for log-normal hyperprior.
    eb_prior_fn (callable): Empirical Bayes prior for use with TFP layers.
    always_on_dropout_rate (float): Dropout rate (active in train and test).
    examples_per_epoch (int): Number of examples per epoch for variational KL.

  Returns:
      x (tensor): tensor as input to the next layer
  """
  if variational:
    divergence_fn = uq_utils.make_divergence_fn_for_empirical_bayes(
        std_prior_scale, examples_per_epoch)

    def fixup_init(shape, dtype=None):
      """Fixup initialization; see https://arxiv.org/abs/1901.09321."""
      return keras.initializers.he_normal()(shape, dtype=dtype) * depth**(-1/4)

    conv = tfp.layers.Convolution2DFlipout(
        num_filters,
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        kernel_prior_fn=eb_prior_fn,
        kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(
            loc_initializer=fixup_init),
        kernel_divergence_fn=divergence_fn)
  else:
    conv = keras.layers.Conv2D(num_filters,
                               kernel_size=kernel_size,
                               strides=strides,
                               padding='same',
                               kernel_initializer='he_normal',
                               kernel_regularizer=keras.regularizers.l2(1e-4))

  def apply_conv(net):
    logging.info('Applying conv layer; always_on_dropout=%s.',
                 always_on_dropout_rate)
    if always_on_dropout_rate:
      net = keras.layers.Dropout(always_on_dropout_rate)(net, training=True)
    return conv(net)

  x = inputs
  x = apply_conv(x) if conv_first else x
  x = (keras.layers.BatchNormalization()(x)
       if batch_norm and not variational else x)
  x = keras.layers.Activation(activation)(x) if activation is not None else x
  x = x if conv_first else apply_conv(x)
  return x