Ejemplo n.º 1
0
    def update(self, params, state, epoch, *args, **kwargs):
        """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: number of epoch.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """
        del epoch  # unused
        if self.lmbda == 1:
            raise ValueError(
                'lmbda =1 was passed to SPSsqrt solver. This solver does not work with lmbda =1 because then the parameters are never updated! '
            )

        if self.has_aux:
            (value,
             aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
        else:
            value, grad = self._value_and_grad_fun(params, *args, **kwargs)
            aux = None

        # If slack hits zero, reset to be the current value.
        # This stops the method from halting.
        if state.slack == 0.0:
            state = state._replace(slack=value)
        ## The mathematical expression of the this update is:
        # step = (value - (1-lmbda/2) sqrt(s))_+) / (4s||grad||^2 + 1 - lmbda)
        # w = w -  4 step s *grad,
        # s = (1-lmbda)*sqrt(s)*(sqrt(s) + step)
        step_size = jax.nn.relu(
            value - (1 - self.lmbda / 2) * jnp.sqrt(state.slack)) / (
                4 * state.slack * tree_l2_norm(grad, squared=True) + 1 -
                self.lmbda)
        newslack = (1 - self.lmbda) * jnp.sqrt(
            state.slack) * (jnp.sqrt(state.slack) + step_size)
        step_size = 4 * state.slack * step_size

        if self.momentum == 0:
            new_params = tree_add_scalar_mul(params, -step_size, grad)
            new_velocity = None
        else:
            # new_v = momentum * v - step_size * grad
            # new_params = params + new_v
            new_velocity = tree_sub(
                tree_scalar_mul(self.momentum, state.velocity),
                tree_scalar_mul(step_size, grad))
            new_params = tree_add(params, new_velocity)

        new_state = SPSsqrtState(iter_num=state.iter_num + 1,
                                 value=value,
                                 slack=newslack,
                                 velocity=new_velocity,
                                 aux=aux)
        return base.OptStep(params=new_params, state=new_state)
Ejemplo n.º 2
0
  def update(self,
             params,
             state,
             epoch,
             *args,
             **kwargs):
    """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.

    Returns:
      (params, state)
    """
    if self.has_aux:
      (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    else:
      value, grad = self._value_and_grad_fun(params, *args, **kwargs)
      aux = None

  # Currently experimenting with decreasing lambda slowly after many iterations.
  # The intuition behind this is that in the early iterations SPS (self.lmbda=1)
  # works well. But in later iterations the slack helpd stabilize.
    late_start = 10
    if self.lmbda_schedule and epoch > late_start:
      lmbdat = self.lmbda/(jnp.log(jnp.log(epoch-late_start+1)+1)+1)
    else:
      lmbdat = self.lmbda

    ## Mathematical description on this step size:
    # step_size = (f_i(w^t) - (1-lmbda) s)_+) / (||grad||^2 + 1 - lmbda)
    step_size = jax.nn.relu(value - (1-lmbdat)*state.slack)/(tree_l2_norm(
        grad, squared=True) + 1 - lmbdat)
    newslack = (1 - lmbdat) * (state.slack + step_size)
    # new_params = tree_add_scalar_mul(params, -step_size, grad)

    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, -step_size, grad)
      new_velocity = None
    else:
      # new_v = momentum * v - step_size * grad
      # new_params = params + new_v
      new_velocity = tree_sub(
          tree_scalar_mul(self.momentum, state.velocity),
          tree_scalar_mul(step_size, grad))
      new_params = tree_add(params, new_velocity)

    new_state = SPSDamState(
        iter_num=state.iter_num + 1,
        value=value,
        slack=newslack,
        velocity=new_velocity,
        aux=aux)
    return base.OptStep(params=new_params, state=new_state)
Ejemplo n.º 3
0
    def update_arrays_CG(self, params, state, data, *args, **kwargs):
        """Perform the update using CG."""

        del kwargs  # unused
        batch_size = data['label'].shape[0]
        _, unravel_pytree = flatten_util.ravel_pytree(params)
        values = jnp.zeros((batch_size))

        @jax.jit
        def loss_sample(image, label):
            tmp_kwargs = {'data': {'image': image, 'label': label}}
            # compute a gradient on a single image/label pair
            if self.has_aux:
                # we only store the last value of aux
                (value_i, aux), grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
            else:
                value_i, grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
                aux = None
            grad_i_flatten, _ = flatten_util.ravel_pytree(grad_i)
            return value_i, aux, grad_i_flatten

        @jax.jit
        def matvec_array(u):
            """Computes the product  (J J^T +delta * I)u ."""
            out = grads @ (u @ grads) + self.delta * u
            return out

        # We add a new axis on data and labels so they have the correct
        # shape after vectorization by vmap, which removes the batch dimension
        ## Important: This is the bottleneck cost of this update!
        expand_data = jnp.expand_dims(data['image'], axis=1)
        expand_labels = jnp.expand_dims(data['label'], axis=1)
        values, aux, grads = jax.vmap(loss_sample,
                                      in_axes=(0, 0))(expand_data,
                                                      expand_labels)
        grads = jax.vmap(loss_sample, in_axes=(0, 0))(expand_data,
                                                      expand_labels)[2]

        # Solving  v =(J J^T +delta * I)^{-1}loss
        v = linear_solve.solve_cg(matvec_array, values, init=None, maxiter=20)
        ## Builds final update v= J^T(J J^T +delta * I)^{-1}loss
        v = v @ grads

        v_tree = unravel_pytree(v)
        new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree)
        value = jnp.mean(values)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', value)

        new_state = SystemStochasticPolyakState(
            # iter_num=state.iter_num + 1, value=value, aux=aux)
            iter_num=state.iter_num + 1,
            aux=aux)
        return base.OptStep(params=new_params, state=new_state)
  def update(self,
             params,
             state,
             data,
             *args,
             **kwargs):
    """Performs one iteration of the optax solver.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      data: dict.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """

    del args, kwargs  # unused
    (value, aux), update = self._spsdiag_update(params, data)
    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, self.learning_rate, update)
      new_velocity = None
    else:
      new_velocity = tree_sub(
          tree_scalar_mul(self.momentum, state.velocity),
          tree_scalar_mul(self.learning_rate, update))
      new_params = tree_add(params, new_velocity)

    new_params = tree_add_scalar_mul(
        params, self.learning_rate, update)
    aux['loss'] = jnp.mean(aux['loss'])
    aux['accuracy'] = jnp.mean(aux['accuracy'])

    if state.iter_num % 10 == 0:
      print('Number of iterations', state.iter_num,
            '. Objective function value: ', value)

    new_state = StochasticPolyakState(
        iter_num=state.iter_num+1, value=value, velocity=new_velocity, aux=aux)
    return base.OptStep(params=new_params, state=new_state)
Ejemplo n.º 5
0
    def update_arrays_lstsq(self, params, state, data, *args, **kwargs):
        """Perform the update using a least square solver."""

        del kwargs  # unused
        # This version makes use of the least-squares solver jnp.linalg.lstsq
        # which has two problems
        # 1. It is too slow because it computes a full svd (overkill) to solve
        # the systems
        # 2. It has no support for regularization
        batch_size = data['label'].shape[0]
        _, unravel_pytree = flatten_util.ravel_pytree(params)
        values = jnp.zeros((batch_size))

        @jax.jit
        def loss_sample(image, label):
            tmp_kwargs = {'data': {'image': image, 'label': label}}
            # compute a gradient on a single image/label pair
            if self.has_aux:
                # we only store the last value of aux
                (value_i, aux), grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
            else:
                value_i, grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
                aux = None
            grad_i_flatten, _ = flatten_util.ravel_pytree(grad_i)
            return value_i, aux, grad_i_flatten

        # we add a new axis on data and labels so they have the correct
        # shape after vectorization by vmap, which removes the batch dimension
        expand_data = jnp.expand_dims(data['image'], axis=1)
        expand_labels = jnp.expand_dims(data['label'], axis=1)
        values, aux, grads = jax.vmap(loss_sample,
                                      in_axes=(0, 0))(expand_data,
                                                      expand_labels)
        grads = jax.vmap(loss_sample, in_axes=(0, 0))(expand_data,
                                                      expand_labels)[2]

        # This is too slow. Need faster implementation
        v = jnp.linalg.lstsq(grads, values)[0]

        v_tree = unravel_pytree(v)
        new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree)
        value = jnp.mean(values)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', value)

        new_state = SystemStochasticPolyakState(
            # iter_num=state.iter_num + 1, value=value, aux=aux)
            iter_num=state.iter_num + 1,
            aux=aux)
        return base.OptStep(params=new_params, state=new_state)
Ejemplo n.º 6
0
    def update_jacrev_arrays_CG(self, params, state, data, *args, **kwargs):
        """Perform the update using jacrev and CG."""

        del args, kwargs  # unused
        # Currently the fastest implementation.
        batch_size = data['label'].shape[0]
        _, unravel_pytree = flatten_util.ravel_pytree(params)
        values = jnp.zeros((batch_size))

        @jax.jit
        def losses(params):
            # Currently, self.fun returns the losses BEFORE the mean reduction.
            return self.fun(params, data)[0]

        # TODO(rmgower): avoid recomputing the auxiliary output (metrics)
        aux = self.fun(params, data)[1]

        @jax.jit
        def matvec_array(u):  # Computes the product  (J J^T +delta * I)u
            out = grads @ (u @ grads) + self.delta * u
            return out

        def jacobian_builder(losses, params):
            grads_tree = jax.jacrev(losses)(params)
            grads, _ = flatten_util.ravel_pytree(grads_tree)
            grads = jnp.reshape(grads,
                                (batch_size, int(grads.shape[0] / batch_size)))
            return grads

        ## Important: This is the bottleneck cost of this update!
        grads = jacobian_builder(losses, params)

        values = losses(params)

        # Solving  v =(J J^T +delta * I)^{-1}loss
        v = linear_solve.solve_cg(matvec_array, values, init=None, maxiter=10)

        ## Builds final update v= J^T(J J^T +delta * I)^{-1}loss
        v = v @ grads

        v_tree = unravel_pytree(v)
        new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree)
        value = jnp.mean(values)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', value)

        new_state = SystemStochasticPolyakState(
            # iter_num=state.iter_num + 1, value=value, aux=aux)
            iter_num=state.iter_num + 1,
            aux=aux)
        return base.OptStep(params=new_params, state=new_state)
Ejemplo n.º 7
0
  def update(self,
             params,
             state,
             epoch,
             *args,
             **kwargs):
    """Perform one update of the algorithm.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type:
      base.OptStep
    Returns:
      (params, state)
    """
    if self.has_aux:
      (value, aux), grad = self._value_and_grad_fun(params, *args, **kwargs)
    else:
      value, grad = self._value_and_grad_fun(params, *args, **kwargs)
      aux = None

    gradnorm = tree_l2_norm(grad, squared=True)
    step1 = jax.nn.relu(value - state.slack + self.delta * self.lmbda) / (
        self.delta + gradnorm)
    spsstep = value / gradnorm
    step_size = jnp.minimum(step1, spsstep)
    newslack = jax.nn.relu(state.slack - self.lmbda * self.delta +
                           self.delta * step1)
    # new_params = tree_add_scalar_mul(params, -step_size, grad)

    if self.momentum == 0:
      new_params = tree_add_scalar_mul(params, -step_size, grad)
      new_velocity = None
    else:
      # new_v = momentum * v - step_size * grad
      # new_params = params + new_v
      new_velocity = tree_sub(tree_scalar_mul(self.momentum, state.velocity),
                              tree_scalar_mul(step_size, grad))
      new_params = tree_add(params, new_velocity)
    new_state = SPSL1State(
        iter_num=state.iter_num + 1,
        value=value,
        slack=newslack,
        velocity=new_velocity,
        aux=aux)
    return base.OptStep(params=new_params, state=new_state)
Ejemplo n.º 8
0
def projection_halfspace(x,  a, b):
  r"""Projection onto a halfspace defined by a pytree and scalar.

  The output is:
    ``argmin_{y, dot(a, y) <= b} ||y - x||``.
  Args:
    x: pytree to project.
    a: pytree
    b: pytree

  Returns:
    y: output array (same shape as ``x``)
  """
  # a, b = hyperparams
  scale = jax.nn.relu(tree_util.tree_vdot(a, x) - b) / tree_util.tree_vdot(a, a)
  return tree_util.tree_add_scalar_mul(x, -scale, a)
Ejemplo n.º 9
0
def projection_hyperplane(a, b, x = None):
  r"""Projection onto a hyperplane defined by a pytree and scalar.

  The output is:
    ``argmin_{y, dot(a, y) = b} ||y - x||``.
  Which is equivalent to
     y = x - (<a,x>-b)/<a,a> a
  Args:
    x: pytree to project.
    hyperparams: tuple ``hyperparams = (a, b)``, where ``a`` is a pytree and
      ``b`` is a scalar.
  Returns:
    y: output array (same shape as ``x``)
  """
  if x is None:
    scale = b/tree_util.tree_vdot(a,a)
    return tree_util.tree_scalar_mul(scale, a)
  else:
    scale = (tree_util.tree_vdot(a,x) -b)/tree_util.tree_vdot(a,a)
    return tree_util.tree_add_scalar_mul(x, -scale, a)