def test_batchnorm_non_trainable_with_fit(self): # We use the same data shape for all the data we use in this test. # This will prevent any used tf.functions from retracing. # This helps us verify that changing trainable and recompiling really # does update the training loop, rather than a different data shape # triggering a retrace. data_shape = (100, 3) inputs = keras.Input((3, )) bn = normalization_v2.BatchNormalization() outputs = bn(inputs) model = keras.Model(inputs, outputs) model.compile('rmsprop', 'mse', run_eagerly=testing_utils.should_run_eagerly()) model.fit(np.random.random(data_shape), np.random.random(data_shape)) test_data = np.random.random(data_shape) test_targets = np.random.random(data_shape) test_loss = model.evaluate(test_data, test_targets) bn.trainable = False model.compile('rmsprop', 'mse', run_eagerly=testing_utils.should_run_eagerly()) train_loss = model.train_on_batch(test_data, test_targets) self.assertAlmostEqual(test_loss, train_loss)
def test_bessels_correction(self): # Bessel's correction is currently only used in the fused case. In the # future, it may be used in the nonfused case as well. x = tf.constant([0., 2.], shape=[2, 1, 1, 1]) layer = normalization_v2.BatchNormalization( momentum=0.5, moving_variance_initializer='zeros') layer(x, training=True) self.assertTrue(layer.fused) # Since fused is used, Bessel's correction is used. The variance of [0, 2] # is 2 with Bessel's correction. Since the momentum is 0.5, the variance is # 2 * 0.5 == 1. self.assertAllEqual(self.evaluate(layer.moving_variance), [1.]) x = tf.constant([0., 2.], shape=[2, 1, 1, 1, 1]) layer = normalization_v2.BatchNormalization( momentum=0.5, moving_variance_initializer='zeros') layer(x, training=True) self.assertFalse(layer.fused) # Since fused is not used, Bessel's correction is not used. The variance of # [0, 2] is 1 without Bessel's correction. Since the momentum is 0.5, the # variance is 1 * 0.5 == 0.5. self.assertAllEqual(self.evaluate(layer.moving_variance), [0.5])
def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self): # Test case for GitHub issue for 32380 norm = normalization_v2.BatchNormalization(virtual_batch_size=8) inp = keras.layers.Input(shape=(None, None, 3)) _ = norm(inp)
def test_v2_fused_attribute(self): norm = normalization_v2.BatchNormalization() self.assertEqual(norm.fused, None) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, True) norm = normalization_v2.BatchNormalization() self.assertEqual(norm.fused, None) inp = keras.layers.Input(shape=(4, 4)) norm(inp) self.assertEqual(norm.fused, False) norm = normalization_v2.BatchNormalization() self.assertIsNone(norm.fused) inp = keras.layers.Input(shape=(4, 4, 4, 4)) norm(inp) self.assertEqual(norm.fused, False) norm = normalization_v2.BatchNormalization(virtual_batch_size=2) self.assertEqual(norm.fused, False) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, False) norm = normalization_v2.BatchNormalization(fused=False) self.assertEqual(norm.fused, False) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, False) norm = normalization_v2.BatchNormalization(fused=True, axis=[3]) self.assertEqual(norm.fused, True) inp = keras.layers.Input(shape=(4, 4, 4)) norm(inp) self.assertEqual(norm.fused, True) with self.assertRaisesRegex(ValueError, 'fused.*renorm'): normalization_v2.BatchNormalization(fused=True, renorm=True) with self.assertRaisesRegex(ValueError, 'fused.*when axis is 1 or 3'): normalization_v2.BatchNormalization(fused=True, axis=2) with self.assertRaisesRegex(ValueError, 'fused.*when axis is 1 or 3'): normalization_v2.BatchNormalization(fused=True, axis=[1, 3]) with self.assertRaisesRegex(ValueError, 'fused.*virtual_batch_size'): normalization_v2.BatchNormalization(fused=True, virtual_batch_size=2) with self.assertRaisesRegex(ValueError, 'fused.*adjustment'): normalization_v2.BatchNormalization(fused=True, adjustment=lambda _: (1, 0)) norm = normalization_v2.BatchNormalization(fused=True) self.assertEqual(norm.fused, True) inp = keras.layers.Input(shape=(4, 4)) with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'): norm(inp)