def test_input_signatures_serial_batch_norm(self): # Include a layer that actively uses state. batch_norm = normalization.BatchNorm() relu = core.Relu() batch_norm_and_relu = cb.Serial(batch_norm, relu) # Check for correct shapes entering and exiting the batch_norm layer. # And the code should run without errors. batch_norm_and_relu.input_signature = ShapeDtype((3, 28, 28)) self.assertEqual(batch_norm.input_signature, ShapeDtype((3, 28, 28))) self.assertEqual(relu.input_signature, ShapeDtype((3, 28, 28)))
def test_batch_norm(self): input_shape = (2, 3, 4) input_dtype = np.float32 input_signature = ShapeDtype(input_shape, input_dtype) eps = 1e-5 inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), input_shape) m1 = 11.5 # Mean of this random input. v1 = 47.9167 # Variance of this random input. layer = normalization.BatchNorm(axis=(0, 1, 2)) _, _ = layer.initialize_once(input_signature) state = layer.state onp.testing.assert_allclose(state[0], 0) onp.testing.assert_allclose(state[1], 1) self.assertEqual(state[2], 0) out = layer(inp1) state = layer.state onp.testing.assert_allclose(state[0], m1 * 0.001) onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6) self.assertEqual(state[2], 1) onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), rtol=1e-6)
def test_batch_norm_shape(self): input_signature = ShapeDtype((29, 5, 7, 20)) result_shape = base.check_shape_agreement(normalization.BatchNorm(), input_signature) self.assertEqual(result_shape, input_signature.shape)