Ejemplo n.º 1
0
 def test_mean_field_fn(self):
     p_fn, q_fn = sampling.mean_field_fn()
     layer = tfp.layers.DenseLocalReparameterization(
         100,
         kernel_prior_fn=p_fn,
         kernel_posterior_fn=q_fn,
         bias_prior_fn=p_fn,
         bias_posterior_fn=q_fn)
     self.assertIsInstance(layer, tfp.layers.DenseLocalReparameterization)
Ejemplo n.º 2
0
def res_net(n_examples,
            input_shape,
            num_classes,
            batchnorm=False,
            variational='full'):
  """Wrapper for build_resnet_v1.

  Args:
    n_examples (int): number of training points.
    input_shape (list): input shape.
    num_classes (int): number of classes (CIFAR10 has 10).
    batchnorm (bool): use of batchnorm layers.
    variational (str): 'none', 'hybrid', 'full'. whether to use variational
      inference for zero, some, or all layers.

  Returns:
      model (Model): Keras model instance whose output is a
        tfp.distributions.Categorical distribution.
  """
  inputs = tf.keras.layers.Input(shape=input_shape)
  x = build_resnet_v1(
      inputs,
      depth=20,
      variational=variational,
      batchnorm=batchnorm,
      n_examples=n_examples)

  p_fn, q_fn = mean_field_fn(empirical_bayes=True)

  def normalized_kl_fn(q, p, _):
    return tfp.distributions.kl_divergence(q, p) / tf.to_float(n_examples)

  logits = tfp.layers.DenseLocalReparameterization(
      num_classes,
      kernel_prior_fn=p_fn,
      kernel_posterior_fn=q_fn,
      bias_prior_fn=p_fn,
      bias_posterior_fn=q_fn,
      kernel_divergence_fn=normalized_kl_fn,
      bias_divergence_fn=normalized_kl_fn)(x)
  outputs = tf.keras.layers.Lambda(lambda x: ed.Categorical(logits=x))(logits)
  return tf.keras.models.Model(inputs=inputs, outputs=outputs)
Ejemplo n.º 3
0
def lenet5(n_examples, input_shape, num_classes):
    """Builds Bayesian LeNet5."""
    p_fn, q_fn = mean_field_fn(empirical_bayes=True)

    def normalized_kl_fn(q, p, _):
        return q.kl_divergence(p) / tf.cast(n_examples, tf.float32)

    inputs = tf.keras.layers.Input(shape=input_shape)
    conv1 = tfp.layers.Convolution2DFlipout(
        6,
        kernel_size=5,
        padding='SAME',
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(inputs)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
                                         strides=[2, 2],
                                         padding='SAME')(conv1)
    conv2 = tfp.layers.Convolution2DFlipout(
        16,
        kernel_size=5,
        padding='SAME',
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(pool1)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
                                         strides=[2, 2],
                                         padding='SAME')(conv2)
    conv3 = tfp.layers.Convolution2DFlipout(
        120,
        kernel_size=5,
        padding='SAME',
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(pool2)
    flatten = tf.keras.layers.Flatten()(conv3)
    dense1 = tfp.layers.DenseLocalReparameterization(
        84,
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(flatten)
    dense2 = tfp.layers.DenseLocalReparameterization(
        num_classes,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(dense1)
    outputs = tf.keras.layers.Lambda(lambda x: ed.Categorical(logits=x))(
        dense2)
    return tf.keras.models.Model(inputs=inputs, outputs=outputs)
Ejemplo n.º 4
0
    def test_sample_auxiliary_op(self):
        p_fn, q_fn = sampling.mean_field_fn()
        p = p_fn(tf.float32, (), 'test_prior', True,
                 tf.get_variable).distribution
        q = q_fn(tf.float32, (), 'test_posterior', True,
                 tf.get_variable).distribution

        # Test benign auxiliary variable
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 1e-10)
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1.1, session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(0.5),
                                   session=sess)
        print(sess.run(q.scale))

        sess.run(sample_op)

        tolerance = 0.0001
        self.assertLess(np.abs(sess.run(p.scale) - 1.), tolerance)
        self.assertLess(np.abs(sess.run(p.loc) - 1.), tolerance)
        self.assertLess(np.abs(sess.run(q.scale) - 0.5), tolerance)
        self.assertLess(np.abs(sess.run(q.loc) - 1.1), tolerance)

        # Test fully determining auxiliary variable
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 1. - 1e-10)
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1.1, session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(.5), session=sess)

        sess.run(sample_op)

        self.assertLess(np.abs(sess.run(q.loc) - sess.run(p.loc)), tolerance)
        self.assertLess(sess.run(p.scale), tolerance)
        self.assertLess(sess.run(q.scale), tolerance)

        # Test delta posterior
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 0.5)
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1.1, session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(1e-10),
                                   session=sess)

        sess.run(sample_op)

        self.assertLess(np.abs(sess.run(q.loc) - 1.1), tolerance)
        self.assertLess(sess.run(q.scale), tolerance)

        # Test prior is posterior
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 0.5)
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1., session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)

        sess.run(sample_op)

        self.assertLess(np.abs(sess.run(q.loc - p.loc)), tolerance)
        self.assertLess(np.abs(sess.run(q.scale - p.scale)), tolerance)
Ejemplo n.º 5
0
def conv_net(n_examples, input_shape, num_classes):
    """Build a simple, feed forward Bayesian neural net."""
    p_fn, q_fn = mean_field_fn(empirical_bayes=True)

    def normalized_kl_fn(q, p, _):
        return tfp.distributions.kl_divergence(q, p) / tf.to_float(n_examples)

    inputs = tf.keras.layers.Input(shape=input_shape)
    conv1 = tfp.layers.Convolution2DFlipout(
        6,
        kernel_size=5,
        padding="SAME",
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(inputs)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
                                         strides=[2, 2],
                                         padding="SAME")(conv1)
    conv2 = tfp.layers.Convolution2DFlipout(
        16,
        kernel_size=5,
        padding="SAME",
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(pool1)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
                                         strides=[2, 2],
                                         padding="SAME")(conv2)
    conv3 = tfp.layers.Convolution2DFlipout(
        120,
        kernel_size=5,
        padding="SAME",
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(pool2)
    flatten = tf.keras.layers.Flatten()(conv3)
    dense1 = tfp.layers.DenseLocalReparameterization(
        84,
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(flatten)
    dense2 = tfp.layers.DenseLocalReparameterization(
        num_classes,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(dense1)

    output_dist = tfp.layers.DistributionLambda(
        lambda o: tfd.Categorical(logits=o))(dense2)
    return tf.keras.models.Model(inputs=inputs, outputs=output_dist)
Ejemplo n.º 6
0
def _resnet_layer(inputs,
                  num_filters=16,
                  kernel_size=3,
                  strides=1,
                  activation='relu',
                  depth=20,
                  batchnorm=False,
                  conv_first=True,
                  variational=False,
                  n_examples=None):
  """2D Convolution-Batch Normalization-Activation stack builder.

  Args:
    inputs (tensor): input tensor from input image or previous layer
    num_filters (int): Conv2D number of filters
    kernel_size (int): Conv2D square kernel dimensions
    strides (int): Conv2D square stride dimensions
    activation (string): Activation function string.
    depth (int): ResNet depth; used for initialization scale.
    batchnorm (bool): whether to include batch normalization
    conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)
    variational (bool): Whether to use a variational convolutional layer.
    n_examples (int): Number of examples per epoch for variational KL.

  Returns:
      x (tensor): tensor as input to the next layer
  """
  if variational:

    def fixup_init(shape, dtype=None):
      """Fixup initialization; see https://arxiv.org/abs/1901.09321."""
      return keras.initializers.he_normal()(
          shape, dtype=dtype) * depth**(-1 / 4)

    p_fn, q_fn = mean_field_fn(empirical_bayes=True, initializer=fixup_init)

    def normalized_kl_fn(q, p, _):
      return tfp.distributions.kl_divergence(q, p) / tf.to_float(n_examples)

    conv = tfp.layers.Convolution2DFlipout(
        num_filters,
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn)
  else:
    conv = keras.layers.Conv2D(
        num_filters,
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        kernel_initializer='he_normal',
        kernel_regularizer=keras.regularizers.l2(1e-4))

  def apply_conv(net):
    return conv(net)

  x = inputs
  x = apply_conv(x) if conv_first else x
  if batchnorm:
    x = keras.layers.BatchNormalization()(x)
  x = keras.layers.Activation(activation)(x) if activation is not None else x
  x = x if conv_first else apply_conv(x)
  return x