def test_connection(self): data = jnp.zeros([2, 3, 4, 5]) norms = [] for axis in range(4): norms.append(layer_norm.LayerNorm(axis=axis, create_scale=True, create_offset=True)(data)) norms.append(layer_norm.LayerNorm(axis=slice(1, None), create_scale=True, create_offset=True)(data)) norms.append(layer_norm.LayerNorm(axis=slice(2, None), create_scale=True, create_offset=True)(data)) norms.append(layer_norm.LayerNorm(axis=slice(1, -1), create_scale=True, create_offset=True)(data)) return norms
def test_slice_axis(self): slice_layer = layer_norm.LayerNorm( slice(1, -1), create_scale=False, create_offset=False) axis_layer = layer_norm.LayerNorm((1, 2), create_scale=False, create_offset=False) inputs = np.random.uniform(size=[3, 4, 4, 5], low=0, high=10) scale = np.random.normal(size=(5,), loc=1.0) offset = np.random.normal(size=(5,)) slice_outputs = slice_layer(inputs, scale, offset) axis_outputs = axis_layer(inputs, scale, offset) np.testing.assert_array_equal(slice_outputs, axis_outputs)
def test_create_offset_and_offset_provided(self): layer = layer_norm.LayerNorm([2], create_offset=True, create_scale=False) with self.assertRaisesRegex( ValueError, "Cannot pass `offset` at call time if `create_offset=True`."): layer(np.ones([2, 3, 4]), offset=np.ones([4]))
def f(x): ln = layer_norm.LayerNorm(axis=-1, create_scale=create_scale, create_offset=create_offset, use_fast_variance=use_fast_variance, param_axis=-1) return ln(x)
def test_no_offset_beta_init_provided(self): with self.assertRaisesRegex( ValueError, "Cannot set `offset_init` if `create_offset=False`."): layer_norm.LayerNorm(3, create_scale=True, create_offset=False, offset_init=np.zeros)
def test_multiple_param_axis(self, param_axis, param_shape): ln = layer_norm.LayerNorm(-1, True, True, param_axis=param_axis) x = jnp.ones([3, 4, 5, 6]) ln(x) self.assertEqual(ln.params_dict()["layer_norm/scale"].shape, param_shape) self.assertEqual(ln.params_dict()["layer_norm/offset"].shape, param_shape)
def test_no_scale_and_init_provided(self): with self.assertRaisesRegex( ValueError, "Cannot set `scale_init` if `create_scale=False`."): layer_norm.LayerNorm(3, create_scale=False, create_offset=True, scale_init=np.ones)
def test_error_prone_param_axis(self): # NOTE: This test defends current, potentially error prone behaviour # (passing axis!=-1 and not passing param_axis). It will be removed in a # future version of Haiku. ln = layer_norm.LayerNorm(1, True, True) x = jnp.ones([3, 4, 5, 6]) ln(x) self.assertEqual(ln.params_dict()["layer_norm/scale"].shape, (6, )) self.assertEqual(ln.params_dict()["layer_norm/offset"].shape, (6, ))
def test_simple_case(self): layer = layer_norm.LayerNorm([1, 2], create_scale=False, create_offset=False) inputs = np.ones([2, 3, 3, 5]) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 0.0)
def test_simple_case(self, use_fast_variance): layer = layer_norm.LayerNorm([1, 2], create_scale=False, create_offset=False, use_fast_variance=use_fast_variance, param_axis=-1) inputs = np.ones([2, 3, 3, 5]) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 0.0)
def test_connection(self): data = jnp.zeros([2, 3, 4, 5]) normalize = ( lambda a: layer_norm.LayerNorm(a, True, True, param_axis=-1)(data)) normalize(0) normalize(1) normalize(2) normalize(3) normalize(slice(1, None)) normalize(slice(2, None)) normalize(slice(1, -1))
def test_simple_case_tensor(self): layer = layer_norm.LayerNorm([1, 2], create_scale=False, create_offset=False) inputs = np.ones([2, 3, 3, 5]) scale = np.full((5, ), 0.5) offset = np.full((5, ), 2.0) outputs = layer(inputs, scale, offset) for x in np.nditer(outputs): self.assertEqual(x, 2.0)
def test_simple_case_var(self): layer = layer_norm.LayerNorm([1, 2], create_scale=True, create_offset=True, scale_init=initializers.Constant(0.5), offset_init=initializers.Constant(2.0)) inputs = np.ones([2, 3, 3, 5]) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 2.0)
def test_param_axis_not_required_for_final_axis(self): ln = layer_norm.LayerNorm(-1, True, True) x = jnp.ones([3, 4, 5, 6]) ln(x) self.assertEqual(ln.params_dict()["layer_norm/scale"].shape, (6, )) self.assertEqual(ln.params_dict()["layer_norm/offset"].shape, (6, ))
def f(x): ln = layer_norm.LayerNorm(axis=-1, create_scale=create_scale, create_offset=create_offset) return ln(x)
def test_param_axis_required_for_non_final_axis(self, axis): ln = layer_norm.LayerNorm(axis, True, True) x = jnp.ones([3, 4, 5, 6]) with self.assertRaisesRegex(ValueError, "pass.*param_axis.*in the ctor"): ln(x)
def test_invalid_axis(self, axis): with self.assertRaisesRegex( ValueError, "`axis` should be an int, slice or iterable of ints."): layer_norm.LayerNorm(axis, create_scale=False, create_offset=False)