def optimizer_for_idx(idx, training_steps):
  """Get a nadamw optimizer for the given configuration and training_steps.

  Args:
    idx: int
      The index into the learned optimizer list.
    training_steps: int
      total number of training steps that the model will be trained.

  Returns:
    An (init_fun, update_fun, get_params) triple.
  """
  config = common.get_optimizer_config(idx)
  config['training_steps'] = training_steps
  config['use_bias_correction'] = True  # always true for now.
  hyper_params = jax_common.NAdamWHyperParams(**config)

  def init(x0):
    return x0, jnp.zeros_like(x0), jnp.zeros_like(x0)

  def update(i, g, state):
    x = state[0]
    state = jax_common.NAdamWParamState(*state[1:])
    update, new_s = jax_common.nadamw_update(i, hyper_params, x, state, g)
    new_x = x + update
    return new_x, new_s[0], new_s[1]

  def get_params(state):
    x, _, _ = state
    return x

  return init, update, get_params
def optimizer_for_idx(idx, training_steps):
    """Get a nadamw optimizer for the given configuration and training_steps.

  Unlike regular Optix functions, the update function returned here additionally
  takes a parameter argument.

  Args:
    idx: int The index into the learned optimizer list.
    training_steps: int total number of training steps that the model will be
      trained.

  Returns:
    An (init_fn, update_with_params_fn) tuple.
  """
    config = common.get_optimizer_config(idx)
    config["training_steps"] = training_steps
    config["use_bias_correction"] = True  # always true for now.
    hyper_params = jax_common.NAdamWHyperParams(**config)

    def init(params):
        zero_initial = tree_map(jnp.zeros_like, params)
        return zero_initial, zero_initial, 0

    def update_fn(grads, params, state):
        """Compute the update.

    Args:
      grads: pytree of ndarray
        Gradient values.
      params: pytree of ndarray
        Parameter values.
      state:
        A tuple of (gradient accumulators, squared gradient accumulators, idx)
    Returns:
      step: pytree of ndarray
        The step to be added to the parameter values.
      next_state:
        A tuple of (gradient accumulators, squared gradient accumulators, idx)
    """

        grad_acc, grad_sq_acc, idx = state

        def update_one(g, p, g_acc, g_sq_acc):
            s = jax_common.NAdamWParamState(g_acc, g_sq_acc)
            new_x, new_s = jax_common.nadamw_update(idx, hyper_params, p, s, g)
            return new_x, new_s

        # the following flattens, applies a map, extracts values out via zip,
        # then unflattens.
        flat_gs, tree_def = tree_flatten(grads)
        flat_ps, _ = tree_flatten(params)
        flat_s0, _ = tree_flatten(grad_acc)
        flat_s1, _ = tree_flatten(grad_sq_acc)

        next_param_states = tree_map(update_one, flat_gs, flat_ps, flat_s0,
                                     flat_s1)

        flat_step, flat_next_ss = zip(*next_param_states)
        flat_next_grad_acc, flat_next_grad_sq_acc = zip(*flat_next_ss)

        step = tree_unflatten(tree_def, flat_step)
        next_grad_acc = tree_unflatten(tree_def, flat_next_grad_acc)
        next_grad_sq_acc = tree_unflatten(tree_def, flat_next_grad_sq_acc)

        return step, (next_grad_acc, next_grad_sq_acc, idx + 1)

    return InitUpdateWithParams(init, update_fn)