Exemplo n.º 1
0
    def __call__(self, A, b):
        lup = self.cache.get(id(A))
        if lup is None:
            lup = A.lu()
            self.cache[id(A)] = lup

        if b.ndim == 1:
            return torch.lu_solve(b.unsqueeze(-1), *lup).squeeze(-1)
        else:
            return torch.lu_solve(b, *lup)
Exemplo n.º 2
0
    def forward(ctx, A, b):
        A_LU, pivots = torch.lu(A)
        x = torch.lu_solve(b, A_LU, pivots)

        ctx.save_for_backward(A_LU, pivots, x)

        return x
Exemplo n.º 3
0
def _exp_pade_generic(A, m=7):
    """
    Minimal, inefficient implementation of the [m/m]-Padé approximation of the
    matrix exponential.
    """
    LU = torch.lu(_pade_poly(-A,m))
    result = torch.lu_solve(_pade_poly(A,m),*LU)
    return result
Exemplo n.º 4
0
 def _compute_weights(self):
     if self._target_dim > 1:
         # we first factorize the matrix
         self.nodes = torch.zeros(self.N, self._target_dim,
                                  dtype=self.di.dtype, device=self.device)
         lu_data = torch.lu(self.A)
         for i in range(self._target_dim):
             self.nodes[:, i] = torch.lu_solve(self.di[:, i].unsqueeze(0).T,
                                               *lu_data).squeeze()
     else:
         self.nodes = torch.solve(self.A, self.di)[0]
Exemplo n.º 5
0
    def backward(ctx, grad_x):
        A_LU, pivots, x = ctx.saved_tensors

        # Math:
        # A * grad_b = grad_x
        # grad_A = -grad_b * x^T

        grad_b = torch.lu_solve(grad_x, A_LU, pivots)
        grad_A = -torch.matmul(grad_b, x.view(1, -1))

        return grad_A, grad_b
Exemplo n.º 6
0
    def prox(self, t, nu, warm_start, pool, cache):
        # raise NotImplementedError("This method is not yet done!!!")

        XtX = cache['XtX']
        XtY = cache['XtY']
        n = cache['n']

        A_LU = torch.lu(XtX + 1. /
                        (2 * t) * torch.eye(n).unsqueeze(0).double())
        b = XtY + 1. / (2 * t) * torch.from_numpy(nu)
        x = torch.lu_solve(b, *A_LU)

        return x.numpy()
Exemplo n.º 7
0
    def solve(self, t):
        sigma = self.sigma
        m = self.m
        Lambda, U = self.Lambda, self.U
        LU = self.LU

        # Solve linear system (see eq. (11))
        y = Lambda @ U.t() @ t

        z = torch.lu_solve(y, *LU)

        alpha = 1.0 / sigma * (t - U @ z)

        return alpha
Exemplo n.º 8
0
    def solve_kkt_be(self, rx, rs, rz, ry):

        b1 = torch.cat((rx, rs), dim=1)
        if ry != None:
            b2 = torch.cat((rz, ry), dim=1)
        else:
            b2 = rz
        A11 = self.J[:, :self.nx + self.nineq, :self.nx + self.nineq]
        A12 = self.J[:, :self.nx + self.nineq, self.nx + self.nineq:]
        A21 = torch.transpose(A12, dim0=2, dim1=1)
        # self.J_lu,self.J_piv= self.lu_factorize(self.J)
        # U_A11= torch.cholesky(A11)
        U_A11, U_A11_piv = self.lu_factorize(A11)
        # u=torch.cholesky_solve(b1,U_A11)
        u = torch.lu_solve(b1, U_A11, U_A11_piv)
        # v=torch.cholesky_solve(A12,U_A11)
        v = torch.lu_solve(A12, U_A11, U_A11_piv)
        S_neg = torch.bmm(A21, v)
        U_S_neg, U_S_neg_piv = self.lu_factorize(S_neg)
        # w= torch.cholesky_solve(b2,U_S_neg)
        w = torch.lu_solve(b2, U_S_neg, U_S_neg_piv)
        # t= torch.cholesky_solve(A21,U_S_neg )
        t = torch.lu_solve(A21, U_S_neg, U_S_neg_piv)
        x2 = -(w - torch.bmm(t, u))
        x1 = u - torch.bmm(v, x2)
        dx = x1[:, :self.nx, :]
        ds = x1[:, self.nx:, :]
        if ry != None:
            dz = x2[:, :-self.neq, :]
        else:
            dz = x2
        if ry != None:
            dy = x2[:, -self.neq:, :]
        else:
            dy = None
        return (dx, ds, dz, dy)
Exemplo n.º 9
0
    def backward(self, y, log_df_dz):
        with torch.no_grad():
            LU = self.L * self.L_mask + self.U * self.U_mask + torch.diag(
                self.sign_s * torch.exp(self.log_s))

            y_reshape = y.view(y.size(0), y.size(1), -1)
            y_reshape = torch.lu_solve(y_reshape, LU.unsqueeze(0),
                                       self.pivots.unsqueeze(0))
            y = y_reshape.view(y.size())
            y = y.contiguous()

        num_pixels = np.prod(y.size()) // (y.size(0) * y.size(1))
        log_df_dz -= torch.sum(self.log_s, dim=0) * num_pixels

        return y, log_df_dz
Exemplo n.º 10
0
def factor_solve_kkt_reg(Q_tilde, D, G, A, rx, rs, rz, ry, eps):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q_tilde)
    H_[:, :nz, :nz] = Q_tilde
    H_[:, -nineq:, -nineq:] = D
    if neq > 0:
        # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1),
        # torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([
            torch.cat(
                [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)],
                2),
            torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q_tilde)], 2)
        ], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
    else:
        A_ = torch.cat(
            [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = lu_hack(H_)

    invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU)
    invH_g_ = g_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)

    S_ = torch.bmm(A_, invH_A_)
    S_ -= eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)
    S_LU = lu_hack(S_)
    t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
    w_ = -t_.unsqueeze(2).lu_solve(*S_LU).squeeze(2)
    t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze()
    v_ = t_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)
    v_ = torch.lu_solve(t_.unsqueeze(2), *H_LU).squeeze(2)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
Exemplo n.º 11
0
def xpbatch_lu_solve(lu_and_piv, b):
    """ solve ax = b

    :param lu_and_piv:
    :param b:
    :return:
    """
    LU, piv = lu_and_piv
    # xp = get_array_module(LU)
    b = b.copy()
    '''
    for i in range(len(LU)):
        if cupy_available and xp == cupy:
            b[i] = cupyx.scipy.linalg.lu_solve((LU[i], piv[i]), b[i], overwrite_b=True)
        else:
            b[i] = scipy.linalg.lu_solve((LU[i], piv[i]), b[i], overwrite_b=True)
    '''
    b = torch.Tensor(b)
    b = b.float()
    LU = torch.from_numpy(LU)
    piv = torch.from_numpy(piv)
    LU = LU.float()
    b = torch.lu_solve(b, LU, piv)
    return b.cpu().numpy()
Exemplo n.º 12
0
def newton_exact(f, g, x_guess, opt_params, ls_method, ls_params):
    """
    This function performs gradient descent using newton's method as the search direction

    INPUTS:
        f < function > : objective function f(x) -> f
        g < function > : gradient function g(x) -> g
        x_guess < tensor > : initial x
        opt_params < dict{
            'ep_g' < float > : conv. tolerance on gradient
            'ep_a' < float > : absolute tolerance
            'ep_r' < float > : relative tolerance
            'Hessian' < function > : function that returns the Hessian
            'iter_lim' < int > : iteration limit
        } > : dictionary of optimization settings
        ls_method < str > : indicates which method to use with line search
        ls_params < dict > : dictionary with parameters to use for line search
    """
    ep_g = opt_params['ep_g']
    ep_a = opt_params['ep_a']
    ep_r = opt_params['ep_r']
    H = opt_params['Hessian']
    iter_lim = opt_params['iter_lim']

    # initializations
    x_k = x_guess
    x_hist = [x_k]
    f_k = f(x_guess)
    f_hist = [f_k]
    k = 0
    conv_count = 0

    # how many iterations for rel. abs. tolerance met before stopping
    conv_count_max = 2

    while k < iter_lim:
        k += 1

        # compute gradient
        g_k = g(x_k)

        # check for gradient convergence
        if torch.norm(g_k) <= ep_g:
            converge = True
            message = "Exact Newton converged due to grad. tolerance."
            break

        # invert Hessian and find search direction
        H_k = H(x_k)
        H_LU, pivots, infos = torch.lu(H_k.reshape(
            [1, H_k.shape[0], -1]), get_infos=True)

        if infos.nonzero().size(0) != 0:
            # check if LU factorization failed
            converge = False
            message = "Hessian LU factorization failed."
            break

        # LU solve is designed for batch operations, hence the [0]
        delta_k = torch.lu_solve(-g_k.unsqueeze(0), H_LU, pivots)[0]

        if torch.matmul(delta_k.t(), g_k) < 0:
            p_k = delta_k
        else:
            p_k = -delta_k

        # perform line search
        alf, ls_converge, ls_message = line_search(f, x_k, g_k, p_k,
                                                   ls_method=ls_method, ls_params=ls_params)
        if not ls_converge:
            converge = ls_converge
            message = ls_message
            break

        # compute x_(k+1)
        x_k1, f_k1 = search_step(f, x_k, alf, p_k)

        # check relative and absolute convergence criteria
        if rel_abs_convergence(f_k, f_k1, ep_a, ep_r):
            conv_count += 1

        x_k = x_k1
        f_k = f_k1

        x_hist.append(x_k)
        f_hist.append(f_k)

        if conv_count >= conv_count_max:
            converge = True
            message = "Exact Newton converged due to abs. rel. tolerance."
            break

    if k == iter_lim:
        converge = False
        message = "Exact Newton iteration limit reached."

    return x_k, f_k, x_hist, f_hist, converge, message
Exemplo n.º 13
0
def quad_search(f, x_k, g_k, p_k, ls_params):
    """
    This function performs approximate quadratic line search

    INPUTS:
        f < function > : objective function f(x) -> f
        x_k < tensor > : current best guess for f(x) minimum
        g_k < tensor > : gradient evaluated at x_k
        p_k < tensor > : search direction
        alf < float > : initial step length
        ls_params < dict{
            'alf' < float > : initial guess for step-length
            'mu' < float > : small positive constant used in "Armijo suff. decrease condition"
            'rho' < float > : step-size dicount coefficient
            'iter_lim < int > : iteration limit for solver
            'alf_lower_coeff' : coefficient for determining point one in quad_search
            'alf_upper_coeff' : coefficient for determining point three in quad_search
        } > : dictionary with parameters to use for line search

    RETURNS:
        alf_new < float > : computed search length
        converge < bool > : bool indicating whether line search converged
        message < string > : string with output from back tracking method
    """
    mu = ls_params['mu']
    iter_lim = ls_params['iter_lim']
    alf_new = ls_params['alf']
    alf_coeff1 = ls_params['alf_lower_coeff']
    alf_coeff2 = ls_params['alf_upper_coeff']
    iter = 0
    while not armijo_suff_decrease(f, x_k, g_k, p_k, alf_new,
                                   mu) and iter < iter_lim:
        a1 = alf_new
        a2 = alf_new * 0.1
        a3 = alf_new * 2.0
        f1 = f(x_k + a1 * p_k)
        f2 = f(x_k + a2 * p_k)
        f3 = f(x_k + a3 * p_k)

        A = torch.tensor([[1 / 2 * a1.pow(2), a1,
                           1], [1 / 2 * a2.pow(2), a2, 1],
                          [1 / 2 * a3.pow(2), a3, 1]])
        b = torch.tensor([[f1], [f2], [f3]])

        A_LU, pivots, infos = torch.lu(A.reshape([1, A.shape[0], -1]),
                                       get_infos=True)

        if infos.nonzero().size(0) != 0:
            converge = False
            message = "Quadratic approx was not possible."
            break

        coeff = torch.lu_solve(b.unsqueeze(0), A_LU, pivots)[0]
        alf_new = -coeff[1] / coeff[0]
        iter += 1

    if iter == iter_lim:
        converge = False
        message = "Quadratic approx. line search iteration limit reached."
    else:
        converge = True
        message = "Quadratic approx. line search converged."

    return alf_new, converge, message