Esempio n. 1
0
 def test_difference_rng(self):
   before = stateful.internal_state()
   base.next_rng_key()
   after = stateful.internal_state()
   diff = stateful.difference(before, after)
   self.assertEmpty(diff.params)
   self.assertEmpty(diff.state)
   self.assertIsNotNone(diff.rng)
Esempio n. 2
0
 def test_vmap_split_rng_without_default(self, require_split_rng):
   # Tests that when split_rng is passed explicitly the value of
   # require_split_rng has no impact.
   x = jnp.arange(2)
   stateful.vmap.require_split_rng = require_split_rng
   k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=True)(x)
   self.assertTrue((k1 != k2).all())
   k1, k2 = stateful.vmap(lambda x: base.next_rng_key(), split_rng=False)(x)
   self.assertTrue((k1 == k2).all())
   stateful.vmap.require_split_rng = True
Esempio n. 3
0
 def test_vmap_no_split_rng(self):
   key_before = base.next_rng_key()
   f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=False)
   x = jnp.arange(4)
   k1, k2, k3, k4 = f(x)
   key_after = base.next_rng_key()
   np.testing.assert_array_equal(k1, k2)
   np.testing.assert_array_equal(k2, k3)
   np.testing.assert_array_equal(k3, k4)
   self.assertFalse(np.array_equal(key_before, k1))
   self.assertFalse(np.array_equal(key_after, k1))
   self.assertFalse(np.array_equal(key_before, key_after))
Esempio n. 4
0
            def __call__(self, x):
                x += base.get_parameter("a", shape=[10, 10], init=jnp.zeros)

                def inner_fn(x):
                    return InnerModule(name="inner")(x)

                inner_transformed = transform.transform(inner_fn)
                inner_params = lift.transparent_lift(inner_transformed.init)(
                    base.next_rng_key(), x)
                x = inner_transformed.apply(inner_params, base.next_rng_key(),
                                            x)
                return x
Esempio n. 5
0
 def test_vmap_split_rng(self):
   key_before = base.next_rng_key()
   f = stateful.vmap(lambda _: base.next_rng_key(), split_rng=True)
   x = jnp.arange(4)
   k1, k2, k3, k4 = f(x)
   key_after = base.next_rng_key()
   # Test that none of the keys are equal.
   named_keys = (("k1", k1), ("k2", k2), ("k3", k3), ("k4", k4),
                 ("key_before", key_before), ("key_after", key_after))
   for (a_name, a), (b_name, b) in it.combinations(named_keys, 2):
     self.assertFalse(
         np.array_equal(a, b),
         msg=f"Keys should not be equal, but {a_name} == {b_name}")
Esempio n. 6
0
  def test_with_rng(self, seed):
    ctx_key = jax.random.PRNGKey(seed * 2 + 1)
    key = jax.random.PRNGKey(seed)
    _, next_key = jax.random.split(key)
    expected_output = jax.random.uniform(next_key, ())

    with base.new_context(rng=ctx_key):
      without_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item()

    with base.new_context(rng=ctx_key):
      with base.with_rng(key):
        with_decorator_out = jax.random.uniform(base.next_rng_key(), ()).item()

    self.assertNotEqual(without_decorator_out, expected_output)
    self.assertEqual(with_decorator_out, expected_output)
Esempio n. 7
0
 def f(x):
     m = CountingModule(op=lambda x: x + 1)
     if not base.params_frozen():
         return m(x)
     else:
         stateful.while_loop(lambda _: base.next_rng_key(), lambda x: x,
                             x)
Esempio n. 8
0
 def outer_fn(x):
     assert x.ndim == 2
     x = Bias()(x)
     inner = base.transform(inner_fn, state=True)
     inner_p, inner_s = lift.lift(inner.init)(base.next_rng_key(), x[0])
     vmap_inner = jax.vmap(inner.apply, in_axes=(None, None, 0))
     return vmap_inner(inner_p, inner_s, x)[0]
Esempio n. 9
0
 def outer_fn(x):
     assert x.ndim == 2
     x = Bias()(x)
     inner = transform.without_apply_rng(transform.transform(inner_fn))
     inner_p = lift.lift(inner.init)(base.next_rng_key(), x[0])
     vmap_inner = jax.vmap(inner.apply, in_axes=(None, 0))
     return vmap_inner(inner_p, x)
Esempio n. 10
0
    def test_includes_no_param_modules(self):
        dropout_cls = basic.to_module(
            lambda x: basic.dropout(base.next_rng_key(), 0.5, x))

        x = jnp.ones([4])
        f = lambda: dropout_cls(name="dropout")(x)
        rows = tabulate_to_list(f, columns=("module", ))
        expected = [["dropout (ToModuleWrapper)"]]
        self.assertEqual(rows, expected)
Esempio n. 11
0
            def __call__(self, carry, x):
                x += base.get_parameter("w", shape=[], init=jnp.zeros)

                inner = transform.transform(inner_fn)
                keys = base.next_rng_key() if transform.running_init(
                ) else None
                params = lift.lift(inner.init,
                                   allow_reuse=self._allow_reuse)(keys, x)
                return carry, inner.apply(params, None, x)
Esempio n. 12
0
 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:
   if len(shape) < 2:
     raise ValueError('Orthogonal initializer requires at least a 2D shape.')
   n_rows = shape[self.axis]
   n_cols = np.prod(shape) // n_rows
   matrix_shape = (n_rows, n_cols) if n_rows > n_cols else (n_cols, n_rows)
   norm_dst = jax.random.normal(base.next_rng_key(), matrix_shape, dtype)
   q_mat, r_mat = jnp.linalg.qr(norm_dst)
   # Enforce Q is uniformly distributed
   q_mat *= jnp.sign(jnp.diag(r_mat))
   if n_rows < n_cols:
     q_mat = q_mat.T
   q_mat = jnp.reshape(q_mat, (n_rows,) + tuple(np.delete(shape, self.axis)))
   q_mat = jnp.moveaxis(q_mat, 0, self.axis)
   return self.scale * q_mat
Esempio n. 13
0
 def test_no_rng(self):
     with self.assertRaisesRegex(ValueError,
                                 "must pass a non-None PRNGKey"):
         base.next_rng_key()
Esempio n. 14
0
 def with_decorator():
     with base.with_rng(key):
         return jax.random.uniform(base.next_rng_key(), ())
Esempio n. 15
0
 def without_decorator():
     return jax.random.uniform(base.next_rng_key(), ())
Esempio n. 16
0
 def three():
     for _ in range(3):
         rngs.append(base.next_rng_key())
Esempio n. 17
0
 def branch_f(i):
   for _ in range(i):
     _ = base.next_rng_key()
   return i
Esempio n. 18
0
 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:
     unscaled = jax.random.truncated_normal(base.next_rng_key(), -2., 2.,
                                            shape, dtype)
     return self.stddev * unscaled + self.mean
Esempio n. 19
0
 def body_fun(_, x):
   for _ in range(10):
     _ = base.next_rng_key()
   return x
Esempio n. 20
0
 def true_branch(x):
   _ = base.next_rng_key()
   return x
Esempio n. 21
0
 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:
     m = jax.lax.convert_element_type(self.mean, dtype)
     s = jax.lax.convert_element_type(self.stddev, dtype)
     unscaled = jax.random.truncated_normal(base.next_rng_key(), -2., 2.,
                                            shape, dtype)
     return s * unscaled + m
Esempio n. 22
0
 def add_random(x):
     x = x + jax.random.normal(base.next_rng_key())
     return x
Esempio n. 23
0
 def f():
   return basic.dropout(base.next_rng_key(), 0.25, jnp.ones([3, 3]))
Esempio n. 24
0
 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:
     return jax.random.uniform(base.next_rng_key(), shape, dtype,
                               self.minval, self.maxval)
Esempio n. 25
0
 def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray:
     m = lax.convert_element_type(self._mean, dtype)
     s = lax.convert_element_type(self._stddev, dtype)
     return m + s * jax.random.normal(base.next_rng_key(), shape, dtype)
Esempio n. 26
0
 def test_rng_no_transform(self):
     with self.assertRaisesRegex(
             ValueError, "must be used as part of an `hk.transform`"):
         base.next_rng_key()
Esempio n. 27
0
 def f():
     k1 = base.next_rng_key()
     k2 = base.next_rng_key()
     return k1, k2
Esempio n. 28
0
 def test_rng(self):
     a = base.next_rng_key()
     b = base.next_rng_key()
     self.assertIsNot(a, b)
Esempio n. 29
0
 def test_dropout_connects(self):
     basic.dropout(base.next_rng_key(), 0.25, jnp.ones([3, 3]))
Esempio n. 30
0
 def false_branch(x):
   _ = base.next_rng_key()
   _ = base.next_rng_key()
   return x