Esempio n. 1
0
    def test_slice_axis(self):
        slice_layer = rms_norm.RMSNorm(slice(1, -1))
        axis_layer = rms_norm.RMSNorm((1, 2))
        inputs = np.random.uniform(size=[3, 4, 4, 5], low=0, high=10)

        slice_outputs = slice_layer(inputs)
        axis_outputs = axis_layer(inputs)

        np.testing.assert_array_equal(slice_outputs, axis_outputs)
Esempio n. 2
0
    def test_connection(self):
        data = jnp.zeros([2, 3, 4, 5])
        norms = []
        for axis in range(4):
            norms.append(rms_norm.RMSNorm(axis=axis)(data))

        norms.append(rms_norm.RMSNorm(axis=slice(1, None))(data))
        norms.append(rms_norm.RMSNorm(axis=slice(2, None))(data))
        norms.append(rms_norm.RMSNorm(axis=slice(1, -1))(data))

        return norms
Esempio n. 3
0
 def test_simple_case_with_scale(self):
     layer = rms_norm.RMSNorm(axis=[1, 2],
                              eps=0.0,
                              scale_init=initializers.Constant(0.5))
     inputs = np.full(shape=[2, 3, 3, 5], fill_value=2.0)
     outputs = layer(inputs)
     for x in np.nditer(outputs):
         self.assertEqual(x, 0.5)
Esempio n. 4
0
 def test_invalid_axis(self, axis):
     with self.assertRaisesRegex(
             ValueError,
             "`axis` should be an int, slice or iterable of ints."):
         rms_norm.RMSNorm(axis)
Esempio n. 5
0
 def test_zero_inputs(self):
     layer = rms_norm.RMSNorm([1, 2])
     inputs = np.zeros([2, 3, 3, 5])
     outputs = layer(inputs)
     for x in np.nditer(outputs):
         self.assertEqual(x, 0.0)
Esempio n. 6
0
 def test_simple_case(self):
     layer = rms_norm.RMSNorm([1, 2], eps=0.0)
     inputs = np.full(shape=[2, 3, 3, 5], fill_value=2.0)
     outputs = layer(inputs)
     for x in np.nditer(outputs):
         self.assertEqual(x, 1.0)
Esempio n. 7
0
 def f(x):
     ln = rms_norm.RMSNorm(axis=-1)
     return ln(x)