def test_difference_rng(self): before = stateful.internal_state() base.next_rng_key() after = stateful.internal_state() diff = stateful.difference(before, after) self.assertEmpty(diff.params) self.assertEmpty(diff.state) self.assertIsNotNone(diff.rng)
def test_vmap_split_rng_without_default(self, require_split_rng): # Tests that when split_rng is passed explicitly the value of # require_split_rng has no impact. x = jnp.arange(2) stateful.vmap.require_split_rng = require_split_rng k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=True)(x) self.assertTrue((k1 != k2).all()) k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=False)(x) self.assertTrue((k1 == k2).all()) stateful.vmap.require_split_rng = True
def test_vmap_no_split_rng(self): key_before = base.next_rng_key() f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=False) x = jnp.arange(4) k1, k2, k3, k4 = f(x) key_after = base.next_rng_key() np.testing.assert_array_equal(k1, k2) np.testing.assert_array_equal(k2, k3) np.testing.assert_array_equal(k3, k4) self.assertFalse(np.array_equal(key_before, k1)) self.assertFalse(np.array_equal(key_after, k1)) self.assertFalse(np.array_equal(key_before, key_after))
def __call__(self, x): x += base.get_parameter("a", shape=[10, 10], init=jnp.zeros) def inner_fn(x): return InnerModule(name="inner")(x) inner_transformed = transform.transform(inner_fn) inner_params = lift.transparent_lift(inner_transformed.init)( base.next_rng_key(), x) x = inner_transformed.apply(inner_params, base.next_rng_key(), x) return x
def test_vmap_split_rng(self): key_before = base.next_rng_key() f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=True) x = jnp.arange(4) k1, k2, k3, k4 = f(x) key_after = base.next_rng_key() # Test that none of the keys are equal. named_keys = (("k1", k1), ("k2", k2), ("k3", k3), ("k4", k4), ("key_before", key_before), ("key_after", key_after)) for (a_name, a), (b_name, b) in it.combinations(named_keys, 2): self.assertFalse( np.array_equal(a, b), msg=f"Keys should not be equal, but {a_name} == {b_name}")
def test_with_rng(self, seed): ctx_key = jax.random.PRNGKey(seed * 2 + 1) key = jax.random.PRNGKey(seed) _, next_key = jax.random.split(key) expected_output = jax.random.uniform(next_key, ()) with base.new_context(rng=ctx_key): without_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item() with base.new_context(rng=ctx_key): with base.with_rng(key): with_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item() self.assertNotEqual(without_decorator_out, expected_output) self.assertEqual(with_decorator_out, expected_output)
def f(x): m = CountingModule(op=lambda x: x + 1) if not base.params_frozen(): return m(x) else: stateful.while_loop(lambda _: base.next_rng_key(), lambda x: x, x)
def outer_fn(x): assert x.ndim == 2 x = Bias()(x) inner = base.transform(inner_fn, state=True) inner_p, inner_s = lift.lift(inner.init)(base.next_rng_key(), x[0]) vmap_inner = jax.vmap(inner.apply, in_axes=(None, None, 0)) return vmap_inner(inner_p, inner_s, x)[0]
def outer_fn(x): assert x.ndim == 2 x = Bias()(x) inner = transform.without_apply_rng(transform.transform(inner_fn)) inner_p = lift.lift(inner.init)(base.next_rng_key(), x[0]) vmap_inner = jax.vmap(inner.apply, in_axes=(None, 0)) return vmap_inner(inner_p, x)
def test_includes_no_param_modules(self): dropout_cls = basic.to_module( lambda x: basic.dropout(base.next_rng_key(), 0.5, x)) x = jnp.ones([4]) f = lambda: dropout_cls(name="dropout")(x) rows = tabulate_to_list(f, columns=("module", )) expected = [["dropout (ToModuleWrapper)"]] self.assertEqual(rows, expected)
def __call__(self, carry, x): x += base.get_parameter("w", shape=[], init=jnp.zeros) inner = transform.transform(inner_fn) keys = base.next_rng_key() if transform.running_init( ) else None params = lift.lift(inner.init, allow_reuse=self._allow_reuse)(keys, x) return carry, inner.apply(params, None, x)
def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray: if len(shape) < 2: raise ValueError('Orthogonal initializer requires at least a 2D shape.') n_rows = shape[self.axis] n_cols = np.prod(shape) // n_rows matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows) norm_dst = jax.random.normal(base.next_rng_key(), matrix_shape, dtype) q_mat, r_mat = jnp.linalg.qr(norm_dst) # Enforce Q is uniformly distributed q_mat *= jnp.sign(jnp.diag(r_mat)) if n_rows < n_cols: q_mat = q_mat.T q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis))) q_mat = jnp.moveaxis(q_mat, 0, self.axis) return self.scale * q_mat
def test_no_rng(self): with self.assertRaisesRegex(ValueError, "must pass a non-None PRNGKey"): base.next_rng_key()
def with_decorator(): with base.with_rng(key): return jax.random.uniform(base.next_rng_key(), ())
def without_decorator(): return jax.random.uniform(base.next_rng_key(), ())
def three(): for _ in range(3): rngs.append(base.next_rng_key())
def branch_f(i): for _ in range(i): _ = base.next_rng_key() return i
def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray: unscaled = jax.random.truncated_normal(base.next_rng_key(), -2., 2., shape, dtype) return self.stddev * unscaled + self.mean
def body_fun(_, x): for _ in range(10): _ = base.next_rng_key() return x
def true_branch(x): _ = base.next_rng_key() return x
def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray: m = jax.lax.convert_element_type(self.mean, dtype) s = jax.lax.convert_element_type(self.stddev, dtype) unscaled = jax.random.truncated_normal(base.next_rng_key(), -2., 2., shape, dtype) return s * unscaled + m
def add_random(x): x = x + jax.random.normal(base.next_rng_key()) return x
def f(): return basic.dropout(base.next_rng_key(), 0.25, jnp.ones([3, 3]))
def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray: return jax.random.uniform(base.next_rng_key(), shape, dtype, self.minval, self.maxval)
def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray: m = lax.convert_element_type(self._mean, dtype) s = lax.convert_element_type(self._stddev, dtype) return m + s * jax.random.normal(base.next_rng_key(), shape, dtype)
def test_rng_no_transform(self): with self.assertRaisesRegex( ValueError, "must be used as part of an `hk.transform`"): base.next_rng_key()
def f(): k1 = base.next_rng_key() k2 = base.next_rng_key() return k1, k2
def test_rng(self): a = base.next_rng_key() b = base.next_rng_key() self.assertIsNot(a, b)
def test_dropout_connects(self): basic.dropout(base.next_rng_key(), 0.25, jnp.ones([3, 3]))
def false_branch(x): _ = base.next_rng_key() _ = base.next_rng_key() return x