Example #1
0
 def test_prng_reserve(self):
     k = jax.random.PRNGKey(42)
     s = base.PRNGSequence(k)
     s.reserve(10)
     hk_keys = tuple(next(s) for _ in range(10))
     jax_keys = tuple(jax.random.split(k, num=11)[1:])
     jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys)
Example #2
0
def temporary_internal_state(state: InternalState):
    rng = state.rng
    if rng is not None:
        rng = base.PRNGSequence(rng)
    frame = base.current_frame()
    frame = frame.evolve(params=state.params, state=state.state, rng=rng)
    return base.frame_stack(frame)
Example #3
0
    def __call__(
        self,
        inputs: jnp.ndarray,
        dropout_rate: Optional[float] = None,
        rng=None,
    ) -> jnp.ndarray:
        """Connects the module to some inputs.

    Args:
      inputs: A Tensor of shape `[batch_size, input_size]`.
      dropout_rate: Optional dropout rate.
      rng: Optional RNG key. Require when using dropout.

    Returns:
      output: The output of the model of size `[batch_size, output_size]`.
    """
        if dropout_rate is not None and rng is None:
            raise ValueError("When using dropout an rng key must be passed.")
        elif dropout_rate is None and rng is not None:
            raise ValueError("RNG should only be passed when using dropout.")

        rng = base.PRNGSequence(rng) if rng is not None else None
        num_layers = len(self._layers)

        for i, layer in enumerate(self._layers):
            inputs = layer(inputs)
            if i < (num_layers - 1) or self._activate_final:
                # Only perform dropout if we are activating the output.
                if dropout_rate is not None:
                    inputs = basic.dropout(next(rng), dropout_rate, inputs)
                inputs = self._activation(inputs)

        return inputs
Example #4
0
    def test_with_container_state(self):
        width = 2
        batch_size = 2
        stack_height = 3

        def f_with_container_state(x):
            hk_layer = basic.Linear(width,
                                    w_init=initializers.Constant(
                                        jnp.eye(width)))
            layer_output = hk_layer(x)
            layer_state = {
                "raw_output": layer_output,
                "output_projection": jnp.sum(layer_output)
            }
            return layer_output + jnp.ones_like(layer_output), layer_state

        @multi_transform.without_apply_rng
        @transform.transform
        def hk_fn(x):
            return layer_stack.layer_stack(
                stack_height,
                with_per_layer_inputs=True)(f_with_container_state)(x)

        x = jnp.zeros([batch_size, width])
        key_seq = base.PRNGSequence(19)
        params = hk_fn.init(next(key_seq), x)
        output, z = hk_fn.apply(params, x)
        self.assertEqual(z["raw_output"].shape,
                         (stack_height, batch_size, width))
        self.assertEqual(output.shape, (batch_size, width))
        self.assertEqual(z["output_projection"].shape, (stack_height, ))
        np.testing.assert_equal(np.sum(z["output_projection"]), np.array(12.))
        np.testing.assert_equal(
            np.all(z["raw_output"] == np.array([0., 1., 2.])[..., None, None]),
            np.array(True))
Example #5
0
    def test_with_per_layer_inputs_multi_args(self):
        """Test layer_stack with per-layer inputs with multiple arguments."""
        width = 4
        batch_size = 5
        stack_height = 3

        def f_with_multi_args(x, a, b):
            return basic.Linear(width,
                                w_init=initializers.Constant(
                                    jnp.eye(width)))(x) * a + b, None

        @multi_transform.without_apply_rng
        @transform.transform
        def hk_fn(x):
            return layer_stack.layer_stack(
                stack_height, with_per_layer_inputs=True)(f_with_multi_args)(
                    x, jnp.full([stack_height], 2.), jnp.ones([stack_height]))

        x = jnp.zeros([batch_size, width])
        key_seq = base.PRNGSequence(19)
        params = hk_fn.init(next(key_seq), x)
        output, z = hk_fn.apply(params, x)
        self.assertIsNone(z)
        self.assertEqual(output.shape, (batch_size, width))
        np.testing.assert_equal(output, np.full([batch_size, width], 7.))
Example #6
0
def to_prng_sequence(rng, err_msg) -> Optional[base.PRNGSequence]:
    if rng is not None:
        try:
            rng = base.PRNGSequence(rng)
        except Exception as e:
            raise ValueError(err_msg) from e
    return rng
Example #7
0
    def test_reverse_with_pass_reverse_to_layer_fn(self):
        # The layer stack below runs iteratively the update equation:
        # x_n = n * alpha * (x_{n-1} + 1)
        # with x_0 = 1, for n={1, ..., N}, where N = stack_height
        # The reverse layer stack as a result runs the update equation:
        # y_{n-1} = (N - n + 1) * alpha * (y_n + 1)
        # with y_N = 1, for n={N-1, ..., 0}, where N = stack_height
        # This test is equivalent to the previous one, but we nest the iterations in
        # two layer stacks.
        width = 2
        batch_size = 3
        stack_height = 4
        total_multiplier = 24
        alpha = jnp.power(total_multiplier, -1. / stack_height)
        forward, backward = self._compute_weights(stack_height, alpha)

        def inner_fn(x, extra):
            out = basic.Linear(
                x.shape[1],
                w_init=initializers.Constant(extra * jnp.eye(x.shape[1])),
                b_init=initializers.Constant(extra),
            )(x)
            return out, out

        def outer_fn(x, extra, reverse=False):
            return layer_stack.layer_stack(
                stack_height // 2,
                with_per_layer_inputs=True)(inner_fn)(x,
                                                      extra,
                                                      reverse=reverse)

        @multi_transform.without_apply_rng
        @transform.transform
        def hk_fn(x, extra, reverse=False):
            return layer_stack.layer_stack(
                2, with_per_layer_inputs=True,
                pass_reverse_to_layer_fn=True)(outer_fn)(x,
                                                         extra,
                                                         reverse=reverse)

        extra = jnp.arange(stack_height).reshape([2, stack_height // 2]) + 1
        extra = extra * alpha
        key_seq = base.PRNGSequence(19)
        init_value = 1 + jax.random.uniform(next(key_seq), [batch_size, width])
        params = hk_fn.init(next(key_seq), init_value, extra)

        x_n, x_all = hk_fn.apply(params, init_value, extra)
        self.assertEqual(x_all.shape[:2], (2, stack_height // 2))
        x_all = x_all.reshape((stack_height, *x_all.shape[2:]))
        for x_t, (a, b) in zip(x_all, forward):
            np.testing.assert_allclose(x_t, a * init_value + b, rtol=1e-6)
        np.testing.assert_allclose(x_n, x_all[-1], rtol=1e-6)

        y_0, y_all = hk_fn.apply(params, init_value, extra, reverse=True)
        self.assertEqual(y_all.shape[:2], (2, stack_height // 2))
        y_all = y_all.reshape((stack_height, *y_all.shape[2:]))
        for y_t, (a, b) in zip(y_all, reversed(backward)):
            np.testing.assert_allclose(y_t, a * init_value + b, rtol=1e-6)
        np.testing.assert_allclose(y_0, y_all[0], rtol=1e-6)
Example #8
0
 def test_prng_reserve_twice(self):
     k = jax.random.PRNGKey(42)
     s = base.PRNGSequence(k)
     s.reserve(2)
     s.reserve(2)
     hk_keys = tuple(next(s) for _ in range(4))
     k, subkey1, subkey2 = tuple(jax.random.split(k, num=3))
     _, subkey3, subkey4 = tuple(jax.random.split(k, num=3))
     jax_keys = (subkey1, subkey2, subkey3, subkey4)
     jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys)
Example #9
0
    def pure_fun(args, state_in):
        if split_rng:
            # NOTE: In the case of split_rng we recieve an RNG key (rather than the
            # internal state of a PRNGSequence) so we need to construct that here.
            rng = base.PRNGSequence(state_in.rng).internal_state
            state_in = InternalState(state_in.params, state_in.state, rng)

        with temporary_internal_state(state_in), \
             base.push_jax_trace_level():
            out = fun(*args)
            state_out = difference(state_in, internal_state())
            return out, state_out
Example #10
0
    def test_reverse_with_additional_inputs(self):
        # The layer stack below runs iteratively the update equation:
        # x_n = n * alpha * (x_{n-1} + 1)
        # with x_0 = 1, for n={1, ..., N}, where N = stack_height
        # The reverse layer stack as a result runs the update equation:
        # y_{n-1} = (N - n + 1) * alpha * (y_n + 1)
        # with y_N = 1, for n={N-1, ..., 0}, where N = stack_height
        width = 2
        batch_size = 3
        stack_height = 4
        total_multiplier = 24
        alpha = jnp.power(total_multiplier, -1. / stack_height)
        forward, backward = self._compute_weights(stack_height, alpha)

        def inner_fn(x, extra):
            # Compared to previous test we pass in the `extra` argument as an
            # additional input, in order to directly initialize the parameters to the
            # index `n` of the iteration.
            out = basic.Linear(
                x.shape[1],
                w_init=initializers.Constant(extra * jnp.eye(x.shape[1])),
                b_init=initializers.Constant(extra),
            )(x)
            return out, out

        @multi_transform.without_apply_rng
        @transform.transform
        def hk_fn(x, extra, reverse=False):
            return layer_stack.layer_stack(
                stack_height,
                with_per_layer_inputs=True)(inner_fn)(x,
                                                      extra,
                                                      reverse=reverse)

        extra = jnp.arange(stack_height) + 1
        extra = extra * alpha
        key_seq = base.PRNGSequence(19)
        init_value = 1 + jax.random.uniform(next(key_seq), [batch_size, width])
        params = hk_fn.init(next(key_seq), init_value, extra)

        x_n, x_all = hk_fn.apply(params, init_value, extra)
        self.assertEqual(x_all.shape[0], stack_height)
        for x_t, (a, b) in zip(x_all, forward):
            np.testing.assert_allclose(x_t, a * init_value + b, rtol=1e-6)
        np.testing.assert_allclose(x_n, x_all[-1], rtol=1e-6)

        y_0, y_all = hk_fn.apply(params, init_value, extra, reverse=True)
        self.assertEqual(y_all.shape[0], stack_height)
        for y_t, (a, b) in zip(y_all, reversed(backward)):
            np.testing.assert_allclose(y_t, a * init_value + b, rtol=1e-6)
        np.testing.assert_allclose(y_0, y_all[0], rtol=1e-6)
Example #11
0
 def test_prng_sequence(self, seed, wrap_seed):
     # Values using our sequence.
     key_or_seed = jax.random.PRNGKey(seed) if wrap_seed else seed
     key_seq = base.PRNGSequence(key_or_seed)
     seq_v1 = jax.random.normal(next(key_seq), [])
     seq_v2 = jax.random.normal(next(key_seq), [])
     # Generate values using manual splitting.
     key = jax.random.PRNGKey(seed)
     key, temp_key = jax.random.split(key)
     raw_v1 = jax.random.normal(temp_key, [])
     _, temp_key = jax.random.split(key)
     raw_v2 = jax.random.normal(temp_key, [])
     self.assertEqual(raw_v1, seq_v1)
     self.assertEqual(raw_v2, seq_v2)
Example #12
0
def temporary_internal_state(state: InternalState):
    """Pushes a temporary copy of the internal state."""
    state = copy_structure(state)
    rng = state.rng
    if rng is not None:
        rng = base.PRNGSequence(rng)
    current_state = internal_state()
    params = state.params
    if params is None:
        params = current_state.params
    state = state.state
    if state is None:
        state = current_state.state
    frame = base.current_frame()
    frame = frame.evolve(params=params, state=state, rng=rng)
    return base.frame_stack(frame)
Example #13
0
    def test_reverse(self):
        # The layer stack below runs iteratively the update equation:
        # x_n = n * alpha * (x_{n-1} + 1)
        # with x_0 = 1, for n={1, ..., N}, where N = stack_height
        # The reverse layer stack as a result runs the update equation:
        # y_{n-1} = (N - n + 1) * alpha * (y_n + 1)
        # with y_N = 1, for n={N-1, ..., 0}, where N = stack_height
        width = 2
        batch_size = 3
        stack_height = 4
        alpha = jnp.power(24, -1. / stack_height)
        forward, backward = self._compute_weights(stack_height, alpha)

        def inner_fn(x):
            # Here we initialize the layer to an identity + 1, while later we multiply
            # each parameter by the index `n`.
            return basic.Linear(
                x.shape[1],
                w_init=initializers.Constant(jnp.eye(x.shape[1])),
                b_init=initializers.Constant(1.0),
            )(x)

        @multi_transform.without_apply_rng
        @transform.transform
        def hk_fn(x, reverse=False):
            return layer_stack.layer_stack(stack_height)(inner_fn)(
                x, reverse=reverse)

        key_seq = base.PRNGSequence(19)
        init_value = 1 + jax.random.uniform(next(key_seq), [batch_size, width])

        def mul_by_m(x):
            m_x = jnp.arange(stack_height) + 1
            while m_x.ndim < x.ndim:
                m_x = m_x[..., None]
            return x * m_x * alpha

        params = jax.tree_map(mul_by_m, hk_fn.init(next(key_seq), init_value))

        a, b = forward[-1]
        x_n = hk_fn.apply(params, init_value)
        np.testing.assert_allclose(x_n, a * init_value + b, rtol=1e-6)

        a, b = backward[-1]
        y_0 = hk_fn.apply(params, init_value, reverse=True)
        np.testing.assert_allclose(y_0, a * init_value + b, rtol=1e-6)
Example #14
0
 def create_random_values(key_or_seed):
     key_seq = base.PRNGSequence(key_or_seed)
     return (jax.random.normal(next(key_seq), []),
             jax.random.normal(next(key_seq), []))
Example #15
0
 def test_prng_sequence_wrong_shape(self):
     with self.assertRaisesRegex(
             ValueError, "key did not have expected shape and/or dtype"):
         base.PRNGSequence(jax.random.split(jax.random.PRNGKey(42), 2))
Example #16
0
 def test_prng_sequence_invalid_input(self):
     with self.assertRaisesRegex(ValueError, "not a JAX PRNGKey"):
         base.PRNGSequence("nonsense")
Example #17
0
 def test_prng_sequence_split(self):
     k = jax.random.PRNGKey(42)
     s = base.PRNGSequence(k)
     hk_keys = s.take(10)
     jax_keys = tuple(jax.random.split(k, num=11)[1:])
     jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys)