Пример #1
0
    def step(self, closure):
        """Performs a single optimization step.
        Arguments:
            closure (callable): a closure that reevaluates the model and returns the loss
        """
        closure = torch.enable_grad()(closure)

        for group in self.param_groups:
            params = group['params']

            alpha = group['alpha']
            affine_invariant = group['affine_invariant']
            L = group['L']
            lambd = group['lambd']
            subsolver = group['subsolver']

            g = tuple_to_vec.tuple_to_vector(
                torch.autograd.grad(closure(), list(params),
                                    create_graph=True))

            if subsolver is None:
                x = exact(g, params, lambd)
            else:
                raise NotImplementedError()

            if affine_invariant:
                G = L * g.dot(-tuple_to_vec.tuple_to_vector(x))
                alpha = ((torch.sqrt(1 + 2 * G) - 1) / G).item()

            with torch.no_grad():
                for i, p in enumerate(params):
                    p.add_(x[i], alpha=alpha)
        return None
Пример #2
0
def iterative(params, closure, L, subsolver, subsolver_args, max_iters,
              rel_acc):
    x = torch.zeros_like(tuple_to_vec.tuple_to_vector(list(params)),
                         requires_grad=True)
    optimizer = subsolver([x], **subsolver_args)

    for _ in range(max_iters):
        optimizer.zero_grad()
        Hx, df = derivatives.flat_hvp(closure, list(params), x)

        for p in params:
            if p.grad is not None:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

        x.grad = df + Hx + x.mul(L * x.norm() / 2.)
        optimizer.step()

        if x.grad.norm() < rel_acc * df.norm():
            return True, tuple_to_vec.rollup_vector(x.detach(), list(params))

    return False, tuple_to_vec.rollup_vector(x.detach(), list(params))
Пример #3
0
    def step(self, closure):
        """Solves a subproblem.
        Arguments:
            closure (callable): a closure that reevaluates the model and returns the loss.
        """
        closure = torch.enable_grad()(closure)

        for group in self.param_groups:
            params = group['params']

            L = group['L']
            subsolver = group['subsolver']
            max_iters = group['max_iters']
            subsolver_args = group['subsolver_args']
            max_iters_outer = group['max_iters_outer']

            df = tuple_to_vec.tuple_to_vector(
                torch.autograd.grad(closure(), list(params), create_graph=True))

            v_flat = torch.zeros_like(df)
            v = tuple_to_vec.rollup_vector(v_flat, list(params))

            if subsolver is None:
                full_hessian = derivatives.flat_hessian(df, list(params)).to(torch.double)
                eigenvalues, eigenvectors = torch.linalg.eigh(full_hessian)

            for _ in range(max_iters_outer):
                D3xx, Hx = derivatives.third_derivative_vec(
                    closure, list(params), v, flat=True)
                with torch.no_grad():
                    D3xx = D3xx.to(torch.double)
                    Hx = Hx.to(torch.double)
                    Lv3 = v_flat * (v_flat.norm().square() * L)
                    g = df.add(Hx).add(D3xx, alpha=0.5).add(Lv3)

                if self._check_stopping_condition(closure, params, v, g.norm()):
                    self._add_v(params, v)
                    return True

                with torch.no_grad():
                    c = g.div(2. + math.sqrt(2)).sub(Hx).sub(Lv3)

                if subsolver is None:
                    v_flat = exact(L, c.detach(), T=eigenvalues, U=eigenvectors)
                else:
                    v_flat = iterative(params, closure, L, c.detach(),
                                  subsolver, subsolver_args, max_iters)
                with torch.no_grad():
                    v = tuple_to_vec.rollup_vector(v_flat, list(params))

            self._add_v(params, v)
        return False
Пример #4
0
def exact(params, closure, L, tol=1e-10):
    df = tuple_to_vec.tuple_to_vector(
        torch.autograd.grad(closure(), list(params), create_graph=True))
    H = derivatives.flat_hessian(df, list(params))

    c = df.clone().detach().to(torch.double)
    A = H.clone().detach().to(torch.double)

    if c.dim() != 1:
        raise ValueError(f"`c` must be a vector, but c = {c}")

    if A.dim() > 2:
        raise ValueError(f"`A` must be a matrix, but A = {A}")

    if c.size()[0] != A.size()[0]:
        raise ValueError("`c` and `A` mush have the same 1st dimension")

    if (A.t() - A).max() > 0.1:
        raise ValueError("`A` is not symmetric")

    T, U = torch.linalg.eigh(A)
    ct = U.t().mv(c)

    def inv(T, L, tau):
        return (T + L / 2 * tau).reciprocal()

    def dual(tau):        return L/12 * tau.pow(3) + 1/2 * \
inv(T, L, tau).mul(ct.square()).sum()

    tau_best = line_search.ray_line_search(dual,
                                           eps=tol,
                                           middle_point=torch.tensor([2.]),
                                           left_point=torch.tensor([0.]))

    invert = inv(T, L, tau_best)
    x = -U.mv(invert.mul(ct).type_as(U))

    if not (c + L / 2 * x.norm() * x + A.mv(x)).abs().max().item() < 0.01:
        raise ValueError('obtained `x` is not optimal')

    return tuple_to_vec.rollup_vector(x, list(params))
Пример #5
0
def iterative(params, closure, L, c, subsolver, subsolver_args, max_iters):
    x = torch.zeros_like(tuple_to_vec.tuple_to_vector(
        list(params)), requires_grad=True)
    optimizer = subsolver([x], **subsolver_args)

    for _ in range(max_iters):
        optimizer.zero_grad()
        Hx, __ = derivatives.flat_hvp(closure, list(params), x)

        for p in params:
            if p.grad is not None:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

        x.grad = c + Hx + x.mul(L * x.norm() ** 2)
        optimizer.step()

    return x.detach()
Пример #6
0
 def _check_stopping_condition(self, closure, params, v, g_norm):
     self._add_v(params, v)
     df_norm = tuple_to_vec.tuple_to_vector(
         torch.autograd.grad(closure(), list(params))).norm()
     self._add_v(params, v, alpha=-1)
     return g_norm <= 1/6 * df_norm