Exemplo n.º 1
0
    def update_arrays_CG(self, params, state, data, *args, **kwargs):
        """Perform the update using CG."""

        del kwargs  # unused
        batch_size = data['label'].shape[0]
        _, unravel_pytree = flatten_util.ravel_pytree(params)
        values = jnp.zeros((batch_size))

        @jax.jit
        def loss_sample(image, label):
            tmp_kwargs = {'data': {'image': image, 'label': label}}
            # compute a gradient on a single image/label pair
            if self.has_aux:
                # we only store the last value of aux
                (value_i, aux), grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
            else:
                value_i, grad_i = self._value_and_grad_fun(
                    params, *args, **tmp_kwargs)
                aux = None
            grad_i_flatten, _ = flatten_util.ravel_pytree(grad_i)
            return value_i, aux, grad_i_flatten

        @jax.jit
        def matvec_array(u):
            """Computes the product  (J J^T +delta * I)u ."""
            out = grads @ (u @ grads) + self.delta * u
            return out

        # We add a new axis on data and labels so they have the correct
        # shape after vectorization by vmap, which removes the batch dimension
        ## Important: This is the bottleneck cost of this update!
        expand_data = jnp.expand_dims(data['image'], axis=1)
        expand_labels = jnp.expand_dims(data['label'], axis=1)
        values, aux, grads = jax.vmap(loss_sample,
                                      in_axes=(0, 0))(expand_data,
                                                      expand_labels)
        grads = jax.vmap(loss_sample, in_axes=(0, 0))(expand_data,
                                                      expand_labels)[2]

        # Solving  v =(J J^T +delta * I)^{-1}loss
        v = linear_solve.solve_cg(matvec_array, values, init=None, maxiter=20)
        ## Builds final update v= J^T(J J^T +delta * I)^{-1}loss
        v = v @ grads

        v_tree = unravel_pytree(v)
        new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree)
        value = jnp.mean(values)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', value)

        new_state = SystemStochasticPolyakState(
            # iter_num=state.iter_num + 1, value=value, aux=aux)
            iter_num=state.iter_num + 1,
            aux=aux)
        return base.OptStep(params=new_params, state=new_state)
Exemplo n.º 2
0
    def update_jacrev_arrays_CG(self, params, state, data, *args, **kwargs):
        """Perform the update using jacrev and CG."""

        del args, kwargs  # unused
        # Currently the fastest implementation.
        batch_size = data['label'].shape[0]
        _, unravel_pytree = flatten_util.ravel_pytree(params)
        values = jnp.zeros((batch_size))

        @jax.jit
        def losses(params):
            # Currently, self.fun returns the losses BEFORE the mean reduction.
            return self.fun(params, data)[0]

        # TODO(rmgower): avoid recomputing the auxiliary output (metrics)
        aux = self.fun(params, data)[1]

        @jax.jit
        def matvec_array(u):  # Computes the product  (J J^T +delta * I)u
            out = grads @ (u @ grads) + self.delta * u
            return out

        def jacobian_builder(losses, params):
            grads_tree = jax.jacrev(losses)(params)
            grads, _ = flatten_util.ravel_pytree(grads_tree)
            grads = jnp.reshape(grads,
                                (batch_size, int(grads.shape[0] / batch_size)))
            return grads

        ## Important: This is the bottleneck cost of this update!
        grads = jacobian_builder(losses, params)

        values = losses(params)

        # Solving  v =(J J^T +delta * I)^{-1}loss
        v = linear_solve.solve_cg(matvec_array, values, init=None, maxiter=10)

        ## Builds final update v= J^T(J J^T +delta * I)^{-1}loss
        v = v @ grads

        v_tree = unravel_pytree(v)
        new_params = tree_util.tree_add_scalar_mul(params, -1.0, v_tree)
        value = jnp.mean(values)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', value)

        new_state = SystemStochasticPolyakState(
            # iter_num=state.iter_num + 1, value=value, aux=aux)
            iter_num=state.iter_num + 1,
            aux=aux)
        return base.OptStep(params=new_params, state=new_state)
Exemplo n.º 3
0
    def update_pytrees_CG(self, params, state, epoch, data, *args, **kwargs):
        """Solves one iteration of the system Polyak solver calling directly CG.

    Args:
      params: pytree containing the parameters.
      state: named tuple containing the solver state.
      epoch: int.
      data: a batch of data.
      *args: additional positional arguments to be passed to ``fun``.
      **kwargs: additional keyword arguments to be passed to ``fun``.
    Return type: base.OptStep

    Returns:
      (params, state)
    """
        del epoch, args, kwargs  # unused

        def losses(params):
            # Currently, self.fun returns the losses BEFORE the mean reduction.
            return self.fun(params, data)[0]

        # TODO(rmgower): avoid recomputing the auxiliary output (metrics)
        aux = self.fun(params, data)[1]

        # get Jacobian transpose operator
        Jt = jax.vjp(losses, params)[1]

        @jax.jit
        def matvec(u):
            """Matrix-vector product.

      Args:
        u:  vectors of length batch_size

      Returns:
        K: vector (J J^T + delta * I)u  = J(J^T(u)) +delta * u
      """
            ## Important: This is slow
            Jtu = Jt(u)  # evaluate Jacobian transpose vector product
            # evaluate Jacobian-vector product
            JJtu = jax.jvp(losses, (params, ), (Jtu[0], ))[1]
            deltau = self.delta * u
            return JJtu + deltau

        ## Solve the small linear system (J J^T +delta * I)x = -loss
        ## Warning: This is the bottleneck cost
        rhs = -losses(params)
        cg_sol = linear_solve.solve_cg(matvec, rhs, init=None, maxiter=20)

        ## Builds final solution w = w - J^T(J J^T +delta * I)^{-1}loss
        rhs = -losses(params)
        Jtsol = Jt(cg_sol)[0]
        new_params = tree_util.tree_add(params, Jtsol)

        if state.iter_num % 10 == 0:
            print('Number of iterations', state.iter_num,
                  '. Objective function value: ', jnp.mean(-rhs))
        new_state = SystemStochasticPolyakState(iter_num=state.iter_num + 1,
                                                aux=aux)

        return base.OptStep(params=new_params, state=new_state)