示例#1
0
    def line_search(kwargs):
        x = kwargs['x']
        y = kwargs['y']
        w = kwargs['w']
        v = kwargs['v']
        q = w + v
        z = x + y
        B = M - z
        A = q - z

        step_size = torch.clamp(utils.bdiv(utils.bdot(A, B), utils.bdot(A, A)),
                                max=1.)
        assert (step_size >= 0).all()
        return step_size
示例#2
0
def backtracking_pgd(closure,
                     prox,
                     step_size,
                     x,
                     grad,
                     increase=1.01,
                     decrease=.6,
                     max_iter_backtracking=1000):

    batch_size = x.size(0)
    rhs = -np.inf * torch.ones(batch_size)
    lhs = np.inf * torch.ones(batch_size)

    need_to_backtrack = lhs > rhs

    while (~need_to_backtrack).any():
        step_size[~need_to_backtrack] *= increase

    while need_to_backtrack.any():
        with torch.no_grad():
            x_candidate = prox(x - utils.bmul(step_size, grad), step_size)

        lhs = closure(x_candidate, return_jac=False)
        rhs = (closure(x, return_jac=False) -
               utils.bdot(grad, x - x_candidate) + utils.bmul(
                   1. / (2 * step_size),
                   torch.norm(
                       (x - x_candidate).view(x.size(0), -1), dim=-1)))**2
示例#3
0
    def prox(self, x, step_size=None):
        """
        Projects `x` batch-wise onto the cone constraint.
        Args:
          x: torch.Tensor of shape (batch_size, *)
            batch of vectors to project
          step_size: Any
            Not used
        Returns:
          proj_x: torch.Tensor of shape (batch_size, *)
            batch-wise projection of `x` onto the cone constraint.
        """
        batch_size = x.size(0)
        uTx = utils.bdot(self.directions, x)
        p_u = self.proj_u(x)
        p_orth_u = x - p_u
        norm_p_orth_u = torch.norm(p_orth_u.reshape(batch_size, -1), dim=-1)
        identity_idx = (norm_p_orth_u <= self.alpha * uTx)
        zero_idx = (self.alpha * norm_p_orth_u <= -uTx)
        project_idx = ~torch.logical_or(identity_idx, zero_idx)

        res = x.detach().clone()
        res[zero_idx] = 0.
        res[project_idx] = utils.bmul(
            (self.alpha * norm_p_orth_u[project_idx] + uTx[project_idx]) /
            (1. + self.alpha**2), (self.alpha * utils.bmul(
                p_orth_u[project_idx], 1 / norm_p_orth_u[project_idx]) +
                                   self.directions[project_idx]))
        return res
示例#4
0
    def proj_u(self, x, step_size=None):
        """
        Projects x on self.directions batch-wise
        Args:
          x: torch.Tensor of shape (batch_size, *)
            vectors to project
          step_size: Any
            Not used
        Returns:
          proj_x: torch.Tensor of shape (batch_size, *)
            batch-wise projection of x onto self.directions
        """

        return utils.bmul(utils.bdot(x, self.directions), self.directions)
示例#5
0
def test_cone_constraint():
    # Standard second order cone
    u = torch.tensor([[0., 0., 1.]])
    cos_alpha = .5

    cone = constraints.Cone(u, cos_alpha)

    for inp, correct_prox in [
        (torch.tensor([[1., 0, 0]]), torch.tensor([[.5, 0, .5]])),
        (torch.tensor([[0, 1., 0]]), torch.tensor([[0, .5, .5]])), (u, u),
        (-u, torch.zeros_like(u))
    ]:
        assert cone.prox(inp).eq(correct_prox).all()

    # Moreau decomposition: x = proj_x + (x - proj_x) where these two vectors are orthogonal
    for _ in range(10):
        x = torch.rand(*u.shape)
        proj_x = cone.prox(x)
        assert utils.bdot(x - proj_x, proj_x).allclose(torch.zeros_like(x),
                                                       atol=4e-7)
示例#6
0
def minimize_frank_wolfe(closure,
                         x0,
                         lmo,
                         step='sublinear',
                         max_iter=200,
                         callback=None):
    """Performs the Frank-Wolfe algorithm on a batch of objectives of the form
      min_x f(x)
      s.t. x in C

    where we have access to the Linear Minimization Oracle (LMO) of the constraint set C,
    and the gradient of f through closure.

    Args:
      closure: callable
        gives function values and the jacobian of f.

      x0: torch.Tensor of shape (batch_size, *).
        initial guess

      lmo: callable
        Returns update_direction, max_step_size

      step: float or 'sublinear'
        step-size scheme to be used.

      max_iter: int
        max number of iterations.

      callback: callable
        (optional) Any callable called on locals() at the end of each iteration.
        Often used for logging.
    """
    x = x0.detach().clone()
    batch_size = x.size(0)
    if not (isinstance(step, Number) or step == 'sublinear'):
        raise ValueError("step must be a float or 'sublinear'.")

    if isinstance(step, Number):
        step_size = step * torch.ones(
            batch_size, device=x.device, dtype=x.dtype)

    cert = np.inf * torch.ones(batch_size, device=x.device)

    for it in range(max_iter):

        x.requires_grad = True
        fval, grad = closure(x)
        update_direction, max_step_size = lmo(-grad, x)
        cert = utils.bdot(-grad, update_direction)

        if step == 'sublinear':
            step_size = 2. / (it + 2) * torch.ones(
                batch_size, dtype=x.dtype, device=x.device)

        with torch.no_grad():
            step_size = torch.min(step_size, max_step_size)
            x += utils.bmul(update_direction, step_size)

        if callback is not None:
            if callback(locals()) is False:
                break

    fval, grad = closure(x)
    return optimize.OptimizeResult(x=x,
                                   nit=it,
                                   fval=fval,
                                   grad=grad,
                                   certificate=cert)
示例#7
0
def minimize_pgd(closure,
                 x0,
                 prox,
                 step='backtracking',
                 max_iter=200,
                 max_iter_backtracking=1000,
                 backtracking_factor=.6,
                 tol=1e-8,
                 *prox_args,
                 callback=None):
    """
    Performs Projected Gradient Descent on batch of objectives of form:
      f(x) + g(x).
    We suppose we have access to gradient computation for f through closure,
    and to the proximal operator of g in prox.

    Args:
      closure: callable

      x0: torch.Tensor of shape (batch_size, *).

      prox: callable
        proximal operator of g

      step: 'backtracking' or float or torch.tensor of shape (batch_size,) or None.
        step size to be used. If None, will be estimated at the beginning
        using line search.
        If 'backtracking', will be estimated at each step using backtracking line search.

      max_iter: int
        number of iterations to perform.

      max_iter_backtracking: int
        max number of iterations in the backtracking line search

      backtracking_factor: float
        factor by which to multiply the step sizes during line search

      tol: float
        stops the algorithm when the certificate is smaller than tol
        for all datapoints in the batch

      prox_args: tuple
        (optional) additional args for prox

      callback: callable
        (optional) Any callable called on locals() at the end of each iteration.
        Often used for logging.
    """
    x = x0.detach().clone()
    batch_size = x.size(0)

    if step is None:
        # estimate lipschitz constant
        L = utils.init_lipschitz(closure, x0)
        step_size = 1. / L

    elif step == 'backtracking':
        L = 1.8 * utils.init_lipschitz(closure, x0)
        step_size = 1. / L

    elif type(step) == float:
        step_size = step * torch.ones(batch_size, device=x.device)

    else:
        raise ValueError("step must be float or backtracking or None")

    for it in range(max_iter):
        x.requires_grad = True

        fval, grad = closure(x)

        x_next = prox(x - utils.bmul(step_size, grad), step_size, *prox_args)
        update_direction = x_next - x

        if step == 'backtracking':
            step_size *= 1.1
            mask = torch.ones(batch_size, dtype=bool, device=x.device)

            with torch.no_grad():
                for _ in range(max_iter_backtracking):
                    f_next = closure(x_next, return_jac=False)
                    rhs = (fval + utils.bdot(grad, update_direction) +
                           utils.bmul(
                               utils.bdot(update_direction, update_direction),
                               1. / (2. * step_size)))
                    mask = f_next > rhs

                    if not mask.any():
                        break

                    step_size[mask] *= backtracking_factor
                    x_next = prox(x - utils.bmul(step_size, grad),
                                  step_size[mask], *prox_args)
                    update_direction[mask] = x_next[mask] - x[mask]
                else:
                    warnings.warn("Maximum number of line-search iterations "
                                  "reached.")

        with torch.no_grad():
            cert = torch.norm(utils.bmul(update_direction, 1. / step_size),
                              dim=-1)
            x.copy_(x_next)
            if (cert < tol).all():
                break

        if callback is not None:
            if callback(locals()) is False:
                break

    fval, grad = closure(x)
    return optimize.OptimizeResult(x=x,
                                   nit=it,
                                   fval=fval,
                                   grad=grad,
                                   certificate=cert)
示例#8
0
def minimize_three_split(closure,
                         x0,
                         prox1=None,
                         prox2=None,
                         tol=1e-6,
                         max_iter=1000,
                         verbose=0,
                         callback=None,
                         line_search=True,
                         step=None,
                         max_iter_backtracking=100,
                         backtracking_factor=0.7,
                         h_Lipschitz=None,
                         *args_prox):
    """Davis-Yin three operator splitting method.
    This algorithm can solve problems of the form

                minimize_x f(x) + g(x) + h(x)

    where f is a smooth function and g and h are (possibly non-smooth)
    functions for which the proximal operator is known.

    Remark: this method returns x = prox1(...). If g and h are two indicator
      functions, this method only garantees that x is feasible for the first.
      Therefore if one of the constraints is a hard constraint,
      make sure to pass it to prox1.

    Args:
      closure: callable
        Returns the function values and gradient of the objective function.
        With return_gradient=False, returns only the function values.
        Shape of return value: (batch_size, *)

      x0 : torch.Tensor(shape: (batch_size, *))
        Initial guess

      prox1 : callable or None
        prox1(x, step_size, *args) returns the proximal operator of g at xa
        with parameter step_size.
        step_size can be a scalar or of shape (batch_size,).

      prox2 : callable or None
        prox2(x, step_size, *args) returns the proximal operator of g at xa
        with parameter step_size.
        alpha can be a scalar or of shape (batch_size,).

      tol: float
        Tolerance of the stopping criterion.

      max_iter : int
        Maximum number of iterations.

      verbose : int
        Verbosity level, from 0 (no output) to 2 (output on each iteration)

      callback : callable.
        callback function (optional).
        Called with locals() at each step of the algorithm.
        The algorithm will exit if callback returns False.

      line_search : boolean
        Whether to perform line-search to estimate the step sizes.

      step_size : float or tensor(shape: (batch_size,)) or None
        Starting value(s) for the line-search procedure.
        if None, step_size will be estimated for each datapoint in the batch.

      max_iter_backtracking: int
        maximun number of backtracking iterations.  Used in line search.

      backtracking_factor: float
        the amount to backtrack by during line search.

      args_prox: iterable
        (optional) Extra arguments passed to the prox functions.

      kwargs_prox: dict
        (optional) Extra keyword arguments passed to the prox functions.


    Returns:
      res : OptimizeResult
        The optimization result represented as a
        ``scipy.optimize.OptimizeResult`` object. Important attributes are:
        ``x`` the solution tensor, ``success`` a Boolean flag indicating if
        the optimizer exited successfully and ``message`` which describes
        the cause of the termination. See `scipy.optimize.OptimizeResult`
        for a description of other attributes.
    """

    success = torch.zeros(x0.size(0), dtype=bool)
    if not max_iter_backtracking > 0:
        raise ValueError("Line search iterations need to be greater than 0")

    LS_EPS = np.finfo(np.float).eps

    if prox1 is None:

        @torch.no_grad()
        def prox1(x, s=None, *args):
            return x

    if prox2 is None:

        @torch.no_grad()
        def prox2(x, s=None, *args):
            return x

    x = x0.detach().clone().requires_grad_(True)
    batch_size = x.size(0)

    if step is None:
        line_search = True
        step_size = 1.0 / utils.init_lipschitz(closure, x)

    elif isinstance(step, Number):
        step_size = step * torch.ones(
            batch_size, device=x.device, dtype=x.dtype)

    else:
        raise ValueError("step must be float or None.")

    z = prox2(x, step_size, *args_prox)
    z = z.clone().detach()
    z.requires_grad_(True)

    fval, grad = closure(z)

    x = prox1(z - utils.bmul(step_size, grad), step_size, *args_prox)
    u = torch.zeros_like(x)

    for it in range(max_iter):
        z.requires_grad_(True)
        fval, grad = closure(z)
        with torch.no_grad():
            x = prox1(z - utils.bmul(step_size, u + grad), step_size,
                      *args_prox)
            incr = x - z
            norm_incr = torch.norm(incr.view(incr.size(0), -1), dim=-1)
            rhs = fval + utils.bdot(grad, incr) + ((norm_incr**2) /
                                                   (2 * step_size))
            ls_tol = closure(x, return_jac=False)
            mask = torch.bitwise_and(norm_incr > 1e-7, line_search)
            ls = mask.detach().clone()
            # TODO: optimize code in this loop using mask
            for it_ls in range(max_iter_backtracking):
                if not (mask.any()):
                    break
                rhs[mask] = fval[mask] + utils.bdot(grad[mask], incr[mask])
                rhs[mask] += utils.bmul(norm_incr[mask]**2,
                                        1. / (2 * step_size[mask]))

                ls_tol[mask] = closure(x, return_jac=False)[mask] - rhs[mask]
                mask &= (ls_tol > LS_EPS)
                step_size[mask] *= backtracking_factor

            z = prox2(x + utils.bmul(step_size, u), step_size, *args_prox)
            u += utils.bmul(x - z, 1. / step_size)
            certificate = utils.bmul(norm_incr, 1. / step_size)

        if callback is not None:
            if callback(locals()) is False:
                break

        success = torch.bitwise_and(certificate < tol, it > 0)
        if success.all():
            break

    return optimize.OptimizeResult(x=x,
                                   success=success,
                                   nit=it,
                                   fval=fval,
                                   certificate=certificate)
示例#9
0
 def is_feasible(self, x, rtol=1e-5, atol=1e-7):
     cosines = utils.bdot(x, self.directions)
     return abs(
         cosines) >= utils.bnorm(x) * self.cos_angle * (1. + rtol) + atol
示例#10
0
 def fw_gap(self, grad, iterate):
     update_direction, _ = self.lmo(-grad, iterate)
     return utils.bdot(-grad, update_direction)