def step(self, closure): """Performs a single optimization step. Arguments: closure (callable): a closure that reevaluates the model and returns the loss """ closure = torch.enable_grad()(closure) for group in self.param_groups: params = group['params'] alpha = group['alpha'] affine_invariant = group['affine_invariant'] L = group['L'] lambd = group['lambd'] subsolver = group['subsolver'] g = tuple_to_vec.tuple_to_vector( torch.autograd.grad(closure(), list(params), create_graph=True)) if subsolver is None: x = exact(g, params, lambd) else: raise NotImplementedError() if affine_invariant: G = L * g.dot(-tuple_to_vec.tuple_to_vector(x)) alpha = ((torch.sqrt(1 + 2 * G) - 1) / G).item() with torch.no_grad(): for i, p in enumerate(params): p.add_(x[i], alpha=alpha) return None
def iterative(params, closure, L, subsolver, subsolver_args, max_iters, rel_acc): x = torch.zeros_like(tuple_to_vec.tuple_to_vector(list(params)), requires_grad=True) optimizer = subsolver([x], **subsolver_args) for _ in range(max_iters): optimizer.zero_grad() Hx, df = derivatives.flat_hvp(closure, list(params), x) for p in params: if p.grad is not None: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() x.grad = df + Hx + x.mul(L * x.norm() / 2.) optimizer.step() if x.grad.norm() < rel_acc * df.norm(): return True, tuple_to_vec.rollup_vector(x.detach(), list(params)) return False, tuple_to_vec.rollup_vector(x.detach(), list(params))
def step(self, closure): """Solves a subproblem. Arguments: closure (callable): a closure that reevaluates the model and returns the loss. """ closure = torch.enable_grad()(closure) for group in self.param_groups: params = group['params'] L = group['L'] subsolver = group['subsolver'] max_iters = group['max_iters'] subsolver_args = group['subsolver_args'] max_iters_outer = group['max_iters_outer'] df = tuple_to_vec.tuple_to_vector( torch.autograd.grad(closure(), list(params), create_graph=True)) v_flat = torch.zeros_like(df) v = tuple_to_vec.rollup_vector(v_flat, list(params)) if subsolver is None: full_hessian = derivatives.flat_hessian(df, list(params)).to(torch.double) eigenvalues, eigenvectors = torch.linalg.eigh(full_hessian) for _ in range(max_iters_outer): D3xx, Hx = derivatives.third_derivative_vec( closure, list(params), v, flat=True) with torch.no_grad(): D3xx = D3xx.to(torch.double) Hx = Hx.to(torch.double) Lv3 = v_flat * (v_flat.norm().square() * L) g = df.add(Hx).add(D3xx, alpha=0.5).add(Lv3) if self._check_stopping_condition(closure, params, v, g.norm()): self._add_v(params, v) return True with torch.no_grad(): c = g.div(2. + math.sqrt(2)).sub(Hx).sub(Lv3) if subsolver is None: v_flat = exact(L, c.detach(), T=eigenvalues, U=eigenvectors) else: v_flat = iterative(params, closure, L, c.detach(), subsolver, subsolver_args, max_iters) with torch.no_grad(): v = tuple_to_vec.rollup_vector(v_flat, list(params)) self._add_v(params, v) return False
def exact(params, closure, L, tol=1e-10): df = tuple_to_vec.tuple_to_vector( torch.autograd.grad(closure(), list(params), create_graph=True)) H = derivatives.flat_hessian(df, list(params)) c = df.clone().detach().to(torch.double) A = H.clone().detach().to(torch.double) if c.dim() != 1: raise ValueError(f"`c` must be a vector, but c = {c}") if A.dim() > 2: raise ValueError(f"`A` must be a matrix, but A = {A}") if c.size()[0] != A.size()[0]: raise ValueError("`c` and `A` mush have the same 1st dimension") if (A.t() - A).max() > 0.1: raise ValueError("`A` is not symmetric") T, U = torch.linalg.eigh(A) ct = U.t().mv(c) def inv(T, L, tau): return (T + L / 2 * tau).reciprocal() def dual(tau): return L/12 * tau.pow(3) + 1/2 * \ inv(T, L, tau).mul(ct.square()).sum() tau_best = line_search.ray_line_search(dual, eps=tol, middle_point=torch.tensor([2.]), left_point=torch.tensor([0.])) invert = inv(T, L, tau_best) x = -U.mv(invert.mul(ct).type_as(U)) if not (c + L / 2 * x.norm() * x + A.mv(x)).abs().max().item() < 0.01: raise ValueError('obtained `x` is not optimal') return tuple_to_vec.rollup_vector(x, list(params))
def iterative(params, closure, L, c, subsolver, subsolver_args, max_iters): x = torch.zeros_like(tuple_to_vec.tuple_to_vector( list(params)), requires_grad=True) optimizer = subsolver([x], **subsolver_args) for _ in range(max_iters): optimizer.zero_grad() Hx, __ = derivatives.flat_hvp(closure, list(params), x) for p in params: if p.grad is not None: if p.grad.grad_fn is not None: p.grad.detach_() else: p.grad.requires_grad_(False) p.grad.zero_() x.grad = c + Hx + x.mul(L * x.norm() ** 2) optimizer.step() return x.detach()
def _check_stopping_condition(self, closure, params, v, g_norm): self._add_v(params, v) df_norm = tuple_to_vec.tuple_to_vector( torch.autograd.grad(closure(), list(params))).norm() self._add_v(params, v, alpha=-1) return g_norm <= 1/6 * df_norm