def test_assert_no_new_parameters(self): with base.new_context(): base.get_parameter("w", [], init=jnp.zeros) with base.assert_no_new_parameters(): # Should not raise, "w" already exists. base.get_parameter("w", [], init=jnp.zeros) with self.assertRaisesRegex(AssertionError, "New parameters were created: .*x"): with base.assert_no_new_parameters(): # Should raise, "x" does not exist. base.get_parameter("x", [], init=jnp.zeros)
def stateful_fun(carry, x): carry, state = carry with temporary_internal_state(state): with base.assert_no_new_parameters(): carry, out = f(carry, x) carry = (carry, internal_state(params=False)) return carry, out
def stateful_fun(carry, x): carry, state = carry with temporary_internal_state(state): with base.assert_no_new_parameters(), \ base.push_jax_trace_level(): carry, out = f(carry, x) reserve_up_to_full_rng_block() carry = (carry, internal_state(params=False)) return carry, out