def test_init_lipschitz(): criterion = nn.MSELoss(reduction='none') @closure def loss_fun(X): return criterion(X.mv(w), y) L = utils.init_lipschitz(loss_fun, X.detach().clone().requires_grad_(True)) print(L)
def minimize_pgd_madry(closure, x0, prox, lmo, step=None, max_iter=200, prox_args=(), callback=None): x = x0.detach().clone() batch_size = x.size(0) if step is None: # estimate lipschitz constant # TODO: this is not the optimal step-size (if there even is one.) # I don't recommend to use this. L = utils.init_lipschitz(closure, x0) step_size = 1. / L elif isinstance(step, Number): step_size = torch.ones(batch_size, device=x.device) * step elif isinstance(step, torch.Tensor): step_size = step else: raise ValueError( f"step must be a number or a torch Tensor, got {step} instead") for it in range(max_iter): x.requires_grad = True _, grad = closure(x) with torch.no_grad(): update_direction, _ = lmo(-grad, x) update_direction += x x = prox(x + utils.bmul(step_size, update_direction), step_size, *prox_args) if callback is not None: if callback(locals()) is False: break fval, grad = closure(x) return optimize.OptimizeResult(x=x, nit=it, fval=fval, grad=grad)
def minimize_pgd(closure, x0, prox, step='backtracking', max_iter=200, max_iter_backtracking=1000, backtracking_factor=.6, tol=1e-8, *prox_args, callback=None): """ Performs Projected Gradient Descent on batch of objectives of form: f(x) + g(x). We suppose we have access to gradient computation for f through closure, and to the proximal operator of g in prox. Args: closure: callable x0: torch.Tensor of shape (batch_size, *). prox: callable proximal operator of g step: 'backtracking' or float or torch.tensor of shape (batch_size,) or None. step size to be used. If None, will be estimated at the beginning using line search. If 'backtracking', will be estimated at each step using backtracking line search. max_iter: int number of iterations to perform. max_iter_backtracking: int max number of iterations in the backtracking line search backtracking_factor: float factor by which to multiply the step sizes during line search tol: float stops the algorithm when the certificate is smaller than tol for all datapoints in the batch prox_args: tuple (optional) additional args for prox callback: callable (optional) Any callable called on locals() at the end of each iteration. Often used for logging. """ x = x0.detach().clone() batch_size = x.size(0) if step is None: # estimate lipschitz constant L = utils.init_lipschitz(closure, x0) step_size = 1. / L elif step == 'backtracking': L = 1.8 * utils.init_lipschitz(closure, x0) step_size = 1. / L elif type(step) == float: step_size = step * torch.ones(batch_size, device=x.device) else: raise ValueError("step must be float or backtracking or None") for it in range(max_iter): x.requires_grad = True fval, grad = closure(x) x_next = prox(x - utils.bmul(step_size, grad), step_size, *prox_args) update_direction = x_next - x if step == 'backtracking': step_size *= 1.1 mask = torch.ones(batch_size, dtype=bool, device=x.device) with torch.no_grad(): for _ in range(max_iter_backtracking): f_next = closure(x_next, return_jac=False) rhs = (fval + utils.bdot(grad, update_direction) + utils.bmul( utils.bdot(update_direction, update_direction), 1. / (2. * step_size))) mask = f_next > rhs if not mask.any(): break step_size[mask] *= backtracking_factor x_next = prox(x - utils.bmul(step_size, grad), step_size[mask], *prox_args) update_direction[mask] = x_next[mask] - x[mask] else: warnings.warn("Maximum number of line-search iterations " "reached.") with torch.no_grad(): cert = torch.norm(utils.bmul(update_direction, 1. / step_size), dim=-1) x.copy_(x_next) if (cert < tol).all(): break if callback is not None: if callback(locals()) is False: break fval, grad = closure(x) return optimize.OptimizeResult(x=x, nit=it, fval=fval, grad=grad, certificate=cert)
def minimize_three_split(closure, x0, prox1=None, prox2=None, tol=1e-6, max_iter=1000, verbose=0, callback=None, line_search=True, step=None, max_iter_backtracking=100, backtracking_factor=0.7, h_Lipschitz=None, *args_prox): """Davis-Yin three operator splitting method. This algorithm can solve problems of the form minimize_x f(x) + g(x) + h(x) where f is a smooth function and g and h are (possibly non-smooth) functions for which the proximal operator is known. Remark: this method returns x = prox1(...). If g and h are two indicator functions, this method only garantees that x is feasible for the first. Therefore if one of the constraints is a hard constraint, make sure to pass it to prox1. Args: closure: callable Returns the function values and gradient of the objective function. With return_gradient=False, returns only the function values. Shape of return value: (batch_size, *) x0 : torch.Tensor(shape: (batch_size, *)) Initial guess prox1 : callable or None prox1(x, step_size, *args) returns the proximal operator of g at xa with parameter step_size. step_size can be a scalar or of shape (batch_size,). prox2 : callable or None prox2(x, step_size, *args) returns the proximal operator of g at xa with parameter step_size. alpha can be a scalar or of shape (batch_size,). tol: float Tolerance of the stopping criterion. max_iter : int Maximum number of iterations. verbose : int Verbosity level, from 0 (no output) to 2 (output on each iteration) callback : callable. callback function (optional). Called with locals() at each step of the algorithm. The algorithm will exit if callback returns False. line_search : boolean Whether to perform line-search to estimate the step sizes. step_size : float or tensor(shape: (batch_size,)) or None Starting value(s) for the line-search procedure. if None, step_size will be estimated for each datapoint in the batch. max_iter_backtracking: int maximun number of backtracking iterations. Used in line search. backtracking_factor: float the amount to backtrack by during line search. args_prox: iterable (optional) Extra arguments passed to the prox functions. kwargs_prox: dict (optional) Extra keyword arguments passed to the prox functions. Returns: res : OptimizeResult The optimization result represented as a ``scipy.optimize.OptimizeResult`` object. Important attributes are: ``x`` the solution tensor, ``success`` a Boolean flag indicating if the optimizer exited successfully and ``message`` which describes the cause of the termination. See `scipy.optimize.OptimizeResult` for a description of other attributes. """ success = torch.zeros(x0.size(0), dtype=bool) if not max_iter_backtracking > 0: raise ValueError("Line search iterations need to be greater than 0") LS_EPS = np.finfo(np.float).eps if prox1 is None: @torch.no_grad() def prox1(x, s=None, *args): return x if prox2 is None: @torch.no_grad() def prox2(x, s=None, *args): return x x = x0.detach().clone().requires_grad_(True) batch_size = x.size(0) if step is None: line_search = True step_size = 1.0 / utils.init_lipschitz(closure, x) elif isinstance(step, Number): step_size = step * torch.ones( batch_size, device=x.device, dtype=x.dtype) else: raise ValueError("step must be float or None.") z = prox2(x, step_size, *args_prox) z = z.clone().detach() z.requires_grad_(True) fval, grad = closure(z) x = prox1(z - utils.bmul(step_size, grad), step_size, *args_prox) u = torch.zeros_like(x) for it in range(max_iter): z.requires_grad_(True) fval, grad = closure(z) with torch.no_grad(): x = prox1(z - utils.bmul(step_size, u + grad), step_size, *args_prox) incr = x - z norm_incr = torch.norm(incr.view(incr.size(0), -1), dim=-1) rhs = fval + utils.bdot(grad, incr) + ((norm_incr**2) / (2 * step_size)) ls_tol = closure(x, return_jac=False) mask = torch.bitwise_and(norm_incr > 1e-7, line_search) ls = mask.detach().clone() # TODO: optimize code in this loop using mask for it_ls in range(max_iter_backtracking): if not (mask.any()): break rhs[mask] = fval[mask] + utils.bdot(grad[mask], incr[mask]) rhs[mask] += utils.bmul(norm_incr[mask]**2, 1. / (2 * step_size[mask])) ls_tol[mask] = closure(x, return_jac=False)[mask] - rhs[mask] mask &= (ls_tol > LS_EPS) step_size[mask] *= backtracking_factor z = prox2(x + utils.bmul(step_size, u), step_size, *args_prox) u += utils.bmul(x - z, 1. / step_size) certificate = utils.bmul(norm_incr, 1. / step_size) if callback is not None: if callback(locals()) is False: break success = torch.bitwise_and(certificate < tol, it > 0) if success.all(): break return optimize.OptimizeResult(x=x, success=success, nit=it, fval=fval, certificate=certificate)
def minimize_alternating_fw_prox(closure, x0, y0, prox=None, lmo=None, lipschitz=1e-3, step='sublinear', line_search=None, max_iter=200, callback=None, *args, **kwargs): """ Implements algorithm from [Garber et al. 2018] https://arxiv.org/abs/1802.05581 to solve the following problem ..math:: \min_{x, y} f(x + y) + R_x(x) + R_y(y). We suppose that $f$ is $L$-smooth and that we have access to the following operators: - a generalized LMO for $R_y$: ..math:: gLMO(w) = \text{argmin}_w R_y(w) + \langle w, \nabla f(x_t + y_t) \rangle - a prox operator for $R_x$: ..math:: prox(v) = \text{argmin}_v R_x(v) + \langle v, \nabla f(x_t+ y_t) \rangle + \frac{\gamma_t L}{2} \|v + w_t - (x_t + y_t)\|^2 Args: x0: torch.Tensor of shape (batch_size, *) starting point for x y0: torch.Tensor of shape (batch_size, *) starting point for y prox: function proximal operator for R_x lmo: function generalized LMO operator for R_y. If R_y is an indicator function, it reduces to the usual LMO operator. lipschitz: float initial guess of the lipschitz constant of f step: float or 'sublinear' step-size scheme to be used. max_iter: int max number of iterations. callback: callable (optional) Any callable called on locals() at the end of each iteration. Often used for logging. Returns: result: optimize.OptimizeResult object Holds the result of the optimization, and certificates of convergence. """ x = x0.detach().clone() y = y0.detach().clone() batch_size = x.size(0) if x.shape != y.shape: raise ValueError( f"x, y should have the same shape. Got {x.shape}, {y.shape}.") if not (isinstance(step, Number) or step == 'sublinear'): raise ValueError( f"step must be a float or 'sublinear', got {step} instead.") if isinstance(step, Number): step_size = step * torch.ones( batch_size, device=x.device, dtype=x.dtype) # TODO: add error catching for L0 Lt = lipschitz for it in range(max_iter): if step == 'sublinear': step_size = 2. / (it + 2) * torch.ones(batch_size, device=x.device) x.requires_grad_(True) y.requires_grad_(True) z = x + y f_val, grad = closure(z) # estimate Lipschitz constant with backtracking line search Lt = utils.init_lipschitz(closure, z, L0=Lt) y_update, max_step_size = lmo(-grad, y) with torch.no_grad(): w = y_update + y prox_step_size = utils.bmul(step_size, Lt) v = prox(z - w - utils.bdiv(grad, prox_step_size), prox_step_size) with torch.no_grad(): if line_search is None: step_size = torch.min(step_size, max_step_size) else: step_size = line_search(locals()) y += utils.bmul(step_size, y_update) x_update = v - x x += utils.bmul(step_size, x_update) if callback is not None: if callback(locals()) is False: break fval, grad = closure(x + y) # TODO: add a certificate of optimality result = optimize.OptimizeResult(x=x, y=y, nit=it, fval=fval, grad=grad, certificate=None) return result