def pure_body_fun(val): val, state = val with temporary_internal_state(state), \ base.push_jax_trace_level(): val = body_fun(val) state = internal_state() return val, state
def pure_body_fun(i, val): state, val = val with temporary_internal_state(state), \ base.push_jax_trace_level(): val = body_fun(i, val) reserve_up_to_full_rng_block() state = internal_state() return state, val
def new_branch_fun(operand): state, operand = operand with temporary_internal_state(state), \ base.push_jax_trace_level(): out = branch_fun(*operand) reserve_up_to_full_rng_block() # TODO(tomhennigan) Return difference of state in/out here. return out, internal_state()
def stateful_fun(*args, **kwargs): state_in = kwargs.pop("hk_state") with temporary_internal_state(state_in), \ base.push_jax_trace_level(): out = fun(*args, **kwargs) out, aux = (out if has_aux else (out, None)) state_out = difference(state_in, internal_state()) return out, (aux, state_out)
def stateful_fun(carry, x): carry, state = carry with temporary_internal_state(state): with base.assert_no_new_parameters(), \ base.push_jax_trace_level(): carry, out = f(carry, x) reserve_up_to_full_rng_block() carry = (carry, internal_state(params=False)) return carry, out
def pure_fun(args, state_in): if split_rng: # NOTE: In the case of split_rng we recieve an RNG key (rather than the # internal state of a PRNGSequence) so we need to construct that here. rng = base.PRNGSequence(state_in.rng).internal_state state_in = InternalState(state_in.params, state_in.state, rng) with temporary_internal_state(state_in), \ base.push_jax_trace_level(): out = fun(*args) state_out = difference(state_in, internal_state()) return out, state_out
def stateless_fun(state, *args, **kwargs): with temporary_internal_state(state), \ base.push_jax_trace_level(): out = fun(*args, **kwargs) # Don't return changed state return out
def stateful_fun(*args, **kwargs): state_in = kwargs.pop("hk_state") with temporary_internal_state(state_in, share_python_state=True), \ base.push_jax_trace_level(): out = fun(*args, **kwargs) return out, difference(state_in, internal_state())