コード例 #1
0
ファイル: layer_stack_test.py プロジェクト: deepmind/dm-haiku
 def inner_fn(x, extra):
     out = basic.Linear(
         x.shape[1],
         w_init=initializers.Constant(extra * jnp.eye(x.shape[1])),
         b_init=initializers.Constant(extra),
     )(x)
     return out, out
コード例 #2
0
ファイル: conv_test.py プロジェクト: shyamalschandra/haiku
def create_constant_initializers(w, b, with_bias):
  if with_bias:
    return {
        "w_init": initializers.Constant(w),
        "b_init": initializers.Constant(b)
    }
  else:
    return {"w_init": initializers.Constant(w)}
コード例 #3
0
ファイル: layer_stack_test.py プロジェクト: deepmind/dm-haiku
 def inner_fn(x):
     # Here we initialize the layer to an identity + 1, while later we multiply
     # each parameter by the index `n`.
     return basic.Linear(
         x.shape[1],
         w_init=initializers.Constant(jnp.eye(x.shape[1])),
         b_init=initializers.Constant(1.0),
     )(x)
コード例 #4
0
ファイル: layer_stack_test.py プロジェクト: deepmind/dm-haiku
 def inner_fn(x, extra):
     # Compared to previous test we pass in the `extra` argument as an
     # additional input, in order to directly initialize the parameters to the
     # index `n` of the iteration.
     out = basic.Linear(
         x.shape[1],
         w_init=initializers.Constant(extra * jnp.eye(x.shape[1])),
         b_init=initializers.Constant(extra),
     )(x)
     return out, out
コード例 #5
0
ファイル: layer_norm_test.py プロジェクト: vballoli/dm-haiku
    def test_simple_case_var(self):
        layer = layer_norm.LayerNorm([1, 2],
                                     create_scale=True,
                                     create_offset=True,
                                     scale_init=initializers.Constant(0.5),
                                     offset_init=initializers.Constant(2.0))

        inputs = np.ones([2, 3, 3, 5])

        outputs = layer(inputs)
        for x in np.nditer(outputs):
            self.assertEqual(x, 2.0)
コード例 #6
0
    def test_simple_case_var(self):
        layer = group_norm.GroupNorm(groups=5,
                                     create_scale=True,
                                     create_offset=True,
                                     scale_init=initializers.Constant(0.5),
                                     offset_init=initializers.Constant(2.0))

        inputs = jnp.ones([2, 3, 3, 10])

        outputs = layer(inputs)
        for x in np.nditer(outputs):
            self.assertEqual(x, 2.0)
コード例 #7
0
  def test_initializers(self):
    as_np_f64 = lambda t: np.array(t, dtype=np.float64)
    # This just makes sure we can call the initializers in accordance to the
    # API and get the right shapes and dtypes out.
    inits = [
        initializers.Constant(42.0),
        initializers.Constant(as_np_f64(42.0)),
        initializers.RandomNormal(),
        initializers.RandomNormal(2.0),
        initializers.RandomNormal(as_np_f64(2.0)),
        initializers.RandomUniform(),
        initializers.RandomUniform(3.0),
        initializers.RandomUniform(as_np_f64(3.0)),
        initializers.VarianceScaling(),
        initializers.VarianceScaling(2.0),
        initializers.VarianceScaling(as_np_f64(2.0)),
        initializers.VarianceScaling(2.0, mode="fan_in"),
        initializers.VarianceScaling(as_np_f64(2.0), mode="fan_in",
                                     fan_in_axes=[0]),
        initializers.VarianceScaling(2.0, mode="fan_in", fan_in_axes=[0]),
        initializers.VarianceScaling(as_np_f64(2.0), mode="fan_in"),
        initializers.VarianceScaling(2.0, mode="fan_out"),
        initializers.VarianceScaling(as_np_f64(2.0), mode="fan_out"),
        initializers.VarianceScaling(2.0, mode="fan_avg"),
        initializers.VarianceScaling(as_np_f64(2.0), mode="fan_avg"),
        initializers.VarianceScaling(2.0, distribution="truncated_normal"),
        initializers.VarianceScaling(
            as_np_f64(2.0), distribution="truncated_normal"),
        initializers.VarianceScaling(2.0, distribution="normal"),
        initializers.VarianceScaling(as_np_f64(2.0), distribution="normal"),
        initializers.VarianceScaling(2.0, distribution="uniform"),
        initializers.VarianceScaling(as_np_f64(2.0), distribution="uniform"),
        initializers.UniformScaling(),
        initializers.UniformScaling(2.0),
        initializers.UniformScaling(as_np_f64(2.0)),
        initializers.TruncatedNormal(),
        initializers.Orthogonal(),
        initializers.Identity(),
        initializers.Identity(as_np_f64(2.0)),

        # Users are supposed to be able to use these.
        jnp.zeros,
        jnp.ones,
    ]

    # TODO(ibab): Test other shapes as well.
    shape = (20, 42)

    dtype = jnp.float32
    for init in inits:
      generated = init(shape, dtype)
      self.assertEqual(generated.shape, shape)
      self.assertEqual(generated.dtype, dtype)
コード例 #8
0
    def test_complex_dtype(self):
        if jax.local_devices()[0].platform == "tpu":
            self.skipTest("Complex dtype not supported by TPU")
        # This just makes sure we can call the initializers in accordance to the
        # API and get the right shapes and dtypes out.
        inits = [
            initializers.Constant(42. + 1j * 1729.),
            initializers.RandomNormal(),
            initializers.RandomNormal(2.0),
            initializers.RandomNormal(2. - 3j),
            initializers.TruncatedNormal(),
            initializers.TruncatedNormal(2.),
            initializers.TruncatedNormal(2., 1. - 1j),

            # Users are supposed to be able to use these.
            jnp.zeros,
            jnp.ones,
        ]

        shape = (5, 13, 17)

        dtype = jnp.complex64
        for init in inits:
            generated = init(shape, dtype)
            self.assertEqual(generated.shape, shape)
            self.assertEqual(generated.dtype, dtype)
コード例 #9
0
    def test_initializers(self, dtype):
        # This just makes sure we can call the initializers in accordance to the
        # API and get the right shapes and dtypes out.
        inits = [
            initializers.Constant(42.0),
            initializers.RandomNormal(),
            initializers.RandomNormal(2.0),
            initializers.RandomUniform(),
            initializers.RandomUniform(3.0),
            initializers.VarianceScaling(),
            initializers.VarianceScaling(2.0),
            initializers.VarianceScaling(2.0, mode="fan_in"),
            initializers.VarianceScaling(2.0, mode="fan_out"),
            initializers.VarianceScaling(2.0, mode="fan_avg"),
            initializers.VarianceScaling(2.0, distribution="truncated_normal"),
            initializers.VarianceScaling(2.0, distribution="normal"),
            initializers.VarianceScaling(2.0, distribution="uniform"),
            initializers.UniformScaling(),
            initializers.UniformScaling(2.0),
            initializers.TruncatedNormal(),

            # Users are supposed to be able to use these.
            jnp.zeros,
            jnp.ones,
        ]

        # TODO(ibab): Test other shapes as well.
        shape = (20, 42)

        for init in inits:
            generated = init(shape, dtype)
            self.assertEqual(generated.shape, shape)
            self.assertEqual(generated.dtype, dtype)
コード例 #10
0
 def maybe_initialize(self, shape, dtype):
     """If uninitialized sets the average to ``0`` of the given shape/dtype."""
     base.get_state("counter", (),
                    jnp.int32,
                    init=initializers.Constant(-self._warmup_length))
     base.get_state("hidden", shape, dtype, init=jnp.zeros)
     base.get_state("average", shape, dtype, init=jnp.zeros)
コード例 #11
0
 def test_simple_case_with_scale(self):
     layer = rms_norm.RMSNorm(axis=[1, 2],
                              eps=0.0,
                              scale_init=initializers.Constant(0.5))
     inputs = np.full(shape=[2, 3, 3, 5], fill_value=2.0)
     outputs = layer(inputs)
     for x in np.nditer(outputs):
         self.assertEqual(x, 0.5)
コード例 #12
0
 def test_constant_with_list(self, k, dtype, broadcast):
   init = initializers.Constant(k)
   shape = (1, 1, len(k)) if broadcast else (len(k),)
   actual = init(shape, dtype)
   expected = jnp.broadcast_to(jnp.asarray(k).astype(dtype), shape)
   np.testing.assert_array_equal(actual, expected)
   self.assertEqual(actual.shape, shape)
   self.assertEqual(actual.dtype, dtype)
コード例 #13
0
ファイル: layer_stack_test.py プロジェクト: deepmind/dm-haiku
 def f_with_container_state(x):
     hk_layer = basic.Linear(width,
                             w_init=initializers.Constant(
                                 jnp.eye(width)))
     layer_output = hk_layer(x)
     layer_state = {
         "raw_output": layer_output,
         "output_projection": jnp.sum(layer_output)
     }
     return layer_output + jnp.ones_like(layer_output), layer_state
コード例 #14
0
    def __call__(self, value, update_stats=True):
        """Updates the EMA and returns the new value.

    Args:
      value: The array-like object for which you would like to perform an
        exponential decay on.
      update_stats: A Boolean, whether to update the internal state
        of this object to reflect the input value. When `update_stats` is False
        the internal stats will remain unchanged.

    Returns:
      The exponentially weighted average of the input value.
    """
        if not isinstance(value, jnp.ndarray):
            value = jnp.asarray(value)

        counter = base.get_state(
            "counter", (),
            jnp.int32,
            init=initializers.Constant(-self._warmup_length))
        counter += 1

        decay = jax.lax.convert_element_type(self._decay, value.dtype)
        if self._warmup_length > 0:
            decay = self._cond(counter <= 0, 0.0, decay, value.dtype)

        one = jnp.ones([], value.dtype)
        hidden = base.get_state("hidden",
                                value.shape,
                                value.dtype,
                                init=jnp.zeros)
        hidden = hidden * decay + value * (one - decay)

        average = hidden
        if self._zero_debias:
            average /= (one - jnp.power(decay, counter))

        if update_stats:
            base.set_state("counter", counter)
            base.set_state("hidden", hidden)
            base.set_state("average", average)

        return average
コード例 #15
0
ファイル: moving_averages.py プロジェクト: ibab/haiku
    def __call__(self, value, update_stats=True):
        """Updates the EMA and returns the new value.

    Args:
      value: The array-like object for which you would like to perform an
        exponential decay on.
      update_stats: A Boolean, whether to update the internal state
        of this object to reflect the input value. When `update_stats` is False
        the internal stats will remain unchanged.

    Returns:
      The exponentially weighted average of the input value.

    """
        value = jnp.asarray(value)  # Ensure value has a dtype.
        prev_counter = base.get_state(
            "counter",
            shape=(),
            dtype=jnp.int32,
            init=initializers.Constant(-self._warmup_length))
        prev_hidden = base.get_state("hidden",
                                     shape=value.shape,
                                     dtype=value.dtype,
                                     init=jnp.zeros)

        decay = jnp.asarray(self._decay).astype(value.dtype)
        counter = prev_counter + 1
        decay = self._cond(jnp.less_equal(counter, 0), 0.0, decay, value.dtype)
        hidden = prev_hidden * decay + value * (1 - decay)

        if self._zero_debias:
            average = hidden / (1. - jnp.power(decay, counter))
        else:
            average = hidden

        if update_stats:
            base.set_state("counter", counter)
            base.set_state("hidden", hidden)
            base.set_state("average", average)
        return average
コード例 #16
0
ファイル: layer_stack_test.py プロジェクト: deepmind/dm-haiku
 def f_with_multi_args(x, a, b):
     return basic.Linear(width,
                         w_init=initializers.Constant(
                             jnp.eye(width)))(x) * a + b, None