def run_bnn_test(self, make_conv): # 1 Prepare Dataset train_size = 128 batch_size = 2 evidence_shape = [28, 28, 1] target_shape = [10] train_dataset = tf.data.Dataset.from_tensor_slices(( tf.random.uniform([train_size] + evidence_shape, maxval=1, dtype=tf.float32), tf.math.softmax(tf.random.normal([train_size] + target_shape)), )) train_dataset = nn.util.tune_dataset( train_dataset, batch_size=batch_size, shuffle_size=int(train_size / 7)) # 2 Specify Model scale = tfp.util.TransformedVariable(1., tfb.Softplus()) n = tf.cast(train_size, tf.float32) bnn = nn.Sequential([ make_conv(evidence_shape[-1], 32, filter_shape=7, strides=2, penalty_weight=1. / n), tf.nn.elu, # nn.util.trace('conv1'), # [b, 14, 14, 32] nn.util.flatten_rightmost, # nn.util.trace('flat1'), # [b, 14 * 14 * 32] nn.AffineVariationalReparameterization( 14 * 14 * 32, np.prod(target_shape) - 1, penalty_weight=1. / n), # nn.util.trace('affine1'), # [b, 9] nn.Lambda( eval_final_fn=lambda loc: tfb.SoftmaxCentered()( # pylint: disable=g-long-lambda tfd.Independent(tfd.Normal(loc, scale), reinterpreted_batch_ndims=1)), also_track=scale), # nn.util.trace('head'), # [b, 10] ], name='bayesian_autoencoder') self.evaluate([v.initializer for v in bnn.trainable_variables]) # 3 Train. train_iter = iter(train_dataset) def loss_fn(): x, y = next(train_iter) nll = -tf.reduce_mean(bnn(x).log_prob(y), axis=-1) kl = bnn.extra_loss # Already normalized. return nll + kl, (nll, kl) opt = tf.optimizers.Adam() fit_op = nn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(2): loss, (nll, kl) = fit_op() # pylint: disable=unused-variable
def test_basic(self): shift = tf.Variable(1.) scale = tfp.util.TransformedVariable(1., tfb.Exp()) f = nn.Lambda( eval_fn=lambda x: tfd.Normal(loc=x + shift, scale=scale), extra_loss_fn=lambda x: tf.norm(x.loc), # `scale` will be tracked through the distribution but not `shift`. also_track=shift) x = tf.zeros([1, 2]) y = f(x) self.assertIsInstance(y, tfd.Normal) self.assertLen(f.trainable_variables, 2) if tf.executing_eagerly(): # We want to specifically check the values when in eager mode to ensure # we're not leaking graph tensors. The precise value doesn't matter. self.assertGreaterEqual(f.extra_loss, 0.) self.assertIsNone(f.extra_result)
def run_bnn_test(self, make_conv, make_deconv): # 1 Prepare Dataset train_size = 128 batch_size = 2 train_dataset = tf.data.Dataset.from_tensor_slices( tf.random.uniform([train_size, 28, 28, 1], maxval=1, dtype=tf.float32)) train_dataset = nn.util.tune_dataset(train_dataset, batch_size=batch_size, shuffle_size=int(train_size / 7)) train_iter = iter(train_dataset) x = next(train_iter) input_channels = int(x.shape[-1]) # 2 Specify Model bottleneck_size = 2 scale = tfp.util.TransformedVariable(1., tfb.Softplus()) bnn = nn.Sequential( [ make_conv(input_channels, 32, filter_shape=5, strides=2), tf.nn.elu, # nn.util.trace('conv1'), # [b, 14, 14, 32] nn.util.flatten_rightmost(ndims=3), # nn.util.trace('flat1'), # [b, 14 * 14 * 32] nn.AffineVariationalReparameterization(14 * 14 * 32, bottleneck_size), # nn.util.trace('affine1'), # [b, 2] lambda x: x[..., tf.newaxis, tf.newaxis, :], # nn.util.trace('expand'), # [b, 1, 1, 2] make_deconv(2, 64, filter_shape=7, strides=1, padding='valid'), tf.nn.elu, # nn.util.trace('deconv1'), # [b, 7, 7, 64] make_deconv(64, 32, filter_shape=4, strides=4), tf.nn.elu, # nn.util.trace('deconv2'), # [2, 28, 28, 32] make_conv(32, 1, filter_shape=2, strides=1), # No activation. # nn.util.trace('deconv3'), # [2, 28, 28, 1] nn.Lambda( eval_fn=lambda loc: tfd.Independent( # pylint: disable=g-long-lambda tfb.Sigmoid()(tfd.Normal(loc, scale)), reinterpreted_batch_ndims=3), also_track=scale), # nn.util.trace('head'), # [b, 28, 28, 1] ], name='bayesian_autoencoder') # 3 Train. def loss_fn(): x = next(train_iter) nll = -tf.reduce_mean(bnn(x).log_prob(x), axis=-1) kl = bnn.extra_loss / tf.cast(train_size, tf.float32) loss = nll + kl return loss, (nll, kl) opt = tf.optimizers.Adam() fit_op = nn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(2): loss, (nll, kl) = fit_op() # pylint: disable=unused-variable