def update_one(g, p, g_acc, g_sq_acc):
     s = jax_common.NAdamWParamState(g_acc, g_sq_acc)
     new_x, new_s = jax_common.nadamw_update(idx, hyper_params, p, s, g)
     return new_x, new_s
 def update(i, g, state):
   x = state[0]
   state = jax_common.NAdamWParamState(*state[1:])
   update, new_s = jax_common.nadamw_update(i, hyper_params, x, state, g)
   new_x = x + update
   return new_x, new_s[0], new_s[1]