def _get_rng_stack(count: int) -> jnp.ndarray: rng = base.maybe_next_rng_key() if rng is not None: rng = jax.random.split(rng, count) else: rng = jnp.zeros([count, 2], dtype=jnp.uint32) return rng
def __call__(self, x, *args_ys, reverse=False): count = self._count try: init_fn, apply_fn = transform.transform(self._call_wrapped) except ValueError as e: raise LayerStackStateError( "LayerStack can only be used in Haiku " "functions which do not make use of Haiku " "state.") from e def per_layer_init_fn(c, a): c, rng = c if rng is not None: rng, next_rng, apply_rng = jax.random.split(rng, 3) else: rng, next_rng, apply_rng = None, None, None params = init_fn(rng, c, *a) c, _ = apply_fn(params, apply_rng, c, *a) return (c, next_rng), params def scanned_init_fn(x, rng): _, params = jax.lax.scan(per_layer_init_fn, (x, rng), args_ys, length=self._count) return params rng = base.maybe_next_rng_key() lifted_init_fn = lift.transparent_lift(scanned_init_fn) params = lifted_init_fn(x, rng) # Use scan during apply, threading through random seed so that it's # unique for each layer. def layer(carry: LayerStackCarry, scanned: LayerStackScanned) -> Tuple[LayerStackCarry, Any]: rng = scanned.rng params = scanned.params kwargs = {} if self._pass_reverse_to_layer_fn: kwargs["reverse"] = reverse out_x, z = apply_fn(params, rng, carry.x, *scanned.args_ys, **kwargs) return LayerStackCarry(x=out_x), z rng = _get_rng_stack(count) carry = LayerStackCarry(x=x) scanned = LayerStackScanned(params=params, rng=rng, args_ys=args_ys) carry, zs = jax.lax.scan(layer, carry, scanned, length=count, unroll=self._unroll, reverse=reverse) return carry.x, zs
def maybe_three(): for _ in range(3): maybes.append(base.maybe_next_rng_key())
def test_maybe_no_rng(self): self.assertIsNone(base.maybe_next_rng_key())
def test_maybe_rng_no_transform(self): with self.assertRaisesRegex( ValueError, "must be used as part of an `hk.transform`"): base.maybe_next_rng_key()