Example #1
0
def optimizer_for_idx(
    idx,
    training_steps,
    iteration = None):
  """Get a nadamw optimizer for the given configuration and training_steps."""
  # TODO(lmetz) the global step is obtained here. Ideally, we should be using
  # the value used by the underlying tensorflow optimizer but at this moment
  # we don't have access to it.
  if not iteration:
    logging.warning("Iteration not passed in! Using the default global_step for"
                    "keeping track of training progress")
    iteration = tf.train.get_or_create_global_step()

  cfg = common.get_optimizer_config(idx)

  fn = get_cosine_learning_rate_fn(
      training_steps=training_steps,
      learning_rate=cfg["learning_rate"],
      min_learning_rate_mult=cfg["min_learning_rate_mult"],
      constant_fraction=cfg["constant_fraction"],
      warmup_fraction=cfg["warmup_fraction"])

  return NAdamWOptimizer(
      learning_rate=fn(iteration),
      beta1=cfg["beta1"],
      beta2=cfg["beta2"],
      epsilon=cfg["epsilon"],
      l2_weight_decay=cfg["l2_weight_decay"],
      adamw_weight_decay=cfg["adamw_weight_decay"],
  )
def optimizer_for_idx(idx, training_steps):
  """Get a nadamw optimizer for the given configuration and training_steps.

  Args:
    idx: int
      The index into the learned optimizer list.
    training_steps: int
      total number of training steps that the model will be trained.

  Returns:
    An (init_fun, update_fun, get_params) triple.
  """
  config = common.get_optimizer_config(idx)
  config['training_steps'] = training_steps
  config['use_bias_correction'] = True  # always true for now.
  hyper_params = jax_common.NAdamWHyperParams(**config)

  def init(x0):
    return x0, jnp.zeros_like(x0), jnp.zeros_like(x0)

  def update(i, g, state):
    x = state[0]
    state = jax_common.NAdamWParamState(*state[1:])
    update, new_s = jax_common.nadamw_update(i, hyper_params, x, state, g)
    new_x = x + update
    return new_x, new_s[0], new_s[1]

  def get_params(state):
    x, _, _ = state
    return x

  return init, update, get_params
Example #3
0
def keras_optimizer_for_idx(idx, training_steps):
    """Get a nadamw optimizer for the given configuration and training_steps."""
    cfg = common.get_optimizer_config(idx)

    decay = CustomCosineDecay(
        training_steps=training_steps,
        learning_rate=cfg["learning_rate"],
        min_learning_rate_mult=cfg["min_learning_rate_mult"],
        constant_fraction=cfg["constant_fraction"],
        warmup_fraction=cfg["warmup_fraction"])

    return NAdamWKeras(
        learning_rate=decay,
        beta1=cfg["beta1"],
        beta2=cfg["beta2"],
        epsilon=cfg["epsilon"],
        l2_weight_decay=cfg["l2_weight_decay"],
        adamw_weight_decay=cfg["adamw_weight_decay"],
    )
def optimizer_for_idx(idx, training_steps):
    """Get a nadamw optimizer for the given configuration and training_steps.

  Unlike regular Optix functions, the update function returned here additionally
  takes a parameter argument.

  Args:
    idx: int The index into the learned optimizer list.
    training_steps: int total number of training steps that the model will be
      trained.

  Returns:
    An (init_fn, update_with_params_fn) tuple.
  """
    config = common.get_optimizer_config(idx)
    config["training_steps"] = training_steps
    config["use_bias_correction"] = True  # always true for now.
    hyper_params = jax_common.NAdamWHyperParams(**config)

    def init(params):
        zero_initial = tree_map(jnp.zeros_like, params)
        return zero_initial, zero_initial, 0

    def update_fn(grads, params, state):
        """Compute the update.

    Args:
      grads: pytree of ndarray
        Gradient values.
      params: pytree of ndarray
        Parameter values.
      state:
        A tuple of (gradient accumulators, squared gradient accumulators, idx)
    Returns:
      step: pytree of ndarray
        The step to be added to the parameter values.
      next_state:
        A tuple of (gradient accumulators, squared gradient accumulators, idx)
    """

        grad_acc, grad_sq_acc, idx = state

        def update_one(g, p, g_acc, g_sq_acc):
            s = jax_common.NAdamWParamState(g_acc, g_sq_acc)
            new_x, new_s = jax_common.nadamw_update(idx, hyper_params, p, s, g)
            return new_x, new_s

        # the following flattens, applies a map, extracts values out via zip,
        # then unflattens.
        flat_gs, tree_def = tree_flatten(grads)
        flat_ps, _ = tree_flatten(params)
        flat_s0, _ = tree_flatten(grad_acc)
        flat_s1, _ = tree_flatten(grad_sq_acc)

        next_param_states = tree_map(update_one, flat_gs, flat_ps, flat_s0,
                                     flat_s1)

        flat_step, flat_next_ss = zip(*next_param_states)
        flat_next_grad_acc, flat_next_grad_sq_acc = zip(*flat_next_ss)

        step = tree_unflatten(tree_def, flat_step)
        next_grad_acc = tree_unflatten(tree_def, flat_next_grad_acc)
        next_grad_sq_acc = tree_unflatten(tree_def, flat_next_grad_sq_acc)

        return step, (next_grad_acc, next_grad_sq_acc, idx + 1)

    return InitUpdateWithParams(init, update_fn)
def optimizer_for_idx(idx, training_steps):
    """Get a OptimizerDef for the given configuration and training_steps."""
    config = common.get_optimizer_config(idx)
    config['training_steps'] = training_steps
    return NAdamWCosineDecay(**config)