예제 #1
0
  def test_connection(self):
    data = jnp.zeros([2, 3, 4, 5])
    norms = []
    for axis in range(4):
      norms.append(layer_norm.LayerNorm(axis=axis, create_scale=True,
                                        create_offset=True)(data))

    norms.append(layer_norm.LayerNorm(axis=slice(1, None), create_scale=True,
                                      create_offset=True)(data))
    norms.append(layer_norm.LayerNorm(axis=slice(2, None), create_scale=True,
                                      create_offset=True)(data))
    norms.append(layer_norm.LayerNorm(axis=slice(1, -1), create_scale=True,
                                      create_offset=True)(data))

    return norms
예제 #2
0
  def test_slice_axis(self):
    slice_layer = layer_norm.LayerNorm(
        slice(1, -1), create_scale=False, create_offset=False)
    axis_layer = layer_norm.LayerNorm((1, 2),
                                      create_scale=False,
                                      create_offset=False)

    inputs = np.random.uniform(size=[3, 4, 4, 5], low=0, high=10)
    scale = np.random.normal(size=(5,), loc=1.0)
    offset = np.random.normal(size=(5,))

    slice_outputs = slice_layer(inputs, scale, offset)
    axis_outputs = axis_layer(inputs, scale, offset)

    np.testing.assert_array_equal(slice_outputs, axis_outputs)
예제 #3
0
  def test_create_offset_and_offset_provided(self):
    layer = layer_norm.LayerNorm([2], create_offset=True, create_scale=False)

    with self.assertRaisesRegex(
        ValueError,
        "Cannot pass `offset` at call time if `create_offset=True`."):
      layer(np.ones([2, 3, 4]), offset=np.ones([4]))
예제 #4
0
 def f(x):
     ln = layer_norm.LayerNorm(axis=-1,
                               create_scale=create_scale,
                               create_offset=create_offset,
                               use_fast_variance=use_fast_variance,
                               param_axis=-1)
     return ln(x)
예제 #5
0
 def test_no_offset_beta_init_provided(self):
     with self.assertRaisesRegex(
             ValueError,
             "Cannot set `offset_init` if `create_offset=False`."):
         layer_norm.LayerNorm(3,
                              create_scale=True,
                              create_offset=False,
                              offset_init=np.zeros)
예제 #6
0
 def test_multiple_param_axis(self, param_axis, param_shape):
     ln = layer_norm.LayerNorm(-1, True, True, param_axis=param_axis)
     x = jnp.ones([3, 4, 5, 6])
     ln(x)
     self.assertEqual(ln.params_dict()["layer_norm/scale"].shape,
                      param_shape)
     self.assertEqual(ln.params_dict()["layer_norm/offset"].shape,
                      param_shape)
예제 #7
0
 def test_no_scale_and_init_provided(self):
     with self.assertRaisesRegex(
             ValueError,
             "Cannot set `scale_init` if `create_scale=False`."):
         layer_norm.LayerNorm(3,
                              create_scale=False,
                              create_offset=True,
                              scale_init=np.ones)
예제 #8
0
 def test_error_prone_param_axis(self):
     # NOTE: This test defends current, potentially error prone behaviour
     # (passing axis!=-1 and not passing param_axis). It will be removed in a
     # future version of Haiku.
     ln = layer_norm.LayerNorm(1, True, True)
     x = jnp.ones([3, 4, 5, 6])
     ln(x)
     self.assertEqual(ln.params_dict()["layer_norm/scale"].shape, (6, ))
     self.assertEqual(ln.params_dict()["layer_norm/offset"].shape, (6, ))
예제 #9
0
    def test_simple_case(self):
        layer = layer_norm.LayerNorm([1, 2],
                                     create_scale=False,
                                     create_offset=False)
        inputs = np.ones([2, 3, 3, 5])

        outputs = layer(inputs)
        for x in np.nditer(outputs):
            self.assertEqual(x, 0.0)
예제 #10
0
    def test_simple_case(self, use_fast_variance):
        layer = layer_norm.LayerNorm([1, 2],
                                     create_scale=False,
                                     create_offset=False,
                                     use_fast_variance=use_fast_variance,
                                     param_axis=-1)
        inputs = np.ones([2, 3, 3, 5])

        outputs = layer(inputs)
        for x in np.nditer(outputs):
            self.assertEqual(x, 0.0)
예제 #11
0
    def test_connection(self):
        data = jnp.zeros([2, 3, 4, 5])
        normalize = (
            lambda a: layer_norm.LayerNorm(a, True, True, param_axis=-1)(data))

        normalize(0)
        normalize(1)
        normalize(2)
        normalize(3)
        normalize(slice(1, None))
        normalize(slice(2, None))
        normalize(slice(1, -1))
예제 #12
0
    def test_simple_case_tensor(self):
        layer = layer_norm.LayerNorm([1, 2],
                                     create_scale=False,
                                     create_offset=False)

        inputs = np.ones([2, 3, 3, 5])
        scale = np.full((5, ), 0.5)
        offset = np.full((5, ), 2.0)

        outputs = layer(inputs, scale, offset)
        for x in np.nditer(outputs):
            self.assertEqual(x, 2.0)
예제 #13
0
    def test_simple_case_var(self):
        layer = layer_norm.LayerNorm([1, 2],
                                     create_scale=True,
                                     create_offset=True,
                                     scale_init=initializers.Constant(0.5),
                                     offset_init=initializers.Constant(2.0))

        inputs = np.ones([2, 3, 3, 5])

        outputs = layer(inputs)
        for x in np.nditer(outputs):
            self.assertEqual(x, 2.0)
예제 #14
0
 def test_param_axis_not_required_for_final_axis(self):
     ln = layer_norm.LayerNorm(-1, True, True)
     x = jnp.ones([3, 4, 5, 6])
     ln(x)
     self.assertEqual(ln.params_dict()["layer_norm/scale"].shape, (6, ))
     self.assertEqual(ln.params_dict()["layer_norm/offset"].shape, (6, ))
예제 #15
0
 def f(x):
     ln = layer_norm.LayerNorm(axis=-1,
                               create_scale=create_scale,
                               create_offset=create_offset)
     return ln(x)
예제 #16
0
 def test_param_axis_required_for_non_final_axis(self, axis):
     ln = layer_norm.LayerNorm(axis, True, True)
     x = jnp.ones([3, 4, 5, 6])
     with self.assertRaisesRegex(ValueError,
                                 "pass.*param_axis.*in the ctor"):
         ln(x)
예제 #17
0
 def test_invalid_axis(self, axis):
     with self.assertRaisesRegex(
             ValueError,
             "`axis` should be an int, slice or iterable of ints."):
         layer_norm.LayerNorm(axis, create_scale=False, create_offset=False)