def run_bnn_test(self, make_conv):
    # 1  Prepare Dataset

    train_size = 128
    batch_size = 2
    evidence_shape = [28, 28, 1]
    target_shape = [10]

    train_dataset = tf.data.Dataset.from_tensor_slices((
        tf.random.uniform([train_size] + evidence_shape,
                          maxval=1, dtype=tf.float32),
        tf.math.softmax(tf.random.normal([train_size] + target_shape)),
    ))
    train_dataset = nn.util.tune_dataset(
        train_dataset,
        batch_size=batch_size,
        shuffle_size=int(train_size / 7))

    # 2  Specify Model

    scale = tfp.util.TransformedVariable(1., tfb.Softplus())
    n = tf.cast(train_size, tf.float32)
    bnn = nn.Sequential([
        make_conv(evidence_shape[-1], 32, filter_shape=7, strides=2,
                  penalty_weight=1. / n),
        tf.nn.elu,
        # nn.util.trace('conv1'),    # [b, 14, 14, 32]
        nn.util.flatten_rightmost,
        # nn.util.trace('flat1'),    # [b, 14 * 14 * 32]
        nn.AffineVariationalReparameterization(
            14 * 14 * 32, np.prod(target_shape) - 1,
            penalty_weight=1. / n),
        # nn.util.trace('affine1'),  # [b, 9]
        nn.Lambda(
            eval_final_fn=lambda loc: tfb.SoftmaxCentered()(  # pylint: disable=g-long-lambda
                tfd.Independent(tfd.Normal(loc, scale),
                                reinterpreted_batch_ndims=1)),
            also_track=scale),
        # nn.util.trace('head'),     # [b, 10]
    ], name='bayesian_autoencoder')

    self.evaluate([v.initializer for v in bnn.trainable_variables])

    # 3  Train.

    train_iter = iter(train_dataset)
    def loss_fn():
      x, y = next(train_iter)
      nll = -tf.reduce_mean(bnn(x).log_prob(y), axis=-1)
      kl = bnn.extra_loss  # Already normalized.
      return nll + kl, (nll, kl)
    opt = tf.optimizers.Adam()
    fit_op = nn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables)
    for _ in range(2):
      loss, (nll, kl) = fit_op()  # pylint: disable=unused-variable
Example #2
0
 def test_basic(self):
     shift = tf.Variable(1.)
     scale = tfp.util.TransformedVariable(1., tfb.Exp())
     f = nn.Lambda(
         eval_fn=lambda x: tfd.Normal(loc=x + shift, scale=scale),
         extra_loss_fn=lambda x: tf.norm(x.loc),
         # `scale` will be tracked through the distribution but not `shift`.
         also_track=shift)
     x = tf.zeros([1, 2])
     y = f(x)
     self.assertIsInstance(y, tfd.Normal)
     self.assertLen(f.trainable_variables, 2)
     if tf.executing_eagerly():
         # We want to specifically check the values when in eager mode to ensure
         # we're not leaking graph tensors. The precise value doesn't matter.
         self.assertGreaterEqual(f.extra_loss, 0.)
     self.assertIsNone(f.extra_result)
Example #3
0
    def run_bnn_test(self, make_conv, make_deconv):
        # 1  Prepare Dataset

        train_size = 128
        batch_size = 2
        train_dataset = tf.data.Dataset.from_tensor_slices(
            tf.random.uniform([train_size, 28, 28, 1],
                              maxval=1,
                              dtype=tf.float32))
        train_dataset = nn.util.tune_dataset(train_dataset,
                                             batch_size=batch_size,
                                             shuffle_size=int(train_size / 7))
        train_iter = iter(train_dataset)
        x = next(train_iter)
        input_channels = int(x.shape[-1])

        # 2  Specify Model

        bottleneck_size = 2

        scale = tfp.util.TransformedVariable(1., tfb.Softplus())

        bnn = nn.Sequential(
            [
                make_conv(input_channels, 32, filter_shape=5, strides=2),
                tf.nn.elu,
                # nn.util.trace('conv1'),    # [b, 14, 14, 32]
                nn.util.flatten_rightmost(ndims=3),
                # nn.util.trace('flat1'),    # [b, 14 * 14 * 32]
                nn.AffineVariationalReparameterization(14 * 14 * 32,
                                                       bottleneck_size),
                # nn.util.trace('affine1'),  # [b, 2]
                lambda x: x[..., tf.newaxis, tf.newaxis, :],
                # nn.util.trace('expand'),   # [b, 1, 1, 2]
                make_deconv(2, 64, filter_shape=7, strides=1, padding='valid'),
                tf.nn.elu,
                # nn.util.trace('deconv1'),  # [b, 7, 7, 64]
                make_deconv(64, 32, filter_shape=4, strides=4),
                tf.nn.elu,
                # nn.util.trace('deconv2'),  # [2, 28, 28, 32]
                make_conv(32, 1, filter_shape=2, strides=1),
                # No activation.
                # nn.util.trace('deconv3'),  # [2, 28, 28, 1]
                nn.Lambda(
                    eval_fn=lambda loc: tfd.Independent(  # pylint: disable=g-long-lambda
                        tfb.Sigmoid()(tfd.Normal(loc, scale)),
                        reinterpreted_batch_ndims=3),
                    also_track=scale),
                # nn.util.trace('head'),     # [b, 28, 28, 1]
            ],
            name='bayesian_autoencoder')

        # 3  Train.

        def loss_fn():
            x = next(train_iter)
            nll = -tf.reduce_mean(bnn(x).log_prob(x), axis=-1)
            kl = bnn.extra_loss / tf.cast(train_size, tf.float32)
            loss = nll + kl
            return loss, (nll, kl)

        opt = tf.optimizers.Adam()
        fit_op = nn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables)
        for _ in range(2):
            loss, (nll, kl) = fit_op()  # pylint: disable=unused-variable