def test_summary(self): model = nn.Sequential([ lambda x: tf.reshape(x, [-1, 3]), AffineMeanFieldNormal(output_size=5, input_size=3), AffineMeanFieldNormal(output_size=2, input_size=5), ]) self.assertEqual('trainable size: 32 / 0.000 MiB / {float32: 32}', model.summary().split('\n')[-1])
def run_bnn_test(self, make_conv): # 1 Prepare Dataset train_size = 128 batch_size = 2 evidence_shape = [28, 28, 1] target_shape = [10] train_dataset = tf.data.Dataset.from_tensor_slices(( tf.random.uniform([train_size] + evidence_shape, maxval=1, dtype=tf.float32), tf.math.softmax(tf.random.normal([train_size] + target_shape)), )) train_dataset = nn.util.tune_dataset( train_dataset, batch_size=batch_size, shuffle_size=int(train_size / 7)) # 2 Specify Model scale = tfp.util.TransformedVariable(1., tfb.Softplus()) n = tf.cast(train_size, tf.float32) bnn = nn.Sequential([ make_conv(evidence_shape[-1], 32, filter_shape=7, strides=2, penalty_weight=1. / n), tf.nn.elu, # nn.util.trace('conv1'), # [b, 14, 14, 32] nn.util.flatten_rightmost, # nn.util.trace('flat1'), # [b, 14 * 14 * 32] nn.AffineVariationalReparameterization( 14 * 14 * 32, np.prod(target_shape) - 1, penalty_weight=1. / n), # nn.util.trace('affine1'), # [b, 9] nn.Lambda( eval_final_fn=lambda loc: tfb.SoftmaxCentered()( # pylint: disable=g-long-lambda tfd.Independent(tfd.Normal(loc, scale), reinterpreted_batch_ndims=1)), also_track=scale), # nn.util.trace('head'), # [b, 10] ], name='bayesian_autoencoder') self.evaluate([v.initializer for v in bnn.trainable_variables]) # 3 Train. train_iter = iter(train_dataset) def loss_fn(): x, y = next(train_iter) nll = -tf.reduce_mean(bnn(x).log_prob(y), axis=-1) kl = bnn.extra_loss # Already normalized. return nll + kl, (nll, kl) opt = tf.optimizers.Adam() fit_op = nn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(2): loss, (nll, kl) = fit_op() # pylint: disable=unused-variable
def test_works_correctly(self): input_size = 3 output_size = 5 model = nn.Sequential([ lambda x: tf.reshape(x, [-1, input_size]), AffineMeanFieldNormal(output_size=5, input_size=input_size), AffineMeanFieldNormal(output_size=output_size, input_size=5), ]) self.assertLen(model.trainable_variables, 4) self.evaluate([v.initializer for v in model.trainable_variables]) self.assertLen(model.layers, 3) self.assertEqual( '<Sequential: name=lambda__AffineMeanFieldNormal_AffineMeanFieldNormal>', str(model)) x = tf.zeros([2, 1, input_size]) y = model(x) self.assertIsInstance(y, tfd.Independent) self.assertAllEqual((2, output_size), y.distribution.loc.shape) extra_loss = [ l.extra_loss for l in model.layers if getattr(l, 'extra_loss', None) is not None ] extra_result = [ l.extra_result for l in model.layers if getattr(l, 'extra_result', None) is not None ] self.assertIsNone(model.extra_result) self.assertAllEqual([(2, ), (2, )], [x.shape for x in extra_loss]) extra_loss_, extra_result_, model_extra_loss_ = self.evaluate( [extra_loss, extra_result, model.extra_loss]) self.assertAllGreaterEqual(extra_loss_, 0.) self.assertAllEqual([[2, 3], [2, 5]], extra_result_) self.assertAllClose(sum(extra_loss_), model_extra_loss_, rtol=1e-3, atol=1e-3)
def run_bnn_test(self, make_conv, make_deconv): # 1 Prepare Dataset train_size = 128 batch_size = 2 train_dataset = tf.data.Dataset.from_tensor_slices( tf.random.uniform([train_size, 28, 28, 1], maxval=1, dtype=tf.float32)) train_dataset = nn.util.tune_dataset(train_dataset, batch_size=batch_size, shuffle_size=int(train_size / 7)) train_iter = iter(train_dataset) x = next(train_iter) input_channels = int(x.shape[-1]) # 2 Specify Model bottleneck_size = 2 scale = tfp.util.TransformedVariable(1., tfb.Softplus()) bnn = nn.Sequential( [ make_conv(input_channels, 32, filter_shape=5, strides=2), tf.nn.elu, # nn.util.trace('conv1'), # [b, 14, 14, 32] nn.util.flatten_rightmost(ndims=3), # nn.util.trace('flat1'), # [b, 14 * 14 * 32] nn.AffineVariationalReparameterization(14 * 14 * 32, bottleneck_size), # nn.util.trace('affine1'), # [b, 2] lambda x: x[..., tf.newaxis, tf.newaxis, :], # nn.util.trace('expand'), # [b, 1, 1, 2] make_deconv(2, 64, filter_shape=7, strides=1, padding='valid'), tf.nn.elu, # nn.util.trace('deconv1'), # [b, 7, 7, 64] make_deconv(64, 32, filter_shape=4, strides=4), tf.nn.elu, # nn.util.trace('deconv2'), # [2, 28, 28, 32] make_conv(32, 1, filter_shape=2, strides=1), # No activation. # nn.util.trace('deconv3'), # [2, 28, 28, 1] nn.Lambda( eval_fn=lambda loc: tfd.Independent( # pylint: disable=g-long-lambda tfb.Sigmoid()(tfd.Normal(loc, scale)), reinterpreted_batch_ndims=3), also_track=scale), # nn.util.trace('head'), # [b, 28, 28, 1] ], name='bayesian_autoencoder') # 3 Train. def loss_fn(): x = next(train_iter) nll = -tf.reduce_mean(bnn(x).log_prob(x), axis=-1) kl = bnn.extra_loss / tf.cast(train_size, tf.float32) loss = nll + kl return loss, (nll, kl) opt = tf.optimizers.Adam() fit_op = nn.util.make_fit_op(loss_fn, opt, bnn.trainable_variables) for _ in range(2): loss, (nll, kl) = fit_op() # pylint: disable=unused-variable