def test_batchnorm_correctness(self, distribution, fused, optimizer, cloning):
    with self.cached_session():
      with distribution.scope():
        model = keras.models.Sequential()
        norm = keras.layers.BatchNormalization(
            input_shape=(
                10,
                20,
                30,
            ), momentum=0.8, fused=fused)
        model.add(norm)
        model.compile(loss='mse', optimizer=optimizer(), cloning=cloning)

      # centered on 5.0, variance 10.0
      x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 20, 30))
      x = x.astype('float32')
      dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
      dataset = dataset.repeat(100)
      dataset = keras_test_lib.batch_wrapper(dataset, 32, distribution)

      predict_dataset = dataset_ops.Dataset.from_tensor_slices(x)
      predict_dataset = predict_dataset.repeat(100)
      predict_dataset = keras_test_lib.batch_wrapper(predict_dataset, 32,
                                                     distribution)

      model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
      out = model.predict(predict_dataset, steps=2)
      out -= keras.backend.eval(norm.beta)
      out /= keras.backend.eval(norm.gamma)
      np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
      np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
Esempio n. 2
0
  def test_batchnorm_correctness(self, distribution, fused, optimizer, cloning):
    with self.cached_session():
      with distribution.scope():
        model = keras.models.Sequential()
        norm = keras.layers.BatchNormalization(
            input_shape=(
                10,
                20,
                30,
            ), momentum=0.8, fused=fused)
        model.add(norm)
        model.compile(loss='mse', optimizer=optimizer(), cloning=cloning)

      # centered on 5.0, variance 10.0
      x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 20, 30))
      x = x.astype('float32')
      dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
      dataset = dataset.repeat(100)
      dataset = keras_test_lib.batch_wrapper(dataset, 32, distribution)

      predict_dataset = dataset_ops.Dataset.from_tensor_slices(x)
      predict_dataset = predict_dataset.repeat(100)
      predict_dataset = keras_test_lib.batch_wrapper(predict_dataset, 32,
                                                     distribution)

      model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
      out = model.predict(predict_dataset, steps=2)
      out -= keras.backend.eval(norm.beta)
      out /= keras.backend.eval(norm.gamma)
      np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
      np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)