Example #1
0
    def test_assert_no_new_parameters(self):
        with base.new_context():
            base.get_parameter("w", [], init=jnp.zeros)
            with base.assert_no_new_parameters():
                # Should not raise, "w" already exists.
                base.get_parameter("w", [], init=jnp.zeros)

            with self.assertRaisesRegex(AssertionError,
                                        "New parameters were created: .*x"):
                with base.assert_no_new_parameters():
                    # Should raise, "x" does not exist.
                    base.get_parameter("x", [], init=jnp.zeros)
Example #2
0
 def stateful_fun(carry, x):
     carry, state = carry
     with temporary_internal_state(state):
         with base.assert_no_new_parameters():
             carry, out = f(carry, x)
         carry = (carry, internal_state(params=False))
         return carry, out
Example #3
0
 def stateful_fun(carry, x):
     carry, state = carry
     with temporary_internal_state(state):
         with base.assert_no_new_parameters(), \
              base.push_jax_trace_level():
             carry, out = f(carry, x)
         reserve_up_to_full_rng_block()
         carry = (carry, internal_state(params=False))
         return carry, out