Beispiel #1
0
    def test_bessels_correction(self):
        # Bessel's correction is currently only used in the fused case. In the
        # future, it may be used in the nonfused case as well.

        x = tf.constant([0.0, 2.0], shape=[2, 1, 1, 1])
        layer = batch_normalization.BatchNormalization(
            momentum=0.5, moving_variance_initializer="zeros"
        )
        layer(x, training=True)
        self.assertTrue(layer.fused)
        # Since fused is used, Bessel's correction is used. The variance of [0,
        # 2] is 2 with Bessel's correction. Since the momentum is 0.5, the
        # variance is 2 * 0.5 == 1.
        self.assertAllEqual(self.evaluate(layer.moving_variance), [1.0])

        x = tf.constant([0.0, 2.0], shape=[2, 1, 1, 1, 1])
        layer = batch_normalization.BatchNormalization(
            momentum=0.5, moving_variance_initializer="zeros"
        )
        layer(x, training=True)
        self.assertTrue(layer.fused)
        # Since fused is used, Bessel's correction is used. The variance of [0,
        # 2] is 2 with Bessel's correction. Since the momentum is 0.5, the
        # variance is 2 * 0.5 == 1.
        self.assertAllEqual(self.evaluate(layer.moving_variance), [1.0])
    def test_batchnorm_non_trainable_with_fit(self):
        # We use the same data shape for all the data we use in this test.
        # This will prevent any used tf.functions from retracing.
        # This helps us verify that changing trainable and recompiling really
        # does update the training loop, rather than a different data shape
        # triggering a retrace.
        data_shape = (100, 3)

        inputs = keras.Input((3, ))
        bn = batch_normalization.BatchNormalization()
        outputs = bn(inputs)
        model = keras.Model(inputs, outputs)
        model.compile('rmsprop',
                      'mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        model.fit(np.random.random(data_shape), np.random.random(data_shape))

        test_data = np.random.random(data_shape)
        test_targets = np.random.random(data_shape)
        test_loss = model.evaluate(test_data, test_targets)

        bn.trainable = False
        model.compile('rmsprop',
                      'mse',
                      run_eagerly=testing_utils.should_run_eagerly())
        train_loss = model.train_on_batch(test_data, test_targets)
        self.assertAlmostEqual(test_loss, train_loss)
Beispiel #3
0
  def test_fused_batchnorm_empty_batch(self):
    # Test case for https://github.com/tensorflow/tensorflow/issues/52986
    # create a simple strategy with the enable_partial_batch_handling flag
    # turned on, to trigger the empty batch code path in fused batchnorm
    strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
    strategy.extended.enable_partial_batch_handling = True
    with strategy.scope():
      layer = batch_normalization.BatchNormalization()

    def fn():
      with tf.GradientTape() as tape:
        x = tf.ones((0, 2, 2, 2))
        layer(x, training=True)
      return tape

    tape = strategy.run(fn)

    self.assertTrue(layer.fused)

    self.assertIsNotNone(layer.moving_mean)
    self.assertIsNotNone(layer.moving_variance)

    tape_vars = tape.watched_variables()
    self.assertAllEqual(layer.gamma, tape_vars[0])
    self.assertAllEqual(layer.beta, tape_vars[1])
 def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self):
     # Test case for GitHub issue for 32380
     norm = batch_normalization.BatchNormalization(virtual_batch_size=8)
     inp = keras.layers.Input(shape=(None, None, 3))
     _ = norm(inp)
    def test_v2_fused_attribute(self):
        norm = batch_normalization.BatchNormalization()
        self.assertIsNone(norm.fused)
        inp = keras.layers.Input(shape=(4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, True)

        norm = batch_normalization.BatchNormalization()
        self.assertIsNone(norm.fused)
        inp = keras.layers.Input(shape=(4, 4))
        norm(inp)
        self.assertEqual(norm.fused, False)

        norm = batch_normalization.BatchNormalization()
        self.assertIsNone(norm.fused)
        inp = keras.layers.Input(shape=(4, 4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, True)

        norm = batch_normalization.BatchNormalization(virtual_batch_size=2)
        self.assertEqual(norm.fused, False)
        inp = keras.layers.Input(shape=(4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, False)

        norm = batch_normalization.BatchNormalization(fused=False)
        self.assertEqual(norm.fused, False)
        inp = keras.layers.Input(shape=(4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, False)

        norm = batch_normalization.BatchNormalization(fused=True, axis=[3])
        self.assertEqual(norm.fused, True)
        inp = keras.layers.Input(shape=(4, 4, 4))
        norm(inp)
        self.assertEqual(norm.fused, True)

        with self.assertRaisesRegex(ValueError, 'fused.*renorm'):
            batch_normalization.BatchNormalization(fused=True, renorm=True)

        with self.assertRaisesRegex(ValueError, 'fused.*when axis is 1 or 3'):
            batch_normalization.BatchNormalization(fused=True, axis=2)

        with self.assertRaisesRegex(ValueError, 'fused.*when axis is 1 or 3'):
            batch_normalization.BatchNormalization(fused=True, axis=[1, 3])

        with self.assertRaisesRegex(ValueError, 'fused.*virtual_batch_size'):
            batch_normalization.BatchNormalization(fused=True,
                                                   virtual_batch_size=2)

        with self.assertRaisesRegex(ValueError, 'fused.*adjustment'):
            batch_normalization.BatchNormalization(fused=True,
                                                   adjustment=lambda _: (1, 0))

        norm = batch_normalization.BatchNormalization(fused=True)
        self.assertEqual(norm.fused, True)
        inp = keras.layers.Input(shape=(4, 4))
        with self.assertRaisesRegex(ValueError, '4D or 5D input tensors'):
            norm(inp)