def test_slice_axis(self): slice_layer = rms_norm.RMSNorm(slice(1, -1)) axis_layer = rms_norm.RMSNorm((1, 2)) inputs = np.random.uniform(size=[3, 4, 4, 5], low=0, high=10) slice_outputs = slice_layer(inputs) axis_outputs = axis_layer(inputs) np.testing.assert_array_equal(slice_outputs, axis_outputs)
def test_connection(self): data = jnp.zeros([2, 3, 4, 5]) norms = [] for axis in range(4): norms.append(rms_norm.RMSNorm(axis=axis)(data)) norms.append(rms_norm.RMSNorm(axis=slice(1, None))(data)) norms.append(rms_norm.RMSNorm(axis=slice(2, None))(data)) norms.append(rms_norm.RMSNorm(axis=slice(1, -1))(data)) return norms
def test_simple_case_with_scale(self): layer = rms_norm.RMSNorm(axis=[1, 2], eps=0.0, scale_init=initializers.Constant(0.5)) inputs = np.full(shape=[2, 3, 3, 5], fill_value=2.0) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 0.5)
def test_invalid_axis(self, axis): with self.assertRaisesRegex( ValueError, "`axis` should be an int, slice or iterable of ints."): rms_norm.RMSNorm(axis)
def test_zero_inputs(self): layer = rms_norm.RMSNorm([1, 2]) inputs = np.zeros([2, 3, 3, 5]) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 0.0)
def test_simple_case(self): layer = rms_norm.RMSNorm([1, 2], eps=0.0) inputs = np.full(shape=[2, 3, 3, 5], fill_value=2.0) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 1.0)
def f(x): ln = rms_norm.RMSNorm(axis=-1) return ln(x)