コード例 #1
0
    def update(self, params, state, epoch, *args, **kwargs):
        """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: number of epoch.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """
        del epoch  # unused
        if self.lmbda == 1:
            raise ValueError(
                'lmbda =1 was passed to SPSsqrt solver. This solver does not work with lmbda =1 because then the parameters are never updated! '
            )

        if self.has_aux:
            (value,
             aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
        else:
            value, grad = self._value_and_grad_fun(params, *args, **kwargs)
            aux = None

        # If slack hits zero, reset to be the current value.
        # This stops the method from halting.
        if state.slack == 0.0:
            state = state._replace(slack=value)
        ## The mathematical expression of the this update is:
        # step = (value - (1-lmbda/2) sqrt(s))_+) / (4s||grad||^2 + 1 - lmbda)
        # w = w -  4 step s *grad,
        # s = (1-lmbda)*sqrt(s)*(sqrt(s) + step)
        step_size = jax.nn.relu(
            value - (1 - self.lmbda / 2) * jnp.sqrt(state.slack)) / (
                4 * state.slack * tree_l2_norm(grad, squared=True) + 1 -
                self.lmbda)
        newslack = (1 - self.lmbda) * jnp.sqrt(
            state.slack) * (jnp.sqrt(state.slack) + step_size)
        step_size = 4 * state.slack * step_size

        if self.momentum == 0:
            new_params = tree_add_scalar_mul(params, -step_size, grad)
            new_velocity = None
        else:
            # new_v = momentum * v - step_size * grad
            # new_params = params + new_v
            new_velocity = tree_sub(
                tree_scalar_mul(self.momentum, state.velocity),
                tree_scalar_mul(step_size, grad))
            new_params = tree_add(params, new_velocity)

        new_state = SPSsqrtState(iter_num=state.iter_num + 1,
                                 value=value,
                                 slack=newslack,
                                 velocity=new_velocity,
                                 aux=aux)
        return base.OptStep(params=new_params, state=new_state)
コード例 #2
0
  def update(self,
             params,
             state,
             epoch,
             *args,
             **kwargs):
    """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.

    Returns:
      (params, state)
    """
    if self.has_aux:
      (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    else:
      value, grad = self._value_and_grad_fun(params, *args, **kwargs)
      aux = None

  # Currently experimenting with decreasing lambda slowly after many iterations.
  # The intuition behind this is that in the early iterations SPS (self.lmbda=1)
  # works well. But in later iterations the slack helpd stabilize.
    late_start = 10
    if self.lmbda_schedule and epoch > late_start:
      lmbdat = self.lmbda/(jnp.log(jnp.log(epoch-late_start+1)+1)+1)
    else:
      lmbdat = self.lmbda

    ## Mathematical description on this step size:
    # step_size = (f_i(w^t) - (1-lmbda) s)_+) / (||grad||^2 + 1 - lmbda)
    step_size = jax.nn.relu(value - (1-lmbdat)*state.slack)/(tree_l2_norm(
        grad, squared=True) + 1 - lmbdat)
    newslack = (1 - lmbdat) * (state.slack + step_size)
    # new_params = tree_add_scalar_mul(params, -step_size, grad)

    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, -step_size, grad)
      new_velocity = None
    else:
      # new_v = momentum * v - step_size * grad
      # new_params = params + new_v
      new_velocity = tree_sub(
          tree_scalar_mul(self.momentum, state.velocity),
          tree_scalar_mul(step_size, grad))
      new_params = tree_add(params, new_velocity)

    new_state = SPSDamState(
        iter_num=state.iter_num + 1,
        value=value,
        slack=newslack,
        velocity=new_velocity,
        aux=aux)
    return base.OptStep(params=new_params, state=new_state)
コード例 #3
0
    def update_arrays_CG(self, params, state, data, *args, **kwargs):
        """Perform the update using CG."""

        del kwargs  # unused
        batch_size = data['label'].shape[0]
        _, unravel_pytree = flatten_util.ravel_pytree(params)
        values = jnp.zeros((batch_size))

        @jax.jit
        def loss_sample(image, label):
            tmp_kwargs = {'data': {'image': image, 'label': label}}
            # compute a gradient on a single image/label pair
            if self.has_aux:
                # we only store the last value of aux
                (value_i, aux), grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
            else:
                value_i, grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
                aux = None
            grad_i_flatten, _ = flatten_util.ravel_pytree(grad_i)
            return value_i, aux, grad_i_flatten

        @jax.jit
        def matvec_array(u):
            """Computes the product  (J J^T +delta * I)u ."""
            out = grads @ (u @ grads) + self.delta * u
            return out

        # We add a new axis on data and labels so they have the correct
        # shape after vectorization by vmap, which removes the batch dimension
        ## Important: This is the bottleneck cost of this update!
        expand_data = jnp.expand_dims(data['image'], axis=1)
        expand_labels = jnp.expand_dims(data['label'], axis=1)
        values, aux, grads = jax.vmap(loss_sample,
                                      in_axes=(0, 0))(expand_data,
                                                      expand_labels)
        grads = jax.vmap(loss_sample, in_axes=(0, 0))(expand_data,
                                                      expand_labels)[2]

        # Solving  v =(J J^T +delta * I)^{-1}loss
        v = linear_solve.solve_cg(matvec_array, values, init=None, maxiter=20)
        ## Builds final update v= J^T(J J^T +delta * I)^{-1}loss
        v = v @ grads

        v_tree = unravel_pytree(v)
        new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree)
        value = jnp.mean(values)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', value)

        new_state = SystemStochasticPolyakState(
            # iter_num=state.iter_num + 1, value=value, aux=aux)
            iter_num=state.iter_num + 1,
            aux=aux)
        return base.OptStep(params=new_params, state=new_state)
コード例 #4
0
    def update_arrays_lstsq(self, params, state, data, *args, **kwargs):
        """Perform the update using a least square solver."""

        del kwargs  # unused
        # This version makes use of the least-squares solver jnp.linalg.lstsq
        # which has two problems
        # 1. It is too slow because it computes a full svd (overkill) to solve
        # the systems
        # 2. It has no support for regularization
        batch_size = data['label'].shape[0]
        _, unravel_pytree = flatten_util.ravel_pytree(params)
        values = jnp.zeros((batch_size))

        @jax.jit
        def loss_sample(image, label):
            tmp_kwargs = {'data': {'image': image, 'label': label}}
            # compute a gradient on a single image/label pair
            if self.has_aux:
                # we only store the last value of aux
                (value_i, aux), grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
            else:
                value_i, grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
                aux = None
            grad_i_flatten, _ = flatten_util.ravel_pytree(grad_i)
            return value_i, aux, grad_i_flatten

        # we add a new axis on data and labels so they have the correct
        # shape after vectorization by vmap, which removes the batch dimension
        expand_data = jnp.expand_dims(data['image'], axis=1)
        expand_labels = jnp.expand_dims(data['label'], axis=1)
        values, aux, grads = jax.vmap(loss_sample,
                                      in_axes=(0, 0))(expand_data,
                                                      expand_labels)
        grads = jax.vmap(loss_sample, in_axes=(0, 0))(expand_data,
                                                      expand_labels)[2]

        # This is too slow. Need faster implementation
        v = jnp.linalg.lstsq(grads, values)[0]

        v_tree = unravel_pytree(v)
        new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree)
        value = jnp.mean(values)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', value)

        new_state = SystemStochasticPolyakState(
            # iter_num=state.iter_num + 1, value=value, aux=aux)
            iter_num=state.iter_num + 1,
            aux=aux)
        return base.OptStep(params=new_params, state=new_state)
コード例 #5
0
    def update_jacrev_arrays_CG(self, params, state, data, *args, **kwargs):
        """Perform the update using jacrev and CG."""

        del args, kwargs  # unused
        # Currently the fastest implementation.
        batch_size = data['label'].shape[0]
        _, unravel_pytree = flatten_util.ravel_pytree(params)
        values = jnp.zeros((batch_size))

        @jax.jit
        def losses(params):
            # Currently, self.fun returns the losses BEFORE the mean reduction.
            return self.fun(params, data)[0]

        # TODO(rmgower): avoid recomputing the auxiliary output (metrics)
        aux = self.fun(params, data)[1]

        @jax.jit
        def matvec_array(u):  # Computes the product  (J J^T +delta * I)u
            out = grads @ (u @ grads) + self.delta * u
            return out

        def jacobian_builder(losses, params):
            grads_tree = jax.jacrev(losses)(params)
            grads, _ = flatten_util.ravel_pytree(grads_tree)
            grads = jnp.reshape(grads,
                                (batch_size, int(grads.shape[0] / batch_size)))
            return grads

        ## Important: This is the bottleneck cost of this update!
        grads = jacobian_builder(losses, params)

        values = losses(params)

        # Solving  v =(J J^T +delta * I)^{-1}loss
        v = linear_solve.solve_cg(matvec_array, values, init=None, maxiter=10)

        ## Builds final update v= J^T(J J^T +delta * I)^{-1}loss
        v = v @ grads

        v_tree = unravel_pytree(v)
        new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree)
        value = jnp.mean(values)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', value)

        new_state = SystemStochasticPolyakState(
            # iter_num=state.iter_num + 1, value=value, aux=aux)
            iter_num=state.iter_num + 1,
            aux=aux)
        return base.OptStep(params=new_params, state=new_state)
コード例 #6
0
    def init(self, init_params):
        """Initialize the ``(params, state)`` pair.

    Args:
      init_params: pytree containing the initial parameters.
    Return type: base.OptStep

    Returns:
      (params, state)
    """
        # state = SystemStochasticPolyakState(iter_num=0, value=jnp.inf, aux=None)
        state = SystemStochasticPolyakState(iter_num=0, aux=None)
        return base.OptStep(params=init_params, state=state)
コード例 #7
0
  def update(self,
             params,
             state,
             epoch,
             *args,
             **kwargs):
    """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """
    if self.has_aux:
      (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    else:
      value, grad = self._value_and_grad_fun(params, *args, **kwargs)
      aux = None

    gradnorm = tree_l2_norm(grad, squared=True)
    step1 = jax.nn.relu(value - state.slack + self.delta * self.lmbda) / (
        self.delta + gradnorm)
    spsstep = value / gradnorm
    step_size = jnp.minimum(step1, spsstep)
    newslack = jax.nn.relu(state.slack - self.lmbda * self.delta +
                           self.delta * step1)
    # new_params = tree_add_scalar_mul(params, -step_size, grad)

    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, -step_size, grad)
      new_velocity = None
    else:
      # new_v = momentum * v - step_size * grad
      # new_params = params + new_v
      new_velocity = tree_sub(tree_scalar_mul(self.momentum, state.velocity),
                              tree_scalar_mul(step_size, grad))
      new_params = tree_add(params, new_velocity)
    new_state = SPSL1State(
        iter_num=state.iter_num + 1,
        value=value,
        slack=newslack,
        velocity=new_velocity,
        aux=aux)
    return base.OptStep(params=new_params, state=new_state)
コード例 #8
0
  def update(self,
             params,
             state,
             data,
             *args,
             **kwargs):
    """Performs one iteration of the optax solver.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      data: dict.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """

    del args, kwargs  # unused
    (value, aux), update = self._spsdiag_update(params, data)
    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, self.learning_rate, update)
      new_velocity = None
    else:
      new_velocity = tree_sub(
          tree_scalar_mul(self.momentum, state.velocity),
          tree_scalar_mul(self.learning_rate, update))
      new_params = tree_add(params, new_velocity)

    new_params = tree_add_scalar_mul(
        params, self.learning_rate, update)
    aux['loss'] = jnp.mean(aux['loss'])
    aux['accuracy'] = jnp.mean(aux['accuracy'])

    if state.iter_num % 10 == 0:
      print('Number of iterations', state.iter_num,
            '. Objective function value: ', value)

    new_state = StochasticPolyakState(
        iter_num=state.iter_num+1, value=value, velocity=new_velocity, aux=aux)
    return base.OptStep(params=new_params, state=new_state)
コード例 #9
0
  def init(self,
           init_params):
    """Initialize the ``(params, state)`` pair.

    Args:
      init_params: pytree containing the initial parameters.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """
    if self.momentum == 0:
      velocity = None
    else:
      velocity = tree_zeros_like(init_params)

    state = StochasticPolyakState(
        iter_num=0, value=jnp.inf, velocity=velocity, aux=None)
    return base.OptStep(params=init_params, state=state)
コード例 #10
0
    def update_pytrees_CG(self, params, state, epoch, data, *args, **kwargs):
        """Solves one iteration of the system Polyak solver calling directly CG.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      data: a batch of data.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type: base.OptStep

    Returns:
      (params, state)
    """
        del epoch, args, kwargs  # unused

        def losses(params):
            # Currently, self.fun returns the losses BEFORE the mean reduction.
            return self.fun(params, data)[0]

        # TODO(rmgower): avoid recomputing the auxiliary output (metrics)
        aux = self.fun(params, data)[1]

        # get Jacobian transpose operator
        Jt = jax.vjp(losses, params)[1]

        @jax.jit
        def matvec(u):
            """Matrix-vector product.

      Args:
        u:  vectors of length batch_size

      Returns:
        K: vector (J J^T + delta * I)u  = J(J^T(u)) +delta * u
      """
            ## Important: This is slow
            Jtu = Jt(u)  # evaluate Jacobian transpose vector product
            # evaluate Jacobian-vector product
            JJtu = jax.jvp(losses, (params, ), (Jtu[0], ))[1]
            deltau = self.delta * u
            return JJtu + deltau

        ## Solve the small linear system (J J^T +delta * I)x = -loss
        ## Warning: This is the bottleneck cost
        rhs = -losses(params)
        cg_sol = linear_solve.solve_cg(matvec, rhs, init=None, maxiter=20)

        ## Builds final solution w = w - J^T(J J^T +delta * I)^{-1}loss
        rhs = -losses(params)
        Jtsol = Jt(cg_sol)[0]
        new_params = tree_util.tree_add(params, Jtsol)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', jnp.mean(-rhs))
        new_state = SystemStochasticPolyakState(iter_num=state.iter_num + 1,
                                                aux=aux)

        return base.OptStep(params=new_params, state=new_state)
コード例 #11
0
    def update_pytrees_QP(self, params, state, epoch, data, *args, **kwargs):
        """Performs one iteration of the system Polyak solver.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int, epoch number.
      data: a batch of data.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type: base.OptStep

    Returns:
      (params, state)
    """

        del epoch, args, kwargs  # unused

        # The output of losses(params) is of shape size(batch).
        # Therefore, losses is a function from size(params) to size(batch).
        # Therefore the Jacobian is of shape size(batch) x size(params).
        def losses(params):
            # Currently, self.fun returns the losses BEFORE the mean reduction.
            return self.fun(params, data)[0]

        # TODO(rmgower): avoid recomputing the auxiliary output (metrics)
        aux = self.fun(params, data)[1]

        # Solves 0.5 ||w - w^t||^2 s.t. A w = b
        # where A is of shape size(batch) x size(params) and contains the gradients
        #       b is of shape size(batch) and b[i] = A w^t - loss_values[i]
        #       w = params

        # This is equivalent to solving 0.5 w^T Q w + <c, w> s.t. A w = b,
        # where Q = Identity and c = -w^t.
        def matvec_Q(params_Q, u):
            del params_Q  # ignored
            return u

        def matvec_A(params_A, u):
            del params_A  # ignored
            return jax.jvp(losses, (params, ), (u, ))[1]

        # Since A is the Jacobian of losses, A w^t is a JVP.
        # This computes the JVP and loss values along the way.
        loss_values, Awt = jax.jvp(losses, (params, ), (params, ))
        b = Awt - loss_values
        # Rob: Double wrong! Check again
        c = tree_util.tree_scalar_mul(-1.0, params)
        params_obj = (None, c)
        params_eq = (None, b)

        # Rob: Solves for primal and dual variables,
        # thus solves very large linear system.
        qp = jaxopt.QuadraticProgramming(matvec_Q=matvec_Q,
                                         matvec_A=matvec_A,
                                         maxiter=10)
        res = qp.run(params_obj=params_obj, params_eq=params_eq).params
        # res contains both the primal and dual variables
        # but we only need the primal variable.
        new_params = res[0]

        new_state = SystemStochasticPolyakState(iter_num=state.iter_num + 1,
                                                aux=aux)
        return base.OptStep(params=new_params, state=new_state)