def inner_fn(x, extra): out = basic.Linear( x.shape[1], w_init=initializers.Constant(extra * jnp.eye(x.shape[1])), b_init=initializers.Constant(extra), )(x) return out, out
def create_constant_initializers(w, b, with_bias): if with_bias: return { "w_init": initializers.Constant(w), "b_init": initializers.Constant(b) } else: return {"w_init": initializers.Constant(w)}
def inner_fn(x): # Here we initialize the layer to an identity + 1, while later we multiply # each parameter by the index `n`. return basic.Linear( x.shape[1], w_init=initializers.Constant(jnp.eye(x.shape[1])), b_init=initializers.Constant(1.0), )(x)
def inner_fn(x, extra): # Compared to previous test we pass in the `extra` argument as an # additional input, in order to directly initialize the parameters to the # index `n` of the iteration. out = basic.Linear( x.shape[1], w_init=initializers.Constant(extra * jnp.eye(x.shape[1])), b_init=initializers.Constant(extra), )(x) return out, out
def test_simple_case_var(self): layer = layer_norm.LayerNorm([1, 2], create_scale=True, create_offset=True, scale_init=initializers.Constant(0.5), offset_init=initializers.Constant(2.0)) inputs = np.ones([2, 3, 3, 5]) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 2.0)
def test_simple_case_var(self): layer = group_norm.GroupNorm(groups=5, create_scale=True, create_offset=True, scale_init=initializers.Constant(0.5), offset_init=initializers.Constant(2.0)) inputs = jnp.ones([2, 3, 3, 10]) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 2.0)
def test_initializers(self): as_np_f64 = lambda t: np.array(t, dtype=np.float64) # This just makes sure we can call the initializers in accordance to the # API and get the right shapes and dtypes out. inits = [ initializers.Constant(42.0), initializers.Constant(as_np_f64(42.0)), initializers.RandomNormal(), initializers.RandomNormal(2.0), initializers.RandomNormal(as_np_f64(2.0)), initializers.RandomUniform(), initializers.RandomUniform(3.0), initializers.RandomUniform(as_np_f64(3.0)), initializers.VarianceScaling(), initializers.VarianceScaling(2.0), initializers.VarianceScaling(as_np_f64(2.0)), initializers.VarianceScaling(2.0, mode="fan_in"), initializers.VarianceScaling(as_np_f64(2.0), mode="fan_in", fan_in_axes=[0]), initializers.VarianceScaling(2.0, mode="fan_in", fan_in_axes=[0]), initializers.VarianceScaling(as_np_f64(2.0), mode="fan_in"), initializers.VarianceScaling(2.0, mode="fan_out"), initializers.VarianceScaling(as_np_f64(2.0), mode="fan_out"), initializers.VarianceScaling(2.0, mode="fan_avg"), initializers.VarianceScaling(as_np_f64(2.0), mode="fan_avg"), initializers.VarianceScaling(2.0, distribution="truncated_normal"), initializers.VarianceScaling( as_np_f64(2.0), distribution="truncated_normal"), initializers.VarianceScaling(2.0, distribution="normal"), initializers.VarianceScaling(as_np_f64(2.0), distribution="normal"), initializers.VarianceScaling(2.0, distribution="uniform"), initializers.VarianceScaling(as_np_f64(2.0), distribution="uniform"), initializers.UniformScaling(), initializers.UniformScaling(2.0), initializers.UniformScaling(as_np_f64(2.0)), initializers.TruncatedNormal(), initializers.Orthogonal(), initializers.Identity(), initializers.Identity(as_np_f64(2.0)), # Users are supposed to be able to use these. jnp.zeros, jnp.ones, ] # TODO(ibab): Test other shapes as well. shape = (20, 42) dtype = jnp.float32 for init in inits: generated = init(shape, dtype) self.assertEqual(generated.shape, shape) self.assertEqual(generated.dtype, dtype)
def test_complex_dtype(self): if jax.local_devices()[0].platform == "tpu": self.skipTest("Complex dtype not supported by TPU") # This just makes sure we can call the initializers in accordance to the # API and get the right shapes and dtypes out. inits = [ initializers.Constant(42. + 1j * 1729.), initializers.RandomNormal(), initializers.RandomNormal(2.0), initializers.RandomNormal(2. - 3j), initializers.TruncatedNormal(), initializers.TruncatedNormal(2.), initializers.TruncatedNormal(2., 1. - 1j), # Users are supposed to be able to use these. jnp.zeros, jnp.ones, ] shape = (5, 13, 17) dtype = jnp.complex64 for init in inits: generated = init(shape, dtype) self.assertEqual(generated.shape, shape) self.assertEqual(generated.dtype, dtype)
def test_initializers(self, dtype): # This just makes sure we can call the initializers in accordance to the # API and get the right shapes and dtypes out. inits = [ initializers.Constant(42.0), initializers.RandomNormal(), initializers.RandomNormal(2.0), initializers.RandomUniform(), initializers.RandomUniform(3.0), initializers.VarianceScaling(), initializers.VarianceScaling(2.0), initializers.VarianceScaling(2.0, mode="fan_in"), initializers.VarianceScaling(2.0, mode="fan_out"), initializers.VarianceScaling(2.0, mode="fan_avg"), initializers.VarianceScaling(2.0, distribution="truncated_normal"), initializers.VarianceScaling(2.0, distribution="normal"), initializers.VarianceScaling(2.0, distribution="uniform"), initializers.UniformScaling(), initializers.UniformScaling(2.0), initializers.TruncatedNormal(), # Users are supposed to be able to use these. jnp.zeros, jnp.ones, ] # TODO(ibab): Test other shapes as well. shape = (20, 42) for init in inits: generated = init(shape, dtype) self.assertEqual(generated.shape, shape) self.assertEqual(generated.dtype, dtype)
def maybe_initialize(self, shape, dtype): """If uninitialized sets the average to ``0`` of the given shape/dtype.""" base.get_state("counter", (), jnp.int32, init=initializers.Constant(-self._warmup_length)) base.get_state("hidden", shape, dtype, init=jnp.zeros) base.get_state("average", shape, dtype, init=jnp.zeros)
def test_simple_case_with_scale(self): layer = rms_norm.RMSNorm(axis=[1, 2], eps=0.0, scale_init=initializers.Constant(0.5)) inputs = np.full(shape=[2, 3, 3, 5], fill_value=2.0) outputs = layer(inputs) for x in np.nditer(outputs): self.assertEqual(x, 0.5)
def test_constant_with_list(self, k, dtype, broadcast): init = initializers.Constant(k) shape = (1, 1, len(k)) if broadcast else (len(k),) actual = init(shape, dtype) expected = jnp.broadcast_to(jnp.asarray(k).astype(dtype), shape) np.testing.assert_array_equal(actual, expected) self.assertEqual(actual.shape, shape) self.assertEqual(actual.dtype, dtype)
def f_with_container_state(x): hk_layer = basic.Linear(width, w_init=initializers.Constant( jnp.eye(width))) layer_output = hk_layer(x) layer_state = { "raw_output": layer_output, "output_projection": jnp.sum(layer_output) } return layer_output + jnp.ones_like(layer_output), layer_state
def __call__(self, value, update_stats=True): """Updates the EMA and returns the new value. Args: value: The array-like object for which you would like to perform an exponential decay on. update_stats: A Boolean, whether to update the internal state of this object to reflect the input value. When `update_stats` is False the internal stats will remain unchanged. Returns: The exponentially weighted average of the input value. """ if not isinstance(value, jnp.ndarray): value = jnp.asarray(value) counter = base.get_state( "counter", (), jnp.int32, init=initializers.Constant(-self._warmup_length)) counter += 1 decay = jax.lax.convert_element_type(self._decay, value.dtype) if self._warmup_length > 0: decay = self._cond(counter <= 0, 0.0, decay, value.dtype) one = jnp.ones([], value.dtype) hidden = base.get_state("hidden", value.shape, value.dtype, init=jnp.zeros) hidden = hidden * decay + value * (one - decay) average = hidden if self._zero_debias: average /= (one - jnp.power(decay, counter)) if update_stats: base.set_state("counter", counter) base.set_state("hidden", hidden) base.set_state("average", average) return average
def __call__(self, value, update_stats=True): """Updates the EMA and returns the new value. Args: value: The array-like object for which you would like to perform an exponential decay on. update_stats: A Boolean, whether to update the internal state of this object to reflect the input value. When `update_stats` is False the internal stats will remain unchanged. Returns: The exponentially weighted average of the input value. """ value = jnp.asarray(value) # Ensure value has a dtype. prev_counter = base.get_state( "counter", shape=(), dtype=jnp.int32, init=initializers.Constant(-self._warmup_length)) prev_hidden = base.get_state("hidden", shape=value.shape, dtype=value.dtype, init=jnp.zeros) decay = jnp.asarray(self._decay).astype(value.dtype) counter = prev_counter + 1 decay = self._cond(jnp.less_equal(counter, 0), 0.0, decay, value.dtype) hidden = prev_hidden * decay + value * (1 - decay) if self._zero_debias: average = hidden / (1. - jnp.power(decay, counter)) else: average = hidden if update_stats: base.set_state("counter", counter) base.set_state("hidden", hidden) base.set_state("average", average) return average
def f_with_multi_args(x, a, b): return basic.Linear(width, w_init=initializers.Constant( jnp.eye(width)))(x) * a + b, None