Exemplo n.º 1
0
 def pure_body_fun(val):
     val, state = val
     with temporary_internal_state(state), \
          base.push_jax_trace_level():
         val = body_fun(val)
         state = internal_state()
         return val, state
Exemplo n.º 2
0
 def pure_body_fun(i, val):
     state, val = val
     with temporary_internal_state(state), \
          base.push_jax_trace_level():
         val = body_fun(i, val)
         reserve_up_to_full_rng_block()
         state = internal_state()
         return state, val
Exemplo n.º 3
0
 def new_branch_fun(operand):
     state, operand = operand
     with temporary_internal_state(state), \
          base.push_jax_trace_level():
         out = branch_fun(*operand)
         reserve_up_to_full_rng_block()
         # TODO(tomhennigan) Return difference of state in/out here.
         return out, internal_state()
Exemplo n.º 4
0
 def stateful_fun(*args, **kwargs):
     state_in = kwargs.pop("hk_state")
     with temporary_internal_state(state_in), \
          base.push_jax_trace_level():
         out = fun(*args, **kwargs)
         out, aux = (out if has_aux else (out, None))
         state_out = difference(state_in, internal_state())
         return out, (aux, state_out)
Exemplo n.º 5
0
 def stateful_fun(carry, x):
     carry, state = carry
     with temporary_internal_state(state):
         with base.assert_no_new_parameters(), \
              base.push_jax_trace_level():
             carry, out = f(carry, x)
         reserve_up_to_full_rng_block()
         carry = (carry, internal_state(params=False))
         return carry, out
Exemplo n.º 6
0
    def pure_fun(args, state_in):
        if split_rng:
            # NOTE: In the case of split_rng we recieve an RNG key (rather than the
            # internal state of a PRNGSequence) so we need to construct that here.
            rng = base.PRNGSequence(state_in.rng).internal_state
            state_in = InternalState(state_in.params, state_in.state, rng)

        with temporary_internal_state(state_in), \
             base.push_jax_trace_level():
            out = fun(*args)
            state_out = difference(state_in, internal_state())
            return out, state_out
Exemplo n.º 7
0
 def stateless_fun(state, *args, **kwargs):
     with temporary_internal_state(state), \
          base.push_jax_trace_level():
         out = fun(*args, **kwargs)
         # Don't return changed state
         return out
Exemplo n.º 8
0
 def stateful_fun(*args, **kwargs):
     state_in = kwargs.pop("hk_state")
     with temporary_internal_state(state_in, share_python_state=True), \
          base.push_jax_trace_level():
         out = fun(*args, **kwargs)
         return out, difference(state_in, internal_state())