Exemplo n.º 1
0
  def test_with_rng(self, seed):
    ctx_key = jax.random.PRNGKey(seed * 2 + 1)
    key = jax.random.PRNGKey(seed)
    _, next_key = jax.random.split(key)
    expected_output = jax.random.uniform(next_key, ())

    with base.new_context(rng=ctx_key):
      without_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item()

    with base.new_context(rng=ctx_key):
      with base.with_rng(key):
        with_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item()

    self.assertNotEqual(without_decorator_out, expected_output)
    self.assertEqual(with_decorator_out, expected_output)
Exemplo n.º 2
0
 def with_decorator():
     with base.with_rng(key):
         return jax.random.uniform(base.next_rng_key(), ())
Exemplo n.º 3
0
 def test_with_rng_no_transform(self):
     with self.assertRaisesRegex(
             ValueError, "must be used as part of an `hk.transform`"):
         with base.with_rng(jax.random.PRNGKey(428)):
             pass
Exemplo n.º 4
0
def with_rng_example():
    with base.with_rng(jax.random.PRNGKey(42)):
        pass
Exemplo n.º 5
0
def maybe_with_rng(key):
    if key is not None:
        return base.with_rng(key)
    else:
        return nullcontext()