def test_with_rng(self, seed): ctx_key = jax.random.PRNGKey(seed * 2 + 1) key = jax.random.PRNGKey(seed) _, next_key = jax.random.split(key) expected_output = jax.random.uniform(next_key, ()) with base.new_context(rng=ctx_key): without_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item() with base.new_context(rng=ctx_key): with base.with_rng(key): with_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item() self.assertNotEqual(without_decorator_out, expected_output) self.assertEqual(with_decorator_out, expected_output)
def with_decorator(): with base.with_rng(key): return jax.random.uniform(base.next_rng_key(), ())
def test_with_rng_no_transform(self): with self.assertRaisesRegex( ValueError, "must be used as part of an `hk.transform`"): with base.with_rng(jax.random.PRNGKey(428)): pass
def with_rng_example(): with base.with_rng(jax.random.PRNGKey(42)): pass
def maybe_with_rng(key): if key is not None: return base.with_rng(key) else: return nullcontext()