def optimizer_for_idx(idx, training_steps): """Get a nadamw optimizer for the given configuration and training_steps. Args: idx: int The index into the learned optimizer list. training_steps: int total number of training steps that the model will be trained. Returns: An (init_fun, update_fun, get_params) triple. """ config = common.get_optimizer_config(idx) config['training_steps'] = training_steps config['use_bias_correction'] = True # always true for now. hyper_params = jax_common.NAdamWHyperParams(**config) def init(x0): return x0, jnp.zeros_like(x0), jnp.zeros_like(x0) def update(i, g, state): x = state[0] state = jax_common.NAdamWParamState(*state[1:]) update, new_s = jax_common.nadamw_update(i, hyper_params, x, state, g) new_x = x + update return new_x, new_s[0], new_s[1] def get_params(state): x, _, _ = state return x return init, update, get_params
def optimizer_for_idx(idx, training_steps): """Get a nadamw optimizer for the given configuration and training_steps. Unlike regular Optix functions, the update function returned here additionally takes a parameter argument. Args: idx: int The index into the learned optimizer list. training_steps: int total number of training steps that the model will be trained. Returns: An (init_fn, update_with_params_fn) tuple. """ config = common.get_optimizer_config(idx) config["training_steps"] = training_steps config["use_bias_correction"] = True # always true for now. hyper_params = jax_common.NAdamWHyperParams(**config) def init(params): zero_initial = tree_map(jnp.zeros_like, params) return zero_initial, zero_initial, 0 def update_fn(grads, params, state): """Compute the update. Args: grads: pytree of ndarray Gradient values. params: pytree of ndarray Parameter values. state: A tuple of (gradient accumulators, squared gradient accumulators, idx) Returns: step: pytree of ndarray The step to be added to the parameter values. next_state: A tuple of (gradient accumulators, squared gradient accumulators, idx) """ grad_acc, grad_sq_acc, idx = state def update_one(g, p, g_acc, g_sq_acc): s = jax_common.NAdamWParamState(g_acc, g_sq_acc) new_x, new_s = jax_common.nadamw_update(idx, hyper_params, p, s, g) return new_x, new_s # the following flattens, applies a map, extracts values out via zip, # then unflattens. flat_gs, tree_def = tree_flatten(grads) flat_ps, _ = tree_flatten(params) flat_s0, _ = tree_flatten(grad_acc) flat_s1, _ = tree_flatten(grad_sq_acc) next_param_states = tree_map(update_one, flat_gs, flat_ps, flat_s0, flat_s1) flat_step, flat_next_ss = zip(*next_param_states) flat_next_grad_acc, flat_next_grad_sq_acc = zip(*flat_next_ss) step = tree_unflatten(tree_def, flat_step) next_grad_acc = tree_unflatten(tree_def, flat_next_grad_acc) next_grad_sq_acc = tree_unflatten(tree_def, flat_next_grad_sq_acc) return step, (next_grad_acc, next_grad_sq_acc, idx + 1) return InitUpdateWithParams(init, update_fn)