Exemple #1
0
    def update(self, params, state, epoch, *args, **kwargs):
        """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: number of epoch.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """
        del epoch  # unused
        if self.lmbda == 1:
            raise ValueError(
                'lmbda =1 was passed to SPSsqrt solver. This solver does not work with lmbda =1 because then the parameters are never updated! '
            )

        if self.has_aux:
            (value,
             aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
        else:
            value, grad = self._value_and_grad_fun(params, *args, **kwargs)
            aux = None

        # If slack hits zero, reset to be the current value.
        # This stops the method from halting.
        if state.slack == 0.0:
            state = state._replace(slack=value)
        ## The mathematical expression of the this update is:
        # step = (value - (1-lmbda/2) sqrt(s))_+) / (4s||grad||^2 + 1 - lmbda)
        # w = w -  4 step s *grad,
        # s = (1-lmbda)*sqrt(s)*(sqrt(s) + step)
        step_size = jax.nn.relu(
            value - (1 - self.lmbda / 2) * jnp.sqrt(state.slack)) / (
                4 * state.slack * tree_l2_norm(grad, squared=True) + 1 -
                self.lmbda)
        newslack = (1 - self.lmbda) * jnp.sqrt(
            state.slack) * (jnp.sqrt(state.slack) + step_size)
        step_size = 4 * state.slack * step_size

        if self.momentum == 0:
            new_params = tree_add_scalar_mul(params, -step_size, grad)
            new_velocity = None
        else:
            # new_v = momentum * v - step_size * grad
            # new_params = params + new_v
            new_velocity = tree_sub(
                tree_scalar_mul(self.momentum, state.velocity),
                tree_scalar_mul(step_size, grad))
            new_params = tree_add(params, new_velocity)

        new_state = SPSsqrtState(iter_num=state.iter_num + 1,
                                 value=value,
                                 slack=newslack,
                                 velocity=new_velocity,
                                 aux=aux)
        return base.OptStep(params=new_params, state=new_state)
  def update(self,
             params,
             state,
             epoch,
             *args,
             **kwargs):
    """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.

    Returns:
      (params, state)
    """
    if self.has_aux:
      (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    else:
      value, grad = self._value_and_grad_fun(params, *args, **kwargs)
      aux = None

  # Currently experimenting with decreasing lambda slowly after many iterations.
  # The intuition behind this is that in the early iterations SPS (self.lmbda=1)
  # works well. But in later iterations the slack helpd stabilize.
    late_start = 10
    if self.lmbda_schedule and epoch > late_start:
      lmbdat = self.lmbda/(jnp.log(jnp.log(epoch-late_start+1)+1)+1)
    else:
      lmbdat = self.lmbda

    ## Mathematical description on this step size:
    # step_size = (f_i(w^t) - (1-lmbda) s)_+) / (||grad||^2 + 1 - lmbda)
    step_size = jax.nn.relu(value - (1-lmbdat)*state.slack)/(tree_l2_norm(
        grad, squared=True) + 1 - lmbdat)
    newslack = (1 - lmbdat) * (state.slack + step_size)
    # new_params = tree_add_scalar_mul(params, -step_size, grad)

    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, -step_size, grad)
      new_velocity = None
    else:
      # new_v = momentum * v - step_size * grad
      # new_params = params + new_v
      new_velocity = tree_sub(
          tree_scalar_mul(self.momentum, state.velocity),
          tree_scalar_mul(step_size, grad))
      new_params = tree_add(params, new_velocity)

    new_state = SPSDamState(
        iter_num=state.iter_num + 1,
        value=value,
        slack=newslack,
        velocity=new_velocity,
        aux=aux)
    return base.OptStep(params=new_params, state=new_state)
  def update(self,
             params,
             state,
             epoch,
             *args,
             **kwargs):
    """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """
    if self.has_aux:
      (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    else:
      value, grad = self._value_and_grad_fun(params, *args, **kwargs)
      aux = None

    gradnorm = tree_l2_norm(grad, squared=True)
    step1 = jax.nn.relu(value - state.slack + self.delta * self.lmbda) / (
        self.delta + gradnorm)
    spsstep = value / gradnorm
    step_size = jnp.minimum(step1, spsstep)
    newslack = jax.nn.relu(state.slack - self.lmbda * self.delta +
                           self.delta * step1)
    # new_params = tree_add_scalar_mul(params, -step_size, grad)

    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, -step_size, grad)
      new_velocity = None
    else:
      # new_v = momentum * v - step_size * grad
      # new_params = params + new_v
      new_velocity = tree_sub(tree_scalar_mul(self.momentum, state.velocity),
                              tree_scalar_mul(step_size, grad))
      new_params = tree_add(params, new_velocity)
    new_state = SPSL1State(
        iter_num=state.iter_num + 1,
        value=value,
        slack=newslack,
        velocity=new_velocity,
        aux=aux)
    return base.OptStep(params=new_params, state=new_state)
  def update(self,
             params,
             state,
             data,
             *args,
             **kwargs):
    """Performs one iteration of the optax solver.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      data: dict.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """

    del args, kwargs  # unused
    (value, aux), update = self._spsdiag_update(params, data)
    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, self.learning_rate, update)
      new_velocity = None
    else:
      new_velocity = tree_sub(
          tree_scalar_mul(self.momentum, state.velocity),
          tree_scalar_mul(self.learning_rate, update))
      new_params = tree_add(params, new_velocity)

    new_params = tree_add_scalar_mul(
        params, self.learning_rate, update)
    aux['loss'] = jnp.mean(aux['loss'])
    aux['accuracy'] = jnp.mean(aux['accuracy'])

    if state.iter_num % 10 == 0:
      print('Number of iterations', state.iter_num,
            '. Objective function value: ', value)

    new_state = StochasticPolyakState(
        iter_num=state.iter_num+1, value=value, velocity=new_velocity, aux=aux)
    return base.OptStep(params=new_params, state=new_state)
def projection_hyperplane(a, b, x = None):
  r"""Projection onto a hyperplane defined by a pytree and scalar.

  The output is:
    ``argmin_{y, dot(a, y) = b} ||y - x||``.
  Which is equivalent to
     y = x - (<a,x>-b)/<a,a> a
  Args:
    x: pytree to project.
    hyperparams: tuple ``hyperparams = (a, b)``, where ``a`` is a pytree and
      ``b`` is a scalar.
  Returns:
    y: output array (same shape as ``x``)
  """
  if x is None:
    scale = b/tree_util.tree_vdot(a,a)
    return tree_util.tree_scalar_mul(scale, a)
  else:
    scale = (tree_util.tree_vdot(a,x) -b)/tree_util.tree_vdot(a,a)
    return tree_util.tree_add_scalar_mul(x, -scale, a)
 def least_square_regularizor_1d(a, b, delta):
   # Computes the solution to min || a^Tx -b||^2 + delta ||x||^2
   scale = -b/(tree_vdot(a, a) + delta)
   return tree_scalar_mul(scale, a)
Exemple #7
0
    def update_pytrees_QP(self, params, state, epoch, data, *args, **kwargs):
        """Performs one iteration of the system Polyak solver.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int, epoch number.
      data: a batch of data.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type: base.OptStep

    Returns:
      (params, state)
    """

        del epoch, args, kwargs  # unused

        # The output of losses(params) is of shape size(batch).
        # Therefore, losses is a function from size(params) to size(batch).
        # Therefore the Jacobian is of shape size(batch) x size(params).
        def losses(params):
            # Currently, self.fun returns the losses BEFORE the mean reduction.
            return self.fun(params, data)[0]

        # TODO(rmgower): avoid recomputing the auxiliary output (metrics)
        aux = self.fun(params, data)[1]

        # Solves 0.5 ||w - w^t||^2 s.t. A w = b
        # where A is of shape size(batch) x size(params) and contains the gradients
        #       b is of shape size(batch) and b[i] = A w^t - loss_values[i]
        #       w = params

        # This is equivalent to solving 0.5 w^T Q w + <c, w> s.t. A w = b,
        # where Q = Identity and c = -w^t.
        def matvec_Q(params_Q, u):
            del params_Q  # ignored
            return u

        def matvec_A(params_A, u):
            del params_A  # ignored
            return jax.jvp(losses, (params, ), (u, ))[1]

        # Since A is the Jacobian of losses, A w^t is a JVP.
        # This computes the JVP and loss values along the way.
        loss_values, Awt = jax.jvp(losses, (params, ), (params, ))
        b = Awt - loss_values
        # Rob: Double wrong! Check again
        c = tree_util.tree_scalar_mul(-1.0, params)
        params_obj = (None, c)
        params_eq = (None, b)

        # Rob: Solves for primal and dual variables,
        # thus solves very large linear system.
        qp = jaxopt.QuadraticProgramming(matvec_Q=matvec_Q,
                                         matvec_A=matvec_A,
                                         maxiter=10)
        res = qp.run(params_obj=params_obj, params_eq=params_eq).params
        # res contains both the primal and dual variables
        # but we only need the primal variable.
        new_params = res[0]

        new_state = SystemStochasticPolyakState(iter_num=state.iter_num + 1,
                                                aux=aux)
        return base.OptStep(params=new_params, state=new_state)