def update(i, g, state): x = state[0] state = jax_common.NAdamWParamState(*state[1:]) update, new_s = jax_common.nadamw_update(i, hyper_params, x, state, g) new_x = x + update return new_x, new_s[0], new_s[1]
def update_one(g, p, g_acc, g_sq_acc): s = jax_common.NAdamWParamState(g_acc, g_sq_acc) new_x, new_s = jax_common.nadamw_update(idx, hyper_params, p, s, g) return new_x, new_s
def apply_param_gradient(self, step, hyper_params, param, state, grad): update, state = jax_common.nadamw_update(step, hyper_params, param, state, grad) return param + update, state