def test_normalize_train(self): x = np.random.uniform(size=(2, 3, 4, 5)) normalize = normalization.Normalize(mode='train', epsilon=0.0) normalize.init(shapes.signature(x)) old_state = normalize.state y = normalize(x) np.testing.assert_equal(normalize.state, old_state) np.testing.assert_almost_equal(x, y)
def test_normalize_collect(self): x = np.random.uniform(size=(2, 3, 4, 5)) normalize = normalization.Normalize(mode='collect') normalize.init(shapes.signature(x)) old_state = normalize.state y = normalize(x) with self.assertRaises(AssertionError): np.testing.assert_equal(normalize.state, old_state) with self.assertRaises(AssertionError): np.testing.assert_almost_equal(x, y)