예제 #1
def forward_ineq(Q, p, G, h, verbose=0, maxIter=100, dt=0.2):
    """ Solves a QP problem with only inequality constraints, fastest version
        for the dynamic solver
    nineq, nz, neq, nBatch = get_sizes(G)
    A_T = torch.transpose(G, -1, 1)
    Q_I = torch.inverse(Q)

    AQ_I = G.bmm(Q_I)
    D = -AQ_I.bmm(A_T)
    d = -(h.unsqueeze(2) + AQ_I.bmm(p.unsqueeze(2)))

    lams = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device)
    zeros = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device)

    for _ in range(maxIter):
        dlams = dt * (D.bmm(lams) + d)
        dlams = torch.max(lams + dlams, zeros) - lams
        # lams = torch.where(lams > 0, lams, zeros)
        # lams = torch.max(lams + dlams, zeros)

    zhat = -Q_I.bmm(p.unsqueeze(2) + A_T.bmm(lams))
    slacks = h.unsqueeze(2) - G.bmm(zhat)

    return zhat.squeeze(2), lams.squeeze(2), slacks.squeeze(2)
예제 #2
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    invQ_rx = rx.btrs(Q_LU)
    if neq > 0:
        h = torch.cat(
            (invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)) - ry,
             invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)) + rs / d - rz),
        h = (invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)) + rs / d -

    w = -(h.btrs(S_LU))

    g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G)
    if neq > 0:
        g1 -= w[:, :neq].unsqueeze(1).bmm(A)
    g2 = -rs - w[:, neq:]

    dx = g1.btrs(Q_LU)
    ds = g2 / d
    dz = w[:, neq:]
    dy = w[:, :neq] if neq > 0 else None

    return dx, ds, dz, dy
예제 #3
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, _ = get_sizes(G, A)

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T           ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]

    U_Q = torch.potrf(Q)
    # partial cholesky of S matrix
    U_S = torch.zeros(neq + nineq, neq + nineq).type_as(Q)

    G_invQ_GT = torch.mm(G, torch.potrs(G.t(), U_Q))
    R = G_invQ_GT
    if neq > 0:
        invQ_AT = torch.potrs(A.t(), U_Q)
        A_invQ_AT = torch.mm(A, invQ_AT)
        G_invQ_AT = torch.mm(G, invQ_AT)

        # TODO: torch.potrf sometimes says the matrix is not PSD but
        # numpy does? I filed an issue at
        # https://github.com/pytorch/pytorch/issues/199
            U11 = torch.potrf(A_invQ_AT)
            U11 = torch.Tensor(np.linalg.cholesky(

        # TODO: torch.trtrs is currently not implemented on the GPU
        # and we are using gesv as a workaround.
        U12 = torch.gesv(G_invQ_AT.t(), U11.t())[0]
        U_S[:neq, :neq] = U11
        U_S[:neq, neq:] = U12
        R -= torch.mm(U12.t(), U12)

    return U_Q, U_S, R
예제 #4
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
    if neq > 0:
        h = torch.cat(
            (invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry,
             invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d -
             rz), 1)
        h = invQ_rx.unsqueeze(1).bmm(G.transpose(1,
                                                 2)).squeeze(1) + rs / d - rz
    w = -(h.unsqueeze(2).lu_solve(*S_LU)).squeeze(2)
    # if any(torch.isnan(torch.flatten(w)).tolist()):
    #     logging.info("**Alert** nan in param  w ")
    #     logging.info( any(torch.isnan(torch.flatten(S_LU[0])).tolist()))
    #     logging.info( any(torch.isnan(torch.flatten(S_LU[1])).tolist()))
    #     logging.info( any(torch.isnan(torch.flatten(h)).tolist()))
    #     w = -sp_lu_solve(h.unsqueeze(2),*S_LU).squeeze(2)
    # w[torch.isnan(w)] = 0
    # logging.info(w[0:3])

    g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1)
    if neq > 0:
        g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1)
    g2 = -rs - w[:, neq:]
    dx = g1.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
    ds = g2 / d
    dz = w[:, neq:]
    dy = w[:, :neq] if neq > 0 else None

    return dx, ds, dz, dy
예제 #5
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, _ = get_sizes(G, A)

    if neq > 0:
        H_ = torch.cat([
            torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1),
            torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)
        ], 0)
        A_ = torch.cat([
            torch.cat([G, torch.eye(nineq).type_as(Q)], 1),
            torch.cat([A, torch.zeros(neq, nineq).type_as(Q)], 1)
        ], 0)
        g_ = torch.cat([rx, rs], 0)
        h_ = torch.cat([rz, ry], 0)
        H_ = torch.cat([
            torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1),
            torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)
        ], 0)
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 0)
        h_ = rz

    U_H_ = torch.potrf(H_)

    invH_A_ = torch.potrs(A_.t(), U_H_)
    invH_g_ = torch.potrs(g_.view(-1, 1), U_H_).view(-1)

    S_ = torch.mm(A_, invH_A_)
    U_S_ = torch.potrf(S_)
    t_ = torch.mv(A_, invH_g_).view(-1, 1) - h_
    w_ = -torch.potrs(t_, U_S_).view(-1)
    v_ = torch.potrs(-g_.view(-1, 1) - torch.mv(A_.t(), w_), U_H_).view(-1)

    return v_[:nz], v_[nz:], w_[:nineq], w_[nineq:] if neq > 0 else None
예제 #6
def solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry, dbg=False):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, _ = get_sizes(G, A)

    invQ_rx = torch.potrs(rx.view(-1, 1), U_Q).view(-1)
    if neq > 0:
        h = torch.cat(
            [torch.mv(A, invQ_rx) - ry,
             torch.mv(G, invQ_rx) + rs / d - rz], 0)
        h = torch.mv(G, invQ_rx) + rs / d - rz

    w = -torch.potrs(h.view(-1, 1), U_S).view(-1)

    g1 = -rx - torch.mv(G.t(), w[neq:])
    if neq > 0:
        g1 -= torch.mv(A.t(), w[:neq])
    g2 = -rs - w[neq:]

    dx = torch.potrs(g1.view(-1, 1), U_Q).view(-1)
    ds = g2 / d
    dz = w[neq:]
    dy = w[:neq] if neq > 0 else None

    # if np.all(np.array([x.norm() for x in [rx, rs, rz, ry]]) != 0):
    if dbg:
        import IPython
        import sys

    # if rs.norm() > 0: import IPython, sys; IPython.embed(); sys.exit(-1)
    return dx, ds, dz, dy
예제 #7
def solve_kkt_ir(Q, D, G, A, rx, rs, rz, ry, niter=1):
    """Inefficient iterative refinement."""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    eps = 1e-7
    Q_tilde = Q + eps * torch.eye(nz).type_as(Q).repeat(nBatch, 1, 1)
    D_tilde = D + eps * torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)

    dx, ds, dz, dy = factor_solve_kkt_reg(
        Q_tilde, D_tilde, G, A, rx, rs, rz, ry, eps)
    res = kkt_resid_reg(Q, D, G, A, eps,
                        dx, ds, dz, dy, rx, rs, rz, ry)
    resx, ress, resz, resy = res
    res = resx
    for k in range(niter):
        ddx, dds, ddz, ddy = factor_solve_kkt_reg(Q_tilde, D_tilde, G, A, -resx, -ress, -resz,
                                                  -resy if resy is not None else None,
        dx, ds, dz, dy = [v + dv if v is not None else None
                          for v, dv in zip((dx, ds, dz, dy), (ddx, dds, ddz, ddy))]
        res = kkt_resid_reg(Q, D, G, A, eps,
                            dx, ds, dz, dy, rx, rs, rz, ry)
        resx, ress, resz, resy = res
        # res = torch.cat(resx)
        res = resx

    return dx, ds, dz, dy
예제 #8
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q)
    H_[:, :nz, :nz] = Q
    H_[:, -nineq:, -nineq:] = D
    if neq > 0:
        A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)], 2),
                        torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q)], 2)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = lu_hack(H_)

    invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU)
    invH_g_ = g_.lu_solve(*H_LU)

    S_ = torch.bmm(A_, invH_A_)
    S_LU = lu_hack(S_)
    t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
    w_ = -t_.lu_solve(*S_LU)
    t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze()
    v_ = t_.lu_solve(*H_LU)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
예제 #9
파일: batch.py 프로젝트: lopa23/flim_optcrf
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)
        Q_LU = lu_hack(Q)
        raise RuntimeError("""
qpth Error: Cannot perform LU factorization on Q.
Please make sure that your Q matrix is PSD and has
a non-zero diagonal.

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    # We compute a partial LU decomposition of the S matrix
    # that can be completed once D^{-1} is known.
    # See the 'Block LU factorization' part of our website
    # for more details.
    qlu, pivots=Q_LU
   ##put in a condition for m>1
    if G.size(2)==2: ##something funny here
        G_invQ_GT = torch.matmul(G, G.lu_solve(*Q_LU).transpose(1,2))
        G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU))
    R = G_invQ_GT.clone()
    S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
        .repeat(nBatch, 1).type_as(Q).int()
    if neq > 0:
        invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU)
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)

        LU_A_invQ_AT = lu_hack(A_invQ_AT)
        P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT)
        P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)

        S_LU_11 = LU_A_invQ_AT[0]
        U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)
        S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
        T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
        S_LU_data = torch.cat((torch.cat((S_LU_11, S_LU_12), 2),
                               torch.cat((S_LU_21, S_LU_22), 2)),
        S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]

        R -= G_invQ_AT.bmm(T)
        S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)

    S_LU = [S_LU_data, S_LU_pivots]
    return Q_LU, S_LU, R
예제 #10
def forward_eq_conv(Q, p, G_, h_, A_, b_, verbose=0, maxIter=100, dt=0.2):
        Solves the QP problem by transforming equality constraints into two
        inequality constraint, then using the ineq solver.
    nineq, nz, neq, nBatch = get_sizes(G_, A_)

    G = torch.cat((G_, A_, -A_), dim=1)
    h = torch.cat((h_, b_, -b_), dim=1)
    A_T = torch.transpose(G, -1, 1)
    Q_I = torch.inverse(Q)

    AQ_I = G.bmm(Q_I)
    D = -AQ_I.bmm(A_T)
    d = -(h.unsqueeze(2) + AQ_I.bmm(p.unsqueeze(2)))

    lams = torch.zeros(nBatch, nineq + neq + neq, 1).type_as(Q).to(Q.device)
    zeros = torch.zeros(nBatch, nineq + neq + neq, 1).type_as(Q).to(Q.device)

    for _ in range(maxIter):
        dlams = dt * (D.bmm(lams) + d)
        dlams = torch.max(lams + dlams, zeros) - lams

    zhat = -Q_I.bmm(p.unsqueeze(2) + A_T.bmm(lams))
    slacks = h.unsqueeze(2) - G.bmm(zhat)

    return zhat.squeeze(2), lams.squeeze(2)[:, :nineq], lams.squeeze(
        2)[:, -neq:], slacks.squeeze(2)
예제 #11
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    Q_LU = Q.btrf()

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    # We compute a partial LU decomposition of S matrix
    # that can be completed once D^{-1} is known.
    # This is done for a general matrix by decomposing
    # S using the Schur complement and then LU-factorizing
    # the matrices in the middle:
    #   [ A B ] = [ I            0 ] [ A     0              ] [ I    A^{-1} B ]
    #   [ C D ]   [ C A^{-1}     I ] [ 0     D - C A^{-1} B ] [ 0    I        ]

    G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrs(Q_LU))
    R = G_invQ_GT
    if neq > 0:
        invQ_AT = A.transpose(1, 2).btrs(Q_LU)
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)

        LU_A_invQ_AT = A_invQ_AT.btrf()
        if neq == 1:
            LU_A_invQ_AT = LU_A_invQ_AT.view(-1, 1, 1)
        M = torch.tril(torch.ones(neq,
                                  neq)).unsqueeze(0).expand(nBatch, neq,
        L_A_invQ_AT = M * LU_A_invQ_AT
            nBatch, neq, neq).type_as(Q).byte()] = 1.0
        M = torch.triu(torch.ones(neq,
                                  neq)).unsqueeze(0).expand(nBatch, neq,
        U_A_invQ_AT = M * LU_A_invQ_AT

        S_LU_11 = LU_A_invQ_AT
        S_LU_21 = G_invQ_AT.bmm(L_A_invQ_AT.btrs(LU_A_invQ_AT))
        T = G_invQ_AT.transpose(1, 2).btrs(LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU = torch.cat(
            (torch.cat((S_LU_11, S_LU_12), 2),
                 (S_LU_21, torch.zeros(nBatch, nineq, nineq).type_as(Q)), 2)),
        R -= G_invQ_AT.bmm(T)
        S_LU = torch.zeros(nBatch, nineq, nineq).type_as(Q)

    return Q_LU, S_LU, R
예제 #12
파일: batch.py 프로젝트: lopa23/flim_optcrf
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, nBatch = get_sizes(G, A)
    qlu, pivots=Q_LU
    #print("InEq cons, nz, eq cons, btch",nineq, nz, neq, nBatch) #how is G changing dim to 3D
    #print("Now in solve KKt", G.size(),rx.size(), qlu.size())
    if G.size(1)>=2:
        invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
        invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)

    #print("Size after",invQ_rx.size(),G.size(),rs.size(),rz.size())
    if neq > 0:
        h = torch.cat((invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry,
                       invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz), 1)
        h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz
           # h=torch.cat(h,h,1)
    slu, pv=S_LU

    #print("W getting generated",h.size(),slu.size())
    if G.size(1)>=2 and h.size(0)>1:
         w = -(h.unsqueeze(0).lu_solve(*S_LU)).squeeze(2)
        w = -(h.unsqueeze(2).lu_solve(*S_LU)).squeeze(2)

    if G.size(1)>=2:
         g1 = -rx - w[:, neq:].matmul(G).squeeze(1)
        g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1)
    if neq > 0:
        g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1)
    g2 = -rs - w[:, neq:]
    if g1.dim()>2:
         dx = g1.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
    ds = g2 / d
    dz = w[:, neq:]
    dy = w[:, :neq] if neq > 0 else None
   # print("Done Solve KKT");
    return dx, ds, dz, dy
예제 #13
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    if neq > 0:
        # import IPython, sys; IPython.embed(); sys.exit(-1)
        # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1),
        #                 torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        # A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q)], 1),
        #                 torch.cat([A, torch.zeros(neq, nineq).type_as(Q)], 1)], 0)
        # g_ = torch.cat([rx, rs], 0)
        # h_ = torch.cat([rz, ry], 0)

        H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q)
        H_[:, :nz, :nz] = Q.repeat(nBatch, 1, 1)
        H_[:, -nineq:, -nineq:] = D

        from block import block
        A_ = block(((G, 'I'), (A, torch.zeros(neq, nineq).type_as(Q))))

        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
        H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q)
        H_[:, :nz, :nz] = Q.repeat(nBatch, 1, 1)
        H_[:, -nineq:, -nineq:] = D
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = H_.btrf()

    A = A_.repeat(nBatch, 1, 1)
    invH_A_ = A.transpose(1, 2).btrs(H_LU)
    invH_g_ = g_.btrs(H_LU)

    S_ = torch.bmm(A, invH_A_)
    S_LU = S_.btrf()
    t_ = torch.mm(invH_g_, A_.t()) - h_
    w_ = -t_.btrs(S_LU)
    t_ = -g_ - w_.mm(A_)
    v_ = t_.btrs(H_LU)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
예제 #14
파일: batch.py 프로젝트: lopa23/flim_optcrf
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q)
    H_[:, :nz, :nz] = Q
    H_[:, -nineq:, -nineq:] = D
    if neq > 0:
        A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)], 2),
                        torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q)], 2)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = lu_hack(H_)
    lu, pv=H_LU
   # print(A_.size())
    invH_A_ = A_.lu_solve(*H_LU).unsqueeze(1)#changed from  A_.transpose(1, 2).lu_solve(*H_LU)
    invH_g_ = g_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)

    S_ = torch.bmm(A_, invH_A_)
    S_LU = lu_hack(S_)
    slu, pv=S_LU

    #print(A_.size(),invH_g_.unsqueeze(1).size(), torch.bmm(invH_g_.unsqueeze(1), A_).size(),h_.size())
    t_ = torch.bmm(A_,invH_g_.unsqueeze(1)).squeeze(1) - h_#changed from torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
    w_ = -t_.lu_solve(*S_LU).squeeze(2)#changed
    t_ = -g_ - w_.bmm(A_).squeeze()
    v_ = t_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
예제 #15
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    Q_LU = Q.btrifact()

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    # We compute a partial LU decomposition of the S matrix
    # that can be completed once D^{-1} is known.
    # See the 'Block LU factorization' part of our website
    # for more details.

    G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrisolve(*Q_LU))
    R = G_invQ_GT.clone()
    S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
        .repeat(nBatch, 1).type_as(Q).int()
    if neq > 0:
        invQ_AT = A.transpose(1, 2).btrisolve(*Q_LU)
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)

        LU_A_invQ_AT = A_invQ_AT.btrifact()
        P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.btriunpack(*LU_A_invQ_AT)
        P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)

        S_LU_11 = LU_A_invQ_AT[0]
        U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)).btrisolve(
        S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
        T = G_invQ_AT.transpose(1, 2).btrisolve(*LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
        S_LU_data = torch.cat((torch.cat(
            (S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1)
        S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]

        R -= G_invQ_AT.bmm(T)
        S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)

    S_LU = [S_LU_data, S_LU_pivots]
    return Q_LU, S_LU, R
예제 #16
파일: batch.py 프로젝트: jdily/qpth
def factor_solve_kkt_reg(Q_tilde, D, G, A, rx, rs, rz, ry, eps):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q_tilde)
    H_[:, :nz, :nz] = Q_tilde
    H_[:, -nineq:, -nineq:] = D
    if neq > 0:
        # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1),
        # torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([
                [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)],
            torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q_tilde)], 2)
        ], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
        A_ = torch.cat(
            [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = H_.btrifact()

    invH_A_ = A_.transpose(1, 2).btrisolve(*H_LU)
    invH_g_ = g_.btrisolve(*H_LU)

    S_ = torch.bmm(A_, invH_A_)
    S_ -= eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)
    S_LU = S_.btrifact()
    t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
    w_ = -t_.btrisolve(*S_LU)
    t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze()
    v_ = t_.btrisolve(*H_LU)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
예제 #17
def forward_eq_new(Q_, p_, G_, h_, A_, b_, verbose=0, maxIter=100, dt=0.2):
    """ Solves equality constraints by dynamically solving both priml and
        dual problems side-by-side
    neq, nz, _, nBatch = get_sizes(A_)

    Q__ = torch.cat((Q_, -Q_), dim=1)
    Q = torch.cat((Q__, -Q__), dim=2)
    p = torch.cat((p_, -p_), dim=1).unsqueeze(2)
    A = torch.cat((A_, -A_), dim=2)
    b = b_.unsqueeze(2)

    Q_T = torch.transpose(Q, -1, 1)
    A_T = torch.transpose(A, -1, 1)
    p_T = torch.transpose(p, -1, 1)
    b_T = torch.transpose(b, -1, 1)

    # Expand 'x' to allow for negative values
    x = torch.zeros(nBatch, 2 * nz, 1).type_as(Q).to(Q.device)
    y = torch.zeros(nBatch, neq, 1).type_as(Q).to(Q.device)

    for _ in range(maxIter):
        x_T = torch.transpose(x, -1, 1)

        g = x_T.bmm(Q).bmm(x) + p_T.bmm(x) - b_T.bmm(y)
        r = neg(Q.bmm(x) + p - A_T.bmm(y))
        dFx = g.bmm(2 * Q.bmm(x) + p)
        dFy = g.bmm(b)
        dx = -dFx - A_T.bmm(A.bmm(x) - b) - neg(x) - Q_T.bmm(r)
        dy = dFy + A.bmm(r)


    slacks = torch.zeros(nBatch, neq).type_as(Q).to(Q.device)

    return (x[:, :nz, :] - x[:, nz:, :]).squeeze(2), y.squeeze(2), slacks
예제 #18
def forward(Q, p, G, h, A, b, verbose=0, maxIter=100, dt=0.2):
    """ Solves a given QP problem by modeling it's dynamic representation
        as an RNN with 'maxIter' loops.
    nineq, nz, neq, nBatch = get_sizes(G, A)

    # Base inverses and transposes
    Q_I = torch.inverse(Q)
    G_T = torch.transpose(G, -1, 1)
    A_T = torch.transpose(A, -1, 1)

    # Intermediate matrix expressions
    GQ_I = -G.bmm(Q_I)  # - G Q^{-1}
    AQ_I = -A.bmm(Q_I)  # - A Q^{-1}
    GA = GQ_I.bmm(A_T)  # - G Q^{-1} A^T
    GG = GQ_I.bmm(G_T)  # - G Q^{-1} G^T
    AA = AQ_I.bmm(A_T)  # - A Q^{-1} A^T
    AG = AQ_I.bmm(G_T)  # - A Q^{-1} G^T
    la_d = GQ_I.bmm(p.unsqueeze(2)) - h.unsqueeze(2)  # - G Q^{-1} p - h
    nu_d = AQ_I.bmm(p.unsqueeze(2)) - b.unsqueeze(2)  # - A Q^{-1} p - b

    lams = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device)
    nus = torch.zeros(nBatch, neq, 1).type_as(Q).to(Q.device)
    zeros = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device)

    for _ in range(maxIter):
        dlams = dt * (GG.bmm(lams) + GA.bmm(nus) + la_d)
        dnus = dt * (AG.bmm(lams) + AA.bmm(nus) + nu_d)
        dlams = torch.max(lams + dlams, zeros) - lams

    zhat = -Q_I.bmm(p.unsqueeze(2) + G_T.bmm(lams) + A_T.bmm(nus))
    slacks = h.unsqueeze(2) - G.bmm(zhat)

    return zhat.squeeze(2), lams.squeeze(2), nus.squeeze(2), slacks.squeeze(2)
예제 #19
def forward(Q, p, G, h, A, b, Q_LU, S_LU, R, eps=1e-12, verbose=0, notImprovedLim=3,
            maxIter=20, solver=KKTSolvers.LU_PARTIAL):
    Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    nineq, nz, neq, nBatch = get_sizes(G, A)

    # Find initial values
    if solver == KKTSolvers.LU_FULL:
        D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
        x, s, z, y = factor_solve_kkt(
            Q, D, G, A, p,
            torch.zeros(nBatch, nineq).type_as(Q),
            -h, -b if b is not None else None)
    elif solver == KKTSolvers.LU_PARTIAL:
        d = torch.ones(nBatch, nineq).type_as(Q)
        factor_kkt(S_LU, R, d)
        x, s, z, y = solve_kkt(
            Q_LU, d, G, A, S_LU,
            p, torch.zeros(nBatch, nineq).type_as(Q),
            -h, -b if neq > 0 else None)
    elif solver == KKTSolvers.IR_UNOPT:
        D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
        x, s, z, y = solve_kkt_ir(
            Q, D, G, A, p,
            torch.zeros(nBatch, nineq).type_as(Q),
            -h, -b if b is not None else None)
        assert False

    # Make all of the slack variables >= 1.
    M = torch.min(s, 1)[0]
    M = M.view(M.size(0), 1).repeat(1, nineq)
    I = M < 0
    s[I] -= M[I] - 1

    # Make all of the inequality dual variables >= 1.
    M = torch.min(z, 1)[0]
    M = M.view(M.size(0), 1).repeat(1, nineq)
    I = M < 0
    z[I] -= M[I] - 1

    best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
    nNotImproved = 0

    for i in range(maxIter):
        # affine scaling direction
        rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
            torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
            torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \
        rs = z
        rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
        ry = torch.bmm(x.unsqueeze(1), A.transpose(
            1, 2)).squeeze(1) - b if neq > 0 else 0.0
        mu = torch.abs((s * z).sum(1).squeeze() / nineq)
        z_resid = torch.norm(rz, 2, 1).squeeze()
        y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
        pri_resid = y_resid + z_resid
        dual_resid = torch.norm(rx, 2, 1).squeeze()
        resids = pri_resid + dual_resid + nineq * mu

        d = z / s
            factor_kkt(S_LU, R, d)
            return best['x'], best['y'], best['z'], best['s']

        if verbose == 1:
            print('iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.format(
                i, pri_resid.mean(), dual_resid.mean(), mu.mean()))
        if best['resids'] is None:
            best['resids'] = resids
            best['x'] = x.clone()
            best['z'] = z.clone()
            best['s'] = s.clone()
            best['y'] = y.clone() if y is not None else None
            nNotImproved = 0
            I = resids < best['resids']
            if I.sum() > 0:
                nNotImproved = 0
                nNotImproved += 1
            I_nz = I.repeat(nz, 1).t()
            I_nineq = I.repeat(nineq, 1).t()
            best['resids'][I] = resids[I]
            best['x'][I_nz] = x[I_nz]
            best['z'][I_nineq] = z[I_nineq]
            best['s'][I_nineq] = s[I_nineq]
            if neq > 0:
                I_neq = I.repeat(neq, 1).t()
                best['y'][I_neq] = y[I_neq]
        if nNotImproved == notImprovedLim or best['resids'].max() < eps or mu.min() > 1e32:
            if best['resids'].max() > 1. and verbose >= 0:
            return best['x'], best['y'], best['z'], best['s']

        if solver == KKTSolvers.LU_FULL:
            D = bdiag(d)
            dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt(
                Q, D, G, A, rx, rs, rz, ry)
        elif solver == KKTSolvers.LU_PARTIAL:
            dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(
                Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
        elif solver == KKTSolvers.IR_UNOPT:
            D = bdiag(d)
            dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir(
                Q, D, G, A, rx, rs, rz, ry)
            assert False

        # compute centering directions
        alpha = torch.min(torch.min(get_step(z, dz_aff),
                                    get_step(s, ds_aff)),
        alpha_nineq = alpha.repeat(nineq, 1).t()
        t1 = s + alpha_nineq * ds_aff
        t2 = z + alpha_nineq * dz_aff
        t3 = torch.sum(t1 * t2, 1).squeeze()
        t4 = torch.sum(s * z, 1).squeeze()
        sig = (t3 / t4)**3

        rx = torch.zeros(nBatch, nz).type_as(Q)
        rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
        rz = torch.zeros(nBatch, nineq).type_as(Q)
        ry = torch.zeros(nBatch, neq).type_as(Q) if neq > 0 else torch.Tensor()

        if solver == KKTSolvers.LU_FULL:
            D = bdiag(d)
            dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(
                Q, D, G, A, rx, rs, rz, ry)
        elif solver == KKTSolvers.LU_PARTIAL:
            dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
                Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
        elif solver == KKTSolvers.IR_UNOPT:
            D = bdiag(d)
            dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir(
                Q, D, G, A, rx, rs, rz, ry)
            assert False

        dx = dx_aff + dx_cor
        ds = ds_aff + ds_cor
        dz = dz_aff + dz_cor
        dy = dy_aff + dy_cor if neq > 0 else None
        alpha = torch.min(0.999 * torch.min(get_step(z, dz),
                                            get_step(s, ds)),
        alpha_nineq = alpha.repeat(nineq, 1).t()
        alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
        alpha_nz = alpha.repeat(nz, 1).t()

        x += alpha_nz * dx
        s += alpha_nineq * ds
        z += alpha_nineq * dz
        y = y + alpha_neq * dy if neq > 0 else None

    if best['resids'].max() > 1. and verbose >= 0:
    return best['x'], best['y'], best['z'], best['s']
예제 #20
def forward(inputs, Q, G, h, A, b, Q_LU, S_LU, R, verbose=False):
    Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    nineq, nz, neq, nBatch = get_sizes(G, A)

    # Find initial values
    d = torch.ones(nBatch, nineq).type_as(Q)
    factor_kkt(S_LU, R, d)
    x, s, z, y = solve_kkt(Q_LU, d, G, A, S_LU, inputs,
                           torch.zeros(nBatch, nineq).type_as(Q), -h,
                           -b if neq > 0 else None)
    # D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
    # x1, s1, z1, y1 = factor_solve_kkt(
    #     Q, D, G, A,
    #     inputs, torch.zeros(nBatch, nineq).type_as(Q),
    #     -h.repeat(nBatch, 1),
    #     nb.repeat(nBatch, 1) if b is not None else None)
    # U_Q, U_S, R = pre_factor_kkt(Q, G, A)
    # factor_kkt(U_S, R, d[0])
    # x2, s2, z2, y2 = solve_kkt(
    #     U_Q, d[0], G, A, U_S,
    #     inputs[0], torch.zeros(nineq).type_as(Q), -h, nb)
    # import IPython, sys; IPython.embed(); sys.exit(-1)

    M = torch.min(s, 1)[0].repeat(1, nineq)
    I = M < 0
    s[I] -= M[I] - 1

    M = torch.min(z, 1)[0].repeat(1, nineq)
    I = M < 0
    z[I] -= M[I] - 1

    best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
    nNotImproved = 0

    for i in range(20):
        # affine scaling direction
        rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
             torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
             torch.bmm(x.unsqueeze(1), Q.transpose(1,2)).squeeze(1) + \
        rs = z
        rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
        ry = torch.bmm(x.unsqueeze(1), A.transpose(
            1, 2)).squeeze(1) - b if neq > 0 else 0.0
        mu = torch.abs((s * z).sum(1).squeeze() / nineq)
        z_resid = torch.norm(rz, 2, 1).squeeze()
        y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
        pri_resid = y_resid + z_resid
        dual_resid = torch.norm(rx, 2, 1).squeeze()
        resids = pri_resid + dual_resid + nineq * mu

        d = z / s
            factor_kkt(S_LU, R, d)
            # TODO: Move this below.
            print('=' * 70 + '\n' +
                  "TODO: Remove try/except around factor_kkt!!!" + '\n')
            return best['x'], best['y'], best['z'], best['s']

        if verbose:
                'iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.
                format(i, pri_resid[0], dual_resid[0], mu[0]))
        if best['resids'] is None:
            best['resids'] = resids
            best['x'] = x.clone()
            best['z'] = z.clone()
            best['s'] = s.clone()
            best['y'] = y.clone() if y is not None else None
            nNotImproved = 0
            I = resids < best['resids']
            if I.sum() > 0:
                nNotImproved = 0
                nNotImproved += 1
            I_nz = I.repeat(nz, 1).t()
            I_nineq = I.repeat(nineq, 1).t()
            best['resids'][I] = resids[I]
            best['x'][I_nz] = x[I_nz]
            best['z'][I_nineq] = z[I_nineq]
            best['s'][I_nineq] = s[I_nineq]
            if neq > 0:
                I_neq = I.repeat(neq, 1).t()
                best['y'][I_neq] = y[I_neq]
        if nNotImproved == 3 or best['resids'].max() < 1e-12:
            return best['x'], best['y'], best['z'], best['s']

        # L_Q, L_S, R_ = pre_factor_kkt(Q, G, A)
        # factor_kkt(L_S, R_, d[0])
        # dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
        #     L_Q, d[0], G, A, L_S, rx[0], rs[0], rz[0], ry[0])
        # TODO: Move factorization back here.
        dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(Q_LU, d, G, A, S_LU, rx, rs,
                                                   rz, ry)

        # D = diaged(d)
        # dx_aff1, ds_aff1, dz_aff1, dy_aff1 = factor_solve_kkt(
        #     Q, D, G, A, rx, rs, rz, ry)
        # dx_aff2, ds_aff2, dz_aff2, dy_aff2 = factor_solve_kkt(
        #     Q, D[0], G, A, rx[0], rs[0], 0, ry[0])

        # compute centering directions
        # alpha0 = min(min(get_step(z[0],dz_aff[0]), get_step(s[0], ds_aff[0])), 1.0)
        alpha = torch.min(torch.min(get_step(z, dz_aff), get_step(s, ds_aff)),
        alpha_nineq = alpha.repeat(nineq, 1).t()
        # alpha_nz = alpha.repeat(nz, 1).t()
        # sig0 = (torch.dot(s[0] + alpha[0]*ds_aff[0],
        # z[0] + alpha[0]*dz_aff[0])/(torch.dot(s[0],z[0])))**3
        t1 = s + alpha_nineq * ds_aff
        t2 = z + alpha_nineq * dz_aff
        t3 = torch.sum(t1 * t2, 1).squeeze()
        t4 = torch.sum(s * z, 1).squeeze()
        sig = (t3 / t4)**3
        # dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
        #     U_Q, d, G, A, U_S, torch.zeros(nz).type_as(Q),
        #     (-mu*sig*torch.ones(nineq).type_as(Q) + ds_aff*dz_aff)/s,
        #     torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q), neq, nz)
        # D = diaged(d)
        # dx_cor0, ds_cor0, dz_cor0, dy_cor0 = factor_solve_kkt(Q, D[0], G, A,
        #     torch.zeros(nz).type_as(Q),
        #     (-mu[0]*sig[0]*torch.ones(nineq).type_as(Q)+ds_aff[0]*dz_aff[0])/s[0],
        #     torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q))
        rx = torch.zeros(nBatch, nz).type_as(Q)
        rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
        rz = torch.zeros(nBatch, nineq).type_as(Q)
        ry = torch.zeros(nBatch, neq).type_as(Q)
        # dx_cor1, ds_cor1, dz_cor1, dy_cor1 = factor_solve_kkt(
        #     Q, D, G, A, rx, rs, rz, ry)
        dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(Q_LU, d, G, A, S_LU, rx, rs,
                                                   rz, ry)

        dx = dx_aff + dx_cor
        ds = ds_aff + ds_cor
        dz = dz_aff + dz_cor
        dy = dy_aff + dy_cor if neq > 0 else None
        # import qpth.solvers.pdipm.single as pdipm_s
        # alpha0 = min(1.0, 0.999*min(pdipm_s.get_step(s[0],ds[0]), pdipm_s.get_step(z[0],dz[0])))
        alpha = torch.min(0.999 * torch.min(get_step(z, dz), get_step(s, ds)),
        # assert(np.isnan(alpha0) or np.isinf(alpha0) or alpha0 - alpha[0] <= 1e-5) # TODO: Remove

        alpha_nineq = alpha.repeat(nineq, 1).t()
        alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
        alpha_nz = alpha.repeat(nz, 1).t()
        dx_norm = torch.norm(dx, 2, 1).squeeze()
        dz_norm = torch.norm(dz, 2, 1).squeeze()
        # if TODO ->np.any(np.isnan(dx_norm)) or \
        #    torch.sum(dx_norm > 1e5) > 0 or \
        #    torch.sum(dz_norm > 1e5):
        #     # Overflow, return early
        #     return x, y, z

        x += alpha_nz * dx
        s += alpha_nineq * ds
        z += alpha_nineq * dz
        y = y + alpha_neq * dy if neq > 0 else None

    return best['x'], best['y'], best['z'], best['s']
예제 #21
def forward(inputs_i, Q, G, A, b, h, U_Q, U_S, R, verbose=False):
    b = A z_0
    h = G z_0 + s_0
    U_Q, U_S, R = pre_factor_kkt(Q, G, A, nineq, neq)
    nineq, nz, neq, _ = get_sizes(G, A)

    # find initial values
    d = torch.ones(nineq).type_as(Q)
    nb = -b if b is not None else None
    factor_kkt(U_S, R, d)
    x, s, z, y = solve_kkt(U_Q, d, G, A, U_S, inputs_i,
                           torch.zeros(nineq).type_as(Q), -h, nb)
    # x1, s1, z1, y1 = factor_solve_kkt(Q, torch.eye(nineq).type_as(Q), G, A, inputs_i,
    # torch.zeros(nineq).type_as(Q), -h, nb)

    if torch.min(s) < 0:
        s -= torch.min(s) - 1
    if torch.min(z) < 0:
        z -= torch.min(z) - 1

    prev_resid = None
    for i in range(20):
        # affine scaling direction
        rx = (torch.mv(A.t(), y) if neq > 0 else 0.) + \
            torch.mv(G.t(), z) + torch.mv(Q, x) + inputs_i
        rs = z
        rz = torch.mv(G, x) + s - h
        ry = torch.mv(A, x) - b if neq > 0 else torch.Tensor([0.])
        mu = torch.dot(s, z) / nineq
        pri_resid = torch.norm(ry) + torch.norm(rz)
        dual_resid = torch.norm(rx)
        resid = pri_resid + dual_resid + nineq * mu
        d = z / s
        if verbose:
            print(("primal_res = {0:.5g}, dual_res = {1:.5g}, " +
                   "gap = {2:.5g}, kappa(d) = {3:.5g}").format(
                       pri_resid, dual_resid, mu,
                       min(d) / max(d)))
        # if (pri_resid < 5e-4 and dual_resid < 5e-4 and mu < 4e-4):
        improved = (prev_resid is None) or (resid < prev_resid + 1e-6)
        if not improved or resid < 1e-6:
            return x, y, z
        prev_resid = resid

        factor_kkt(U_S, R, d)
        dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(U_Q, d, G, A, U_S, rx, rs,
                                                   rz, ry)
        # D = torch.diag((z/s).cpu()).type_as(Q)
        # dx_aff1, ds_aff1, dz_aff1, dy_aff1 = factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry)

        # compute centering directions
        alpha = min(min(get_step(z, dz_aff), get_step(s, ds_aff)), 1.0)
        sig = (torch.dot(s + alpha * ds_aff, z + alpha * dz_aff) /
               (torch.dot(s, z)))**3
        dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
            U_Q, d, G, A, U_S,
            (-mu * sig * torch.ones(nineq).type_as(Q) + ds_aff * dz_aff) / s,
        # dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(Q, D, G, A,
        #     torch.zeros(nz).type_as(Q),
        #     (-mu*sig*torch.ones(nineq).type_as(Q) + ds_aff*dz_aff)/s,
        #     torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q))

        dx = dx_aff + dx_cor
        ds = ds_aff + ds_cor
        dz = dz_aff + dz_cor
        dy = dy_aff + dy_cor if neq > 0 else None
        alpha = min(1.0, 0.999 * min(get_step(s, ds), get_step(z, dz)))
        dx_norm = torch.norm(dx)
        dz_norm = torch.norm(dz)
        if np.isnan(dx_norm) or dx_norm > 1e5 or dz_norm > 1e5:
            # Overflow, return early
            return x, y, z

        x += alpha * dx
        s += alpha * ds
        z += alpha * dz
        y = y + alpha * dy if neq > 0 else None

    return x, y, z
예제 #22
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

        Q_LU = lu_hack(Q)
        raise RuntimeError("""
qpth Error: Cannot perform LU factorization on Q.
Please make sure that your Q matrix is PSD and has
a non-zero diagonal.

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    # We compute a partial LU decomposition of the S matrix
    # that can be completed once D^{-1} is known.
    # See the 'Block LU factorization' part of our website
    # for more details.

    G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU))
    R = G_invQ_GT.clone()
    S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
        .repeat(nBatch, 1).type_as(Q).int()
    if neq > 0:
        invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU)
        # if any(torch.isnan(torch.flatten(invQ_AT)).tolist()):
        #     logging.info("nan comes in invq AT")
        # else:
        #     logging.info("non NAN in invq AT")
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)
        # if any(torch.isnan(torch.flatten(G_invQ_AT)).tolist()):
        #     logging.info("nan comes in G_invQ_AT")
        # else:
        #     logging.info("non NAN in G_invQ_AT")
        LU_A_invQ_AT = lu_hack(A_invQ_AT)
        P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT)
        P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)

        S_LU_11 = LU_A_invQ_AT[0]
        U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)).lu_solve(
        # if any(torch.isnan(torch.flatten(U_A_invQ_AT_inv)).tolist()):
        #     logging.info("nan in U_A_invQ_AT_inv")
        S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
        T = sp_lu_solve(G_invQ_AT.transpose(1, 2), *LU_A_invQ_AT)
        # T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
        S_LU_data = torch.cat((torch.cat(
            (S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1)
        S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]
        # if any(torch.isnan(torch.flatten(T)).tolist()):
        #     logging.info("nan comes in T")
        # else:
        #     logging.info("non NAN in T")
        R -= G_invQ_AT.bmm(T)
        # if any(torch.isnan(torch.flatten(R)).tolist()):

        #     logging.info("nan is here")
        # R[torch.isnan(R)] = 0

        S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)
    # S_LU_data[torch.isnan(S_LU_data)] = 0
    S_LU = [S_LU_data, S_LU_pivots]

    return Q_LU, S_LU, R