def test_input_signatures_serial_batch_norm(self):
        # Include a layer that actively uses state.
        batch_norm = normalization.BatchNorm()
        relu = core.Relu()
        batch_norm_and_relu = cb.Serial(batch_norm, relu)

        # Check for correct shapes entering and exiting the batch_norm layer.
        # And the code should run without errors.
        batch_norm_and_relu.input_signature = ShapeDtype((3, 28, 28))
        self.assertEqual(batch_norm.input_signature, ShapeDtype((3, 28, 28)))
        self.assertEqual(relu.input_signature, ShapeDtype((3, 28, 28)))
Beispiel #2
0
 def test_batch_norm(self):
   input_shape = (2, 3, 4)
   input_dtype = np.float32
   input_signature = ShapeDtype(input_shape, input_dtype)
   eps = 1e-5
   inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype),
                     input_shape)
   m1 = 11.5  # Mean of this random input.
   v1 = 47.9167  # Variance of this random input.
   layer = normalization.BatchNorm(axis=(0, 1, 2))
   _, _ = layer.initialize_once(input_signature)
   state = layer.state
   onp.testing.assert_allclose(state[0], 0)
   onp.testing.assert_allclose(state[1], 1)
   self.assertEqual(state[2], 0)
   out = layer(inp1)
   state = layer.state
   onp.testing.assert_allclose(state[0], m1 * 0.001)
   onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6)
   self.assertEqual(state[2], 1)
   onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps),
                               rtol=1e-6)
Beispiel #3
0
 def test_batch_norm_shape(self):
   input_signature = ShapeDtype((29, 5, 7, 20))
   result_shape = base.check_shape_agreement(normalization.BatchNorm(),
                                             input_signature)
   self.assertEqual(result_shape, input_signature.shape)