Пример #1
0
 def test_mean_field_fn(self):
     p_fn, q_fn = sampling.mean_field_fn()
     layer = tfp.layers.DenseLocalReparameterization(
         100,
         kernel_prior_fn=p_fn,
         kernel_posterior_fn=q_fn,
         bias_prior_fn=p_fn,
         bias_posterior_fn=q_fn)
     self.assertIsInstance(layer, tfp.layers.DenseLocalReparameterization)
Пример #2
0
def res_net(n_examples,
            input_shape,
            num_classes,
            batchnorm=False,
            variational='full'):
    """Wrapper for build_resnet_v1.

  Args:
    n_examples (int): number of training points.
    input_shape (list): input shape.
    num_classes (int): number of classes (CIFAR10 has 10).
    batchnorm (bool): use of batchnorm layers.
    variational (str): 'none', 'hybrid', 'full'. whether to use variational
      inference for zero, some, or all layers.

  Returns:
      model (Model): Keras model instance whose output is a
        tfp.distributions.Categorical distribution.
  """
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = build_resnet_v1(inputs,
                        depth=20,
                        variational=variational,
                        batchnorm=batchnorm,
                        n_examples=n_examples)

    p_fn, q_fn = mean_field_fn(empirical_bayes=True)

    def normalized_kl_fn(q, p, _):
        return tfp.distributions.kl_divergence(q, p) / tf.to_float(n_examples)

    logits = tfp.layers.DenseLocalReparameterization(
        num_classes,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(x)
    outputs = tf.keras.layers.Lambda(lambda x: ed.Categorical(logits=x))(
        logits)
    return tf.keras.models.Model(inputs=inputs, outputs=outputs)
Пример #3
0
    def test_sample_auxiliary_op(self):
        p_fn, q_fn = sampling.mean_field_fn()
        p = p_fn(tf.float32, (), 'test_prior', True,
                 tf.get_variable).distribution
        q = q_fn(tf.float32, (), 'test_posterior', True,
                 tf.get_variable).distribution

        # Test benign auxiliary variable
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 1e-10)
        sess = tf.Session()
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1.1, session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(0.5),
                                   session=sess)
        print(sess.run(q.scale))

        sess.run(sample_op)

        tolerance = 0.0001
        self.assertLess(np.abs(sess.run(p.scale) - 1.), tolerance)
        self.assertLess(np.abs(sess.run(p.loc) - 1.), tolerance)
        self.assertLess(np.abs(sess.run(q.scale) - 0.5), tolerance)
        self.assertLess(np.abs(sess.run(q.loc) - 1.1), tolerance)

        # Test fully determining auxiliary variable
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 1. - 1e-10)
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1.1, session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(.5), session=sess)

        sess.run(sample_op)

        self.assertLess(np.abs(sess.run(q.loc) - sess.run(p.loc)), tolerance)
        self.assertLess(sess.run(p.scale), tolerance)
        self.assertLess(sess.run(q.scale), tolerance)

        # Test delta posterior
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 0.5)
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1.1, session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(1e-10),
                                   session=sess)

        sess.run(sample_op)

        self.assertLess(np.abs(sess.run(q.loc) - 1.1), tolerance)
        self.assertLess(sess.run(q.scale), tolerance)

        # Test prior is posterior
        sample_op, _ = sampling.sample_auxiliary_op(p, q, 0.5)
        sess.run(tf.initialize_all_variables())
        p.loc.load(1., session=sess)
        p.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)
        q.loc.load(1., session=sess)
        q.untransformed_scale.load(self._softplus_inverse_np(1.), session=sess)

        sess.run(sample_op)

        self.assertLess(np.abs(sess.run(q.loc - p.loc)), tolerance)
        self.assertLess(np.abs(sess.run(q.scale - p.scale)), tolerance)
Пример #4
0
def _resnet_layer(inputs,
                  num_filters=16,
                  kernel_size=3,
                  strides=1,
                  activation='relu',
                  depth=20,
                  batchnorm=False,
                  conv_first=True,
                  variational=False,
                  n_examples=None):
    """2D Convolution-Batch Normalization-Activation stack builder.

  Args:
    inputs (tensor): input tensor from input image or previous layer
    num_filters (int): Conv2D number of filters
    kernel_size (int): Conv2D square kernel dimensions
    strides (int): Conv2D square stride dimensions
    activation (string): Activation function string.
    depth (int): ResNet depth; used for initialization scale.
    batchnorm (bool): whether to include batch normalization
    conv_first (bool): conv-bn-activation (True) or bn-activation-conv (False)
    variational (bool): Whether to use a variational convolutional layer.
    n_examples (int): Number of examples per epoch for variational KL.

  Returns:
      x (tensor): tensor as input to the next layer
  """
    if variational:

        def fixup_init(shape, dtype=None):
            """Fixup initialization; see https://arxiv.org/abs/1901.09321."""
            return keras.initializers.he_normal()(
                shape, dtype=dtype) * depth**(-1 / 4)

        p_fn, q_fn = mean_field_fn(empirical_bayes=True,
                                   initializer=fixup_init)

        def normalized_kl_fn(q, p, _):
            return tfp.distributions.kl_divergence(q,
                                                   p) / tf.to_float(n_examples)

        conv = tfp.layers.Convolution2DFlipout(
            num_filters,
            kernel_size=kernel_size,
            strides=strides,
            padding='same',
            kernel_prior_fn=p_fn,
            kernel_posterior_fn=q_fn,
            kernel_divergence_fn=normalized_kl_fn)
    else:
        conv = keras.layers.Conv2D(
            num_filters,
            kernel_size=kernel_size,
            strides=strides,
            padding='same',
            kernel_initializer='he_normal',
            kernel_regularizer=keras.regularizers.l2(1e-4))

    def apply_conv(net):
        return conv(net)

    x = inputs
    x = apply_conv(x) if conv_first else x
    if batchnorm:
        x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation(activation)(x) if activation is not None else x
    x = x if conv_first else apply_conv(x)
    return x
Пример #5
0
def lenet5(n_examples, input_shape, num_classes):
    """Builds Bayesian LeNet5."""
    p_fn, q_fn = mean_field_fn(empirical_bayes=True)

    def normalized_kl_fn(q, p, _):
        return q.kl_divergence(p) / tf.cast(n_examples, tf.float32)

    inputs = tf.keras.layers.Input(shape=input_shape)
    conv1 = tfp.layers.Convolution2DFlipout(
        6,
        kernel_size=5,
        padding='SAME',
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(inputs)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
                                         strides=[2, 2],
                                         padding='SAME')(conv1)
    conv2 = tfp.layers.Convolution2DFlipout(
        16,
        kernel_size=5,
        padding='SAME',
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(pool1)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=[2, 2],
                                         strides=[2, 2],
                                         padding='SAME')(conv2)
    conv3 = tfp.layers.Convolution2DFlipout(
        120,
        kernel_size=5,
        padding='SAME',
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(pool2)
    flatten = tf.keras.layers.Flatten()(conv3)
    dense1 = tfp.layers.DenseLocalReparameterization(
        84,
        activation=tf.nn.relu,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(flatten)
    dense2 = tfp.layers.DenseLocalReparameterization(
        num_classes,
        kernel_prior_fn=p_fn,
        kernel_posterior_fn=q_fn,
        bias_prior_fn=p_fn,
        bias_posterior_fn=q_fn,
        kernel_divergence_fn=normalized_kl_fn,
        bias_divergence_fn=normalized_kl_fn)(dense1)
    outputs = tf.keras.layers.Lambda(lambda x: ed.Categorical(logits=x))(
        dense2)
    return tf.keras.models.Model(inputs=inputs, outputs=outputs)