def comp02(theta, state0, inputs): with base_layer.JaxContext.new_context(prng_key=prng_key, global_step=global_step): final_state, cum_states = recurrent.recurrent_static( theta, state0, inputs, cell_fn) loss = jnp.sum(final_state.y) + jnp.sum(cum_states.y) return loss
def comp01(theta, state0, inputs): final_state, cum_states = recurrent.recurrent_static( theta, state0, inputs, cell_fn) return final_state, cum_states