Example #1
0
    def update(self, params, state, epoch, *args, **kwargs):
        """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: number of epoch.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """
        del epoch  # unused
        if self.lmbda == 1:
            raise ValueError(
                'lmbda =1 was passed to SPSsqrt solver. This solver does not work with lmbda =1 because then the parameters are never updated! '
            )

        if self.has_aux:
            (value,
             aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
        else:
            value, grad = self._value_and_grad_fun(params, *args, **kwargs)
            aux = None

        # If slack hits zero, reset to be the current value.
        # This stops the method from halting.
        if state.slack == 0.0:
            state = state._replace(slack=value)
        ## The mathematical expression of the this update is:
        # step = (value - (1-lmbda/2) sqrt(s))_+) / (4s||grad||^2 + 1 - lmbda)
        # w = w -  4 step s *grad,
        # s = (1-lmbda)*sqrt(s)*(sqrt(s) + step)
        step_size = jax.nn.relu(
            value - (1 - self.lmbda / 2) * jnp.sqrt(state.slack)) / (
                4 * state.slack * tree_l2_norm(grad, squared=True) + 1 -
                self.lmbda)
        newslack = (1 - self.lmbda) * jnp.sqrt(
            state.slack) * (jnp.sqrt(state.slack) + step_size)
        step_size = 4 * state.slack * step_size

        if self.momentum == 0:
            new_params = tree_add_scalar_mul(params, -step_size, grad)
            new_velocity = None
        else:
            # new_v = momentum * v - step_size * grad
            # new_params = params + new_v
            new_velocity = tree_sub(
                tree_scalar_mul(self.momentum, state.velocity),
                tree_scalar_mul(step_size, grad))
            new_params = tree_add(params, new_velocity)

        new_state = SPSsqrtState(iter_num=state.iter_num + 1,
                                 value=value,
                                 slack=newslack,
                                 velocity=new_velocity,
                                 aux=aux)
        return base.OptStep(params=new_params, state=new_state)
  def update(self,
             params,
             state,
             epoch,
             *args,
             **kwargs):
    """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.

    Returns:
      (params, state)
    """
    if self.has_aux:
      (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    else:
      value, grad = self._value_and_grad_fun(params, *args, **kwargs)
      aux = None

  # Currently experimenting with decreasing lambda slowly after many iterations.
  # The intuition behind this is that in the early iterations SPS (self.lmbda=1)
  # works well. But in later iterations the slack helpd stabilize.
    late_start = 10
    if self.lmbda_schedule and epoch > late_start:
      lmbdat = self.lmbda/(jnp.log(jnp.log(epoch-late_start+1)+1)+1)
    else:
      lmbdat = self.lmbda

    ## Mathematical description on this step size:
    # step_size = (f_i(w^t) - (1-lmbda) s)_+) / (||grad||^2 + 1 - lmbda)
    step_size = jax.nn.relu(value - (1-lmbdat)*state.slack)/(tree_l2_norm(
        grad, squared=True) + 1 - lmbdat)
    newslack = (1 - lmbdat) * (state.slack + step_size)
    # new_params = tree_add_scalar_mul(params, -step_size, grad)

    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, -step_size, grad)
      new_velocity = None
    else:
      # new_v = momentum * v - step_size * grad
      # new_params = params + new_v
      new_velocity = tree_sub(
          tree_scalar_mul(self.momentum, state.velocity),
          tree_scalar_mul(step_size, grad))
      new_params = tree_add(params, new_velocity)

    new_state = SPSDamState(
        iter_num=state.iter_num + 1,
        value=value,
        slack=newslack,
        velocity=new_velocity,
        aux=aux)
    return base.OptStep(params=new_params, state=new_state)
  def update(self,
             params,
             state,
             epoch,
             *args,
             **kwargs):
    """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """
    if self.has_aux:
      (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    else:
      value, grad = self._value_and_grad_fun(params, *args, **kwargs)
      aux = None

    gradnorm = tree_l2_norm(grad, squared=True)
    step1 = jax.nn.relu(value - state.slack + self.delta * self.lmbda) / (
        self.delta + gradnorm)
    spsstep = value / gradnorm
    step_size = jnp.minimum(step1, spsstep)
    newslack = jax.nn.relu(state.slack - self.lmbda * self.delta +
                           self.delta * step1)
    # new_params = tree_add_scalar_mul(params, -step_size, grad)

    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, -step_size, grad)
      new_velocity = None
    else:
      # new_v = momentum * v - step_size * grad
      # new_params = params + new_v
      new_velocity = tree_sub(tree_scalar_mul(self.momentum, state.velocity),
                              tree_scalar_mul(step_size, grad))
      new_params = tree_add(params, new_velocity)
    new_state = SPSL1State(
        iter_num=state.iter_num + 1,
        value=value,
        slack=newslack,
        velocity=new_velocity,
        aux=aux)
    return base.OptStep(params=new_params, state=new_state)
  def update(self,
             params,
             state,
             data,
             *args,
             **kwargs):
    """Performs one iteration of the optax solver.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      data: dict.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """

    del args, kwargs  # unused
    (value, aux), update = self._spsdiag_update(params, data)
    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, self.learning_rate, update)
      new_velocity = None
    else:
      new_velocity = tree_sub(
          tree_scalar_mul(self.momentum, state.velocity),
          tree_scalar_mul(self.learning_rate, update))
      new_params = tree_add(params, new_velocity)

    new_params = tree_add_scalar_mul(
        params, self.learning_rate, update)
    aux['loss'] = jnp.mean(aux['loss'])
    aux['accuracy'] = jnp.mean(aux['accuracy'])

    if state.iter_num % 10 == 0:
      print('Number of iterations', state.iter_num,
            '. Objective function value: ', value)

    new_state = StochasticPolyakState(
        iter_num=state.iter_num+1, value=value, velocity=new_velocity, aux=aux)
    return base.OptStep(params=new_params, state=new_state)