def optimizer_for_idx( idx, training_steps, iteration = None): """Get a nadamw optimizer for the given configuration and training_steps.""" # TODO(lmetz) the global step is obtained here. Ideally, we should be using # the value used by the underlying tensorflow optimizer but at this moment # we don't have access to it. if not iteration: logging.warning("Iteration not passed in! Using the default global_step for" "keeping track of training progress") iteration = tf.train.get_or_create_global_step() cfg = common.get_optimizer_config(idx) fn = get_cosine_learning_rate_fn( training_steps=training_steps, learning_rate=cfg["learning_rate"], min_learning_rate_mult=cfg["min_learning_rate_mult"], constant_fraction=cfg["constant_fraction"], warmup_fraction=cfg["warmup_fraction"]) return NAdamWOptimizer( learning_rate=fn(iteration), beta1=cfg["beta1"], beta2=cfg["beta2"], epsilon=cfg["epsilon"], l2_weight_decay=cfg["l2_weight_decay"], adamw_weight_decay=cfg["adamw_weight_decay"], )
def optimizer_for_idx(idx, training_steps): """Get a nadamw optimizer for the given configuration and training_steps. Args: idx: int The index into the learned optimizer list. training_steps: int total number of training steps that the model will be trained. Returns: An (init_fun, update_fun, get_params) triple. """ config = common.get_optimizer_config(idx) config['training_steps'] = training_steps config['use_bias_correction'] = True # always true for now. hyper_params = jax_common.NAdamWHyperParams(**config) def init(x0): return x0, jnp.zeros_like(x0), jnp.zeros_like(x0) def update(i, g, state): x = state[0] state = jax_common.NAdamWParamState(*state[1:]) update, new_s = jax_common.nadamw_update(i, hyper_params, x, state, g) new_x = x + update return new_x, new_s[0], new_s[1] def get_params(state): x, _, _ = state return x return init, update, get_params
def keras_optimizer_for_idx(idx, training_steps): """Get a nadamw optimizer for the given configuration and training_steps.""" cfg = common.get_optimizer_config(idx) decay = CustomCosineDecay( training_steps=training_steps, learning_rate=cfg["learning_rate"], min_learning_rate_mult=cfg["min_learning_rate_mult"], constant_fraction=cfg["constant_fraction"], warmup_fraction=cfg["warmup_fraction"]) return NAdamWKeras( learning_rate=decay, beta1=cfg["beta1"], beta2=cfg["beta2"], epsilon=cfg["epsilon"], l2_weight_decay=cfg["l2_weight_decay"], adamw_weight_decay=cfg["adamw_weight_decay"], )
def optimizer_for_idx(idx, training_steps): """Get a nadamw optimizer for the given configuration and training_steps. Unlike regular Optix functions, the update function returned here additionally takes a parameter argument. Args: idx: int The index into the learned optimizer list. training_steps: int total number of training steps that the model will be trained. Returns: An (init_fn, update_with_params_fn) tuple. """ config = common.get_optimizer_config(idx) config["training_steps"] = training_steps config["use_bias_correction"] = True # always true for now. hyper_params = jax_common.NAdamWHyperParams(**config) def init(params): zero_initial = tree_map(jnp.zeros_like, params) return zero_initial, zero_initial, 0 def update_fn(grads, params, state): """Compute the update. Args: grads: pytree of ndarray Gradient values. params: pytree of ndarray Parameter values. state: A tuple of (gradient accumulators, squared gradient accumulators, idx) Returns: step: pytree of ndarray The step to be added to the parameter values. next_state: A tuple of (gradient accumulators, squared gradient accumulators, idx) """ grad_acc, grad_sq_acc, idx = state def update_one(g, p, g_acc, g_sq_acc): s = jax_common.NAdamWParamState(g_acc, g_sq_acc) new_x, new_s = jax_common.nadamw_update(idx, hyper_params, p, s, g) return new_x, new_s # the following flattens, applies a map, extracts values out via zip, # then unflattens. flat_gs, tree_def = tree_flatten(grads) flat_ps, _ = tree_flatten(params) flat_s0, _ = tree_flatten(grad_acc) flat_s1, _ = tree_flatten(grad_sq_acc) next_param_states = tree_map(update_one, flat_gs, flat_ps, flat_s0, flat_s1) flat_step, flat_next_ss = zip(*next_param_states) flat_next_grad_acc, flat_next_grad_sq_acc = zip(*flat_next_ss) step = tree_unflatten(tree_def, flat_step) next_grad_acc = tree_unflatten(tree_def, flat_next_grad_acc) next_grad_sq_acc = tree_unflatten(tree_def, flat_next_grad_sq_acc) return step, (next_grad_acc, next_grad_sq_acc, idx + 1) return InitUpdateWithParams(init, update_fn)
def optimizer_for_idx(idx, training_steps): """Get a OptimizerDef for the given configuration and training_steps.""" config = common.get_optimizer_config(idx) config['training_steps'] = training_steps return NAdamWCosineDecay(**config)