Exemplo n.º 1
0
def _get_rng_stack(count: int) -> jnp.ndarray:
    rng = base.maybe_next_rng_key()
    if rng is not None:
        rng = jax.random.split(rng, count)
    else:
        rng = jnp.zeros([count, 2], dtype=jnp.uint32)
    return rng
Exemplo n.º 2
0
    def __call__(self, x, *args_ys, reverse=False):
        count = self._count
        try:
            init_fn, apply_fn = transform.transform(self._call_wrapped)
        except ValueError as e:
            raise LayerStackStateError(
                "LayerStack can only be used in Haiku "
                "functions which do not make use of Haiku "
                "state.") from e

        def per_layer_init_fn(c, a):
            c, rng = c
            if rng is not None:
                rng, next_rng, apply_rng = jax.random.split(rng, 3)
            else:
                rng, next_rng, apply_rng = None, None, None
            params = init_fn(rng, c, *a)
            c, _ = apply_fn(params, apply_rng, c, *a)
            return (c, next_rng), params

        def scanned_init_fn(x, rng):
            _, params = jax.lax.scan(per_layer_init_fn, (x, rng),
                                     args_ys,
                                     length=self._count)
            return params

        rng = base.maybe_next_rng_key()
        lifted_init_fn = lift.transparent_lift(scanned_init_fn)
        params = lifted_init_fn(x, rng)

        # Use scan during apply, threading through random seed so that it's
        # unique for each layer.
        def layer(carry: LayerStackCarry,
                  scanned: LayerStackScanned) -> Tuple[LayerStackCarry, Any]:
            rng = scanned.rng
            params = scanned.params

            kwargs = {}
            if self._pass_reverse_to_layer_fn:
                kwargs["reverse"] = reverse
            out_x, z = apply_fn(params, rng, carry.x, *scanned.args_ys,
                                **kwargs)
            return LayerStackCarry(x=out_x), z

        rng = _get_rng_stack(count)

        carry = LayerStackCarry(x=x)
        scanned = LayerStackScanned(params=params, rng=rng, args_ys=args_ys)

        carry, zs = jax.lax.scan(layer,
                                 carry,
                                 scanned,
                                 length=count,
                                 unroll=self._unroll,
                                 reverse=reverse)
        return carry.x, zs
Exemplo n.º 3
0
 def maybe_three():
     for _ in range(3):
         maybes.append(base.maybe_next_rng_key())
Exemplo n.º 4
0
 def test_maybe_no_rng(self):
     self.assertIsNone(base.maybe_next_rng_key())
Exemplo n.º 5
0
 def test_maybe_rng_no_transform(self):
     with self.assertRaisesRegex(
             ValueError, "must be used as part of an `hk.transform`"):
         base.maybe_next_rng_key()