Exemplo n.º 1
0
    def test_batchnorm_non_trainable_with_fit(self):
        # We use the same data shape for all the data we use in this test.
        # This will prevent any used tf.functions from retracing.
        # This helps us verify that changing trainable and recompiling really
        # does update the training loop, rather than a different data shape
        # triggering a retrace.
        data_shape = (100, 3)

        inputs = keras.Input((3, ))
        bn = normalization_v2.BatchNormalization()
        outputs = bn(inputs)
        model = keras.Model(inputs, outputs)
        model.compile('rmsprop',
                      'mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        model.fit(np.random.random(data_shape), np.random.random(data_shape))

        test_data = np.random.random(data_shape)
        test_targets = np.random.random(data_shape)
        test_loss = model.evaluate(test_data, test_targets)

        bn.trainable = False
        model.compile('rmsprop',
                      'mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        train_loss = model.train_on_batch(test_data, test_targets)
        self.assertAlmostEqual(test_loss, train_loss)
Exemplo n.º 2
0
    def test_bessels_correction(self):
        # Bessel's correction is currently only used in the fused case. In the
        # future, it may be used in the nonfused case as well.

        x = tf.constant([0., 2.], shape=[2, 1, 1, 1])
        layer = normalization_v2.BatchNormalization(
            momentum=0.5, moving_variance_initializer='zeros')
        layer(x, training=True)
        self.assertTrue(layer.fused)
        # Since fused is used, Bessel's correction is used. The variance of [0, 2]
        # is 2 with Bessel's correction. Since the momentum is 0.5, the variance is
        # 2 * 0.5 == 1.
        self.assertAllEqual(self.evaluate(layer.moving_variance), [1.])

        x = tf.constant([0., 2.], shape=[2, 1, 1, 1, 1])
        layer = normalization_v2.BatchNormalization(
            momentum=0.5, moving_variance_initializer='zeros')
        layer(x, training=True)
        self.assertFalse(layer.fused)
        # Since fused is not used, Bessel's correction is not used. The variance of
        # [0, 2] is 1 without Bessel's correction. Since the momentum is 0.5, the
        # variance is 1 * 0.5 == 0.5.
        self.assertAllEqual(self.evaluate(layer.moving_variance), [0.5])
Exemplo n.º 3
0
 def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self):
     # Test case for GitHub issue for 32380
     norm = normalization_v2.BatchNormalization(virtual_batch_size=8)
     inp = keras.layers.Input(shape=(None, None, 3))
     _ = norm(inp)
Exemplo n.º 4
0
    def test_v2_fused_attribute(self):
        norm = normalization_v2.BatchNormalization()
        self.assertEqual(norm.fused, None)
        inp = keras.layers.Input(shape=(4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, True)

        norm = normalization_v2.BatchNormalization()
        self.assertEqual(norm.fused, None)
        inp = keras.layers.Input(shape=(4, 4))
        norm(inp)
        self.assertEqual(norm.fused, False)

        norm = normalization_v2.BatchNormalization()
        self.assertIsNone(norm.fused)
        inp = keras.layers.Input(shape=(4, 4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, False)

        norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
        self.assertEqual(norm.fused, False)
        inp = keras.layers.Input(shape=(4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, False)

        norm = normalization_v2.BatchNormalization(fused=False)
        self.assertEqual(norm.fused, False)
        inp = keras.layers.Input(shape=(4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, False)

        norm = normalization_v2.BatchNormalization(fused=True, axis=[3])
        self.assertEqual(norm.fused, True)
        inp = keras.layers.Input(shape=(4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, True)

        with self.assertRaisesRegex(ValueError, 'fused.*renorm'):
            normalization_v2.BatchNormalization(fused=True, renorm=True)

        with self.assertRaisesRegex(ValueError, 'fused.*when axis is 1 or 3'):
            normalization_v2.BatchNormalization(fused=True, axis=2)

        with self.assertRaisesRegex(ValueError, 'fused.*when axis is 1 or 3'):
            normalization_v2.BatchNormalization(fused=True, axis=[1, 3])

        with self.assertRaisesRegex(ValueError, 'fused.*virtual_batch_size'):
            normalization_v2.BatchNormalization(fused=True,
                                                virtual_batch_size=2)

        with self.assertRaisesRegex(ValueError, 'fused.*adjustment'):
            normalization_v2.BatchNormalization(fused=True,
                                                adjustment=lambda _: (1, 0))

        norm = normalization_v2.BatchNormalization(fused=True)
        self.assertEqual(norm.fused, True)
        inp = keras.layers.Input(shape=(4, 4))
        with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'):
            norm(inp)