def test_prng_reserve(self): k = jax.random.PRNGKey(42) s = base.PRNGSequence(k) s.reserve(10) hk_keys = tuple(next(s) for _ in range(10)) jax_keys = tuple(jax.random.split(k, num=11)[1:]) jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys)
def temporary_internal_state(state: InternalState): rng = state.rng if rng is not None: rng = base.PRNGSequence(rng) frame = base.current_frame() frame = frame.evolve(params=state.params, state=state.state, rng=rng) return base.frame_stack(frame)
def __call__( self, inputs: jnp.ndarray, dropout_rate: Optional[float] = None, rng=None, ) -> jnp.ndarray: """Connects the module to some inputs. Args: inputs: A Tensor of shape `[batch_size, input_size]`. dropout_rate: Optional dropout rate. rng: Optional RNG key. Require when using dropout. Returns: output: The output of the model of size `[batch_size, output_size]`. """ if dropout_rate is not None and rng is None: raise ValueError("When using dropout an rng key must be passed.") elif dropout_rate is None and rng is not None: raise ValueError("RNG should only be passed when using dropout.") rng = base.PRNGSequence(rng) if rng is not None else None num_layers = len(self._layers) for i, layer in enumerate(self._layers): inputs = layer(inputs) if i < (num_layers - 1) or self._activate_final: # Only perform dropout if we are activating the output. if dropout_rate is not None: inputs = basic.dropout(next(rng), dropout_rate, inputs) inputs = self._activation(inputs) return inputs
def test_with_container_state(self): width = 2 batch_size = 2 stack_height = 3 def f_with_container_state(x): hk_layer = basic.Linear(width, w_init=initializers.Constant( jnp.eye(width))) layer_output = hk_layer(x) layer_state = { "raw_output": layer_output, "output_projection": jnp.sum(layer_output) } return layer_output + jnp.ones_like(layer_output), layer_state @multi_transform.without_apply_rng @transform.transform def hk_fn(x): return layer_stack.layer_stack( stack_height, with_per_layer_inputs=True)(f_with_container_state)(x) x = jnp.zeros([batch_size, width]) key_seq = base.PRNGSequence(19) params = hk_fn.init(next(key_seq), x) output, z = hk_fn.apply(params, x) self.assertEqual(z["raw_output"].shape, (stack_height, batch_size, width)) self.assertEqual(output.shape, (batch_size, width)) self.assertEqual(z["output_projection"].shape, (stack_height, )) np.testing.assert_equal(np.sum(z["output_projection"]), np.array(12.)) np.testing.assert_equal( np.all(z["raw_output"] == np.array([0., 1., 2.])[..., None, None]), np.array(True))
def test_with_per_layer_inputs_multi_args(self): """Test layer_stack with per-layer inputs with multiple arguments.""" width = 4 batch_size = 5 stack_height = 3 def f_with_multi_args(x, a, b): return basic.Linear(width, w_init=initializers.Constant( jnp.eye(width)))(x) * a + b, None @multi_transform.without_apply_rng @transform.transform def hk_fn(x): return layer_stack.layer_stack( stack_height, with_per_layer_inputs=True)(f_with_multi_args)( x, jnp.full([stack_height], 2.), jnp.ones([stack_height])) x = jnp.zeros([batch_size, width]) key_seq = base.PRNGSequence(19) params = hk_fn.init(next(key_seq), x) output, z = hk_fn.apply(params, x) self.assertIsNone(z) self.assertEqual(output.shape, (batch_size, width)) np.testing.assert_equal(output, np.full([batch_size, width], 7.))
def to_prng_sequence(rng, err_msg) -> Optional[base.PRNGSequence]: if rng is not None: try: rng = base.PRNGSequence(rng) except Exception as e: raise ValueError(err_msg) from e return rng
def test_reverse_with_pass_reverse_to_layer_fn(self): # The layer stack below runs iteratively the update equation: # x_n = n * alpha * (x_{n-1} + 1) # with x_0 = 1, for n={1, ..., N}, where N = stack_height # The reverse layer stack as a result runs the update equation: # y_{n-1} = (N - n + 1) * alpha * (y_n + 1) # with y_N = 1, for n={N-1, ..., 0}, where N = stack_height # This test is equivalent to the previous one, but we nest the iterations in # two layer stacks. width = 2 batch_size = 3 stack_height = 4 total_multiplier = 24 alpha = jnp.power(total_multiplier, -1. / stack_height) forward, backward = self._compute_weights(stack_height, alpha) def inner_fn(x, extra): out = basic.Linear( x.shape[1], w_init=initializers.Constant(extra * jnp.eye(x.shape[1])), b_init=initializers.Constant(extra), )(x) return out, out def outer_fn(x, extra, reverse=False): return layer_stack.layer_stack( stack_height // 2, with_per_layer_inputs=True)(inner_fn)(x, extra, reverse=reverse) @multi_transform.without_apply_rng @transform.transform def hk_fn(x, extra, reverse=False): return layer_stack.layer_stack( 2, with_per_layer_inputs=True, pass_reverse_to_layer_fn=True)(outer_fn)(x, extra, reverse=reverse) extra = jnp.arange(stack_height).reshape([2, stack_height // 2]) + 1 extra = extra * alpha key_seq = base.PRNGSequence(19) init_value = 1 + jax.random.uniform(next(key_seq), [batch_size, width]) params = hk_fn.init(next(key_seq), init_value, extra) x_n, x_all = hk_fn.apply(params, init_value, extra) self.assertEqual(x_all.shape[:2], (2, stack_height // 2)) x_all = x_all.reshape((stack_height, *x_all.shape[2:])) for x_t, (a, b) in zip(x_all, forward): np.testing.assert_allclose(x_t, a * init_value + b, rtol=1e-6) np.testing.assert_allclose(x_n, x_all[-1], rtol=1e-6) y_0, y_all = hk_fn.apply(params, init_value, extra, reverse=True) self.assertEqual(y_all.shape[:2], (2, stack_height // 2)) y_all = y_all.reshape((stack_height, *y_all.shape[2:])) for y_t, (a, b) in zip(y_all, reversed(backward)): np.testing.assert_allclose(y_t, a * init_value + b, rtol=1e-6) np.testing.assert_allclose(y_0, y_all[0], rtol=1e-6)
def test_prng_reserve_twice(self): k = jax.random.PRNGKey(42) s = base.PRNGSequence(k) s.reserve(2) s.reserve(2) hk_keys = tuple(next(s) for _ in range(4)) k, subkey1, subkey2 = tuple(jax.random.split(k, num=3)) _, subkey3, subkey4 = tuple(jax.random.split(k, num=3)) jax_keys = (subkey1, subkey2, subkey3, subkey4) jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys)
def pure_fun(args, state_in): if split_rng: # NOTE: In the case of split_rng we recieve an RNG key (rather than the # internal state of a PRNGSequence) so we need to construct that here. rng = base.PRNGSequence(state_in.rng).internal_state state_in = InternalState(state_in.params, state_in.state, rng) with temporary_internal_state(state_in), \ base.push_jax_trace_level(): out = fun(*args) state_out = difference(state_in, internal_state()) return out, state_out
def test_reverse_with_additional_inputs(self): # The layer stack below runs iteratively the update equation: # x_n = n * alpha * (x_{n-1} + 1) # with x_0 = 1, for n={1, ..., N}, where N = stack_height # The reverse layer stack as a result runs the update equation: # y_{n-1} = (N - n + 1) * alpha * (y_n + 1) # with y_N = 1, for n={N-1, ..., 0}, where N = stack_height width = 2 batch_size = 3 stack_height = 4 total_multiplier = 24 alpha = jnp.power(total_multiplier, -1. / stack_height) forward, backward = self._compute_weights(stack_height, alpha) def inner_fn(x, extra): # Compared to previous test we pass in the `extra` argument as an # additional input, in order to directly initialize the parameters to the # index `n` of the iteration. out = basic.Linear( x.shape[1], w_init=initializers.Constant(extra * jnp.eye(x.shape[1])), b_init=initializers.Constant(extra), )(x) return out, out @multi_transform.without_apply_rng @transform.transform def hk_fn(x, extra, reverse=False): return layer_stack.layer_stack( stack_height, with_per_layer_inputs=True)(inner_fn)(x, extra, reverse=reverse) extra = jnp.arange(stack_height) + 1 extra = extra * alpha key_seq = base.PRNGSequence(19) init_value = 1 + jax.random.uniform(next(key_seq), [batch_size, width]) params = hk_fn.init(next(key_seq), init_value, extra) x_n, x_all = hk_fn.apply(params, init_value, extra) self.assertEqual(x_all.shape[0], stack_height) for x_t, (a, b) in zip(x_all, forward): np.testing.assert_allclose(x_t, a * init_value + b, rtol=1e-6) np.testing.assert_allclose(x_n, x_all[-1], rtol=1e-6) y_0, y_all = hk_fn.apply(params, init_value, extra, reverse=True) self.assertEqual(y_all.shape[0], stack_height) for y_t, (a, b) in zip(y_all, reversed(backward)): np.testing.assert_allclose(y_t, a * init_value + b, rtol=1e-6) np.testing.assert_allclose(y_0, y_all[0], rtol=1e-6)
def test_prng_sequence(self, seed, wrap_seed): # Values using our sequence. key_or_seed = jax.random.PRNGKey(seed) if wrap_seed else seed key_seq = base.PRNGSequence(key_or_seed) seq_v1 = jax.random.normal(next(key_seq), []) seq_v2 = jax.random.normal(next(key_seq), []) # Generate values using manual splitting. key = jax.random.PRNGKey(seed) key, temp_key = jax.random.split(key) raw_v1 = jax.random.normal(temp_key, []) _, temp_key = jax.random.split(key) raw_v2 = jax.random.normal(temp_key, []) self.assertEqual(raw_v1, seq_v1) self.assertEqual(raw_v2, seq_v2)
def temporary_internal_state(state: InternalState): """Pushes a temporary copy of the internal state.""" state = copy_structure(state) rng = state.rng if rng is not None: rng = base.PRNGSequence(rng) current_state = internal_state() params = state.params if params is None: params = current_state.params state = state.state if state is None: state = current_state.state frame = base.current_frame() frame = frame.evolve(params=params, state=state, rng=rng) return base.frame_stack(frame)
def test_reverse(self): # The layer stack below runs iteratively the update equation: # x_n = n * alpha * (x_{n-1} + 1) # with x_0 = 1, for n={1, ..., N}, where N = stack_height # The reverse layer stack as a result runs the update equation: # y_{n-1} = (N - n + 1) * alpha * (y_n + 1) # with y_N = 1, for n={N-1, ..., 0}, where N = stack_height width = 2 batch_size = 3 stack_height = 4 alpha = jnp.power(24, -1. / stack_height) forward, backward = self._compute_weights(stack_height, alpha) def inner_fn(x): # Here we initialize the layer to an identity + 1, while later we multiply # each parameter by the index `n`. return basic.Linear( x.shape[1], w_init=initializers.Constant(jnp.eye(x.shape[1])), b_init=initializers.Constant(1.0), )(x) @multi_transform.without_apply_rng @transform.transform def hk_fn(x, reverse=False): return layer_stack.layer_stack(stack_height)(inner_fn)( x, reverse=reverse) key_seq = base.PRNGSequence(19) init_value = 1 + jax.random.uniform(next(key_seq), [batch_size, width]) def mul_by_m(x): m_x = jnp.arange(stack_height) + 1 while m_x.ndim < x.ndim: m_x = m_x[..., None] return x * m_x * alpha params = jax.tree_map(mul_by_m, hk_fn.init(next(key_seq), init_value)) a, b = forward[-1] x_n = hk_fn.apply(params, init_value) np.testing.assert_allclose(x_n, a * init_value + b, rtol=1e-6) a, b = backward[-1] y_0 = hk_fn.apply(params, init_value, reverse=True) np.testing.assert_allclose(y_0, a * init_value + b, rtol=1e-6)
def create_random_values(key_or_seed): key_seq = base.PRNGSequence(key_or_seed) return (jax.random.normal(next(key_seq), []), jax.random.normal(next(key_seq), []))
def test_prng_sequence_wrong_shape(self): with self.assertRaisesRegex( ValueError, "key did not have expected shape and/or dtype"): base.PRNGSequence(jax.random.split(jax.random.PRNGKey(42), 2))
def test_prng_sequence_invalid_input(self): with self.assertRaisesRegex(ValueError, "not a JAX PRNGKey"): base.PRNGSequence("nonsense")
def test_prng_sequence_split(self): k = jax.random.PRNGKey(42) s = base.PRNGSequence(k) hk_keys = s.take(10) jax_keys = tuple(jax.random.split(k, num=11)[1:]) jax.tree_multimap(np.testing.assert_array_equal, hk_keys, jax_keys)