예제 #1
0
def forward_ineq(Q, p, G, h, verbose=0, maxIter=100, dt=0.2):
    """ Solves a QP problem with only inequality constraints, fastest version
        for the dynamic solver
    """
    nineq, nz, neq, nBatch = get_sizes(G)
    A_T = torch.transpose(G, -1, 1)
    Q_I = torch.inverse(Q)

    AQ_I = G.bmm(Q_I)
    D = -AQ_I.bmm(A_T)
    d = -(h.unsqueeze(2) + AQ_I.bmm(p.unsqueeze(2)))

    lams = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device)
    zeros = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device)

    for _ in range(maxIter):
        dlams = dt * (D.bmm(lams) + d)
        dlams = torch.max(lams + dlams, zeros) - lams
        lams.add_(dlams)
        # lams = torch.where(lams > 0, lams, zeros)
        # lams = torch.max(lams + dlams, zeros)

    zhat = -Q_I.bmm(p.unsqueeze(2) + A_T.bmm(lams))
    slacks = h.unsqueeze(2) - G.bmm(zhat)

    return zhat.squeeze(2), lams.squeeze(2), slacks.squeeze(2)
예제 #2
0
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    invQ_rx = rx.btrs(Q_LU)
    if neq > 0:
        h = torch.cat(
            (invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)) - ry,
             invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)) + rs / d - rz),
            2).squeeze(1)
    else:
        h = (invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)) + rs / d -
             rz).squeeze(1)

    w = -(h.btrs(S_LU))

    g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G)
    if neq > 0:
        g1 -= w[:, :neq].unsqueeze(1).bmm(A)
    g2 = -rs - w[:, neq:]

    dx = g1.btrs(Q_LU)
    ds = g2 / d
    dz = w[:, neq:]
    dy = w[:, :neq] if neq > 0 else None

    return dx, ds, dz, dy
예제 #3
0
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, _ = get_sizes(G, A)

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T           ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]

    U_Q = torch.potrf(Q)
    # partial cholesky of S matrix
    U_S = torch.zeros(neq + nineq, neq + nineq).type_as(Q)

    G_invQ_GT = torch.mm(G, torch.potrs(G.t(), U_Q))
    R = G_invQ_GT
    if neq > 0:
        invQ_AT = torch.potrs(A.t(), U_Q)
        A_invQ_AT = torch.mm(A, invQ_AT)
        G_invQ_AT = torch.mm(G, invQ_AT)

        # TODO: torch.potrf sometimes says the matrix is not PSD but
        # numpy does? I filed an issue at
        # https://github.com/pytorch/pytorch/issues/199
        try:
            U11 = torch.potrf(A_invQ_AT)
        except:
            U11 = torch.Tensor(np.linalg.cholesky(
                A_invQ_AT.cpu().numpy())).type_as(A_invQ_AT)

        # TODO: torch.trtrs is currently not implemented on the GPU
        # and we are using gesv as a workaround.
        U12 = torch.gesv(G_invQ_AT.t(), U11.t())[0]
        U_S[:neq, :neq] = U11
        U_S[:neq, neq:] = U12
        R -= torch.mm(U12.t(), U12)

    return U_Q, U_S, R
예제 #4
0
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
    if neq > 0:
        h = torch.cat(
            (invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry,
             invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d -
             rz), 1)
    else:
        h = invQ_rx.unsqueeze(1).bmm(G.transpose(1,
                                                 2)).squeeze(1) + rs / d - rz
    w = -(h.unsqueeze(2).lu_solve(*S_LU)).squeeze(2)
    # if any(torch.isnan(torch.flatten(w)).tolist()):
    #     logging.info("**Alert** nan in param  w ")
    #     logging.info( any(torch.isnan(torch.flatten(S_LU[0])).tolist()))
    #     logging.info( any(torch.isnan(torch.flatten(S_LU[1])).tolist()))
    #     logging.info( any(torch.isnan(torch.flatten(h)).tolist()))
    #     w = -sp_lu_solve(h.unsqueeze(2),*S_LU).squeeze(2)
    # w[torch.isnan(w)] = 0
    # logging.info(w[0:3])

    g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1)
    if neq > 0:
        g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1)
    g2 = -rs - w[:, neq:]
    dx = g1.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
    ds = g2 / d
    dz = w[:, neq:]
    dy = w[:, :neq] if neq > 0 else None

    return dx, ds, dz, dy
예제 #5
0
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, _ = get_sizes(G, A)

    if neq > 0:
        H_ = torch.cat([
            torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1),
            torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)
        ], 0)
        A_ = torch.cat([
            torch.cat([G, torch.eye(nineq).type_as(Q)], 1),
            torch.cat([A, torch.zeros(neq, nineq).type_as(Q)], 1)
        ], 0)
        g_ = torch.cat([rx, rs], 0)
        h_ = torch.cat([rz, ry], 0)
    else:
        H_ = torch.cat([
            torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1),
            torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)
        ], 0)
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 0)
        h_ = rz

    U_H_ = torch.potrf(H_)

    invH_A_ = torch.potrs(A_.t(), U_H_)
    invH_g_ = torch.potrs(g_.view(-1, 1), U_H_).view(-1)

    S_ = torch.mm(A_, invH_A_)
    U_S_ = torch.potrf(S_)
    t_ = torch.mv(A_, invH_g_).view(-1, 1) - h_
    w_ = -torch.potrs(t_, U_S_).view(-1)
    v_ = torch.potrs(-g_.view(-1, 1) - torch.mv(A_.t(), w_), U_H_).view(-1)

    return v_[:nz], v_[nz:], w_[:nineq], w_[nineq:] if neq > 0 else None
예제 #6
0
def solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry, dbg=False):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, _ = get_sizes(G, A)

    invQ_rx = torch.potrs(rx.view(-1, 1), U_Q).view(-1)
    if neq > 0:
        h = torch.cat(
            [torch.mv(A, invQ_rx) - ry,
             torch.mv(G, invQ_rx) + rs / d - rz], 0)
    else:
        h = torch.mv(G, invQ_rx) + rs / d - rz

    w = -torch.potrs(h.view(-1, 1), U_S).view(-1)

    g1 = -rx - torch.mv(G.t(), w[neq:])
    if neq > 0:
        g1 -= torch.mv(A.t(), w[:neq])
    g2 = -rs - w[neq:]

    dx = torch.potrs(g1.view(-1, 1), U_Q).view(-1)
    ds = g2 / d
    dz = w[neq:]
    dy = w[:neq] if neq > 0 else None

    # if np.all(np.array([x.norm() for x in [rx, rs, rz, ry]]) != 0):
    if dbg:
        import IPython
        import sys
        IPython.embed()
        sys.exit(-1)

    # if rs.norm() > 0: import IPython, sys; IPython.embed(); sys.exit(-1)
    return dx, ds, dz, dy
예제 #7
0
def solve_kkt_ir(Q, D, G, A, rx, rs, rz, ry, niter=1):
    """Inefficient iterative refinement."""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    eps = 1e-7
    Q_tilde = Q + eps * torch.eye(nz).type_as(Q).repeat(nBatch, 1, 1)
    D_tilde = D + eps * torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)

    dx, ds, dz, dy = factor_solve_kkt_reg(
        Q_tilde, D_tilde, G, A, rx, rs, rz, ry, eps)
    res = kkt_resid_reg(Q, D, G, A, eps,
                        dx, ds, dz, dy, rx, rs, rz, ry)
    resx, ress, resz, resy = res
    res = resx
    for k in range(niter):
        ddx, dds, ddz, ddy = factor_solve_kkt_reg(Q_tilde, D_tilde, G, A, -resx, -ress, -resz,
                                                  -resy if resy is not None else None,
                                                  eps)
        dx, ds, dz, dy = [v + dv if v is not None else None
                          for v, dv in zip((dx, ds, dz, dy), (ddx, dds, ddz, ddy))]
        res = kkt_resid_reg(Q, D, G, A, eps,
                            dx, ds, dz, dy, rx, rs, rz, ry)
        resx, ress, resz, resy = res
        # res = torch.cat(resx)
        res = resx

    return dx, ds, dz, dy
예제 #8
0
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q)
    H_[:, :nz, :nz] = Q
    H_[:, -nineq:, -nineq:] = D
    if neq > 0:
        A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)], 2),
                        torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q)], 2)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
    else:
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = lu_hack(H_)

    invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU)
    invH_g_ = g_.lu_solve(*H_LU)

    S_ = torch.bmm(A_, invH_A_)
    S_LU = lu_hack(S_)
    t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
    w_ = -t_.lu_solve(*S_LU)
    t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze()
    v_ = t_.lu_solve(*H_LU)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
예제 #9
0
파일: batch.py 프로젝트: lopa23/flim_optcrf
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)
    
    
    try:
        Q_LU = lu_hack(Q)
    except:
        raise RuntimeError("""
qpth Error: Cannot perform LU factorization on Q.
Please make sure that your Q matrix is PSD and has
a non-zero diagonal.
""")

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    #
    # We compute a partial LU decomposition of the S matrix
    # that can be completed once D^{-1} is known.
    # See the 'Block LU factorization' part of our website
    # for more details.
    qlu, pivots=Q_LU
   ##put in a condition for m>1
    if G.size(2)==2: ##something funny here
        print(G.size(),qlu.size())
        G_invQ_GT = torch.matmul(G, G.lu_solve(*Q_LU).transpose(1,2))
    else:
        print(G.size(2),G.transpose(1,2).size(),qlu.size())
        G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU))
    
    R = G_invQ_GT.clone()
    S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
        .repeat(nBatch, 1).type_as(Q).int()
    if neq > 0:
        invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU)
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)

        LU_A_invQ_AT = lu_hack(A_invQ_AT)
        P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT)
        P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)

        S_LU_11 = LU_A_invQ_AT[0]
        U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)
                           ).lu_solve(*LU_A_invQ_AT)
        S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
        T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
        S_LU_data = torch.cat((torch.cat((S_LU_11, S_LU_12), 2),
                               torch.cat((S_LU_21, S_LU_22), 2)),
                              1)
        S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]

        R -= G_invQ_AT.bmm(T)
    else:
        S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)

    S_LU = [S_LU_data, S_LU_pivots]
    return Q_LU, S_LU, R
예제 #10
0
def forward_eq_conv(Q, p, G_, h_, A_, b_, verbose=0, maxIter=100, dt=0.2):
    """ DOES NOT WORK, ONLY HERE TO DOCUMENT OUR PROCESS
        Solves the QP problem by transforming equality constraints into two
        inequality constraint, then using the ineq solver.
    """
    nineq, nz, neq, nBatch = get_sizes(G_, A_)

    G = torch.cat((G_, A_, -A_), dim=1)
    h = torch.cat((h_, b_, -b_), dim=1)
    A_T = torch.transpose(G, -1, 1)
    Q_I = torch.inverse(Q)

    AQ_I = G.bmm(Q_I)
    D = -AQ_I.bmm(A_T)
    d = -(h.unsqueeze(2) + AQ_I.bmm(p.unsqueeze(2)))

    lams = torch.zeros(nBatch, nineq + neq + neq, 1).type_as(Q).to(Q.device)
    zeros = torch.zeros(nBatch, nineq + neq + neq, 1).type_as(Q).to(Q.device)

    for _ in range(maxIter):
        dlams = dt * (D.bmm(lams) + d)
        dlams = torch.max(lams + dlams, zeros) - lams
        lams.add_(dlams)

    zhat = -Q_I.bmm(p.unsqueeze(2) + A_T.bmm(lams))
    slacks = h.unsqueeze(2) - G.bmm(zhat)

    return zhat.squeeze(2), lams.squeeze(2)[:, :nineq], lams.squeeze(
        2)[:, -neq:], slacks.squeeze(2)
예제 #11
0
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    Q_LU = Q.btrf()

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    #
    # We compute a partial LU decomposition of S matrix
    # that can be completed once D^{-1} is known.
    # This is done for a general matrix by decomposing
    # S using the Schur complement and then LU-factorizing
    # the matrices in the middle:
    #
    #   [ A B ] = [ I            0 ] [ A     0              ] [ I    A^{-1} B ]
    #   [ C D ]   [ C A^{-1}     I ] [ 0     D - C A^{-1} B ] [ 0    I        ]

    G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrs(Q_LU))
    R = G_invQ_GT
    if neq > 0:
        invQ_AT = A.transpose(1, 2).btrs(Q_LU)
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)

        LU_A_invQ_AT = A_invQ_AT.btrf()
        if neq == 1:
            LU_A_invQ_AT = LU_A_invQ_AT.view(-1, 1, 1)
        M = torch.tril(torch.ones(neq,
                                  neq)).unsqueeze(0).expand(nBatch, neq,
                                                            neq).type_as(Q)
        L_A_invQ_AT = M * LU_A_invQ_AT
        L_A_invQ_AT[torch.eye(neq).unsqueeze(0).expand(
            nBatch, neq, neq).type_as(Q).byte()] = 1.0
        M = torch.triu(torch.ones(neq,
                                  neq)).unsqueeze(0).expand(nBatch, neq,
                                                            neq).type_as(Q)
        U_A_invQ_AT = M * LU_A_invQ_AT

        S_LU_11 = LU_A_invQ_AT
        S_LU_21 = G_invQ_AT.bmm(L_A_invQ_AT.btrs(LU_A_invQ_AT))
        T = G_invQ_AT.transpose(1, 2).btrs(LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU = torch.cat(
            (torch.cat((S_LU_11, S_LU_12), 2),
             torch.cat(
                 (S_LU_21, torch.zeros(nBatch, nineq, nineq).type_as(Q)), 2)),
            1)
        R -= G_invQ_AT.bmm(T)
    else:
        S_LU = torch.zeros(nBatch, nineq, nineq).type_as(Q)

    return Q_LU, S_LU, R
예제 #12
0
파일: batch.py 프로젝트: lopa23/flim_optcrf
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
    """ Solve KKT equations for the affine step"""
    nineq, nz, neq, nBatch = get_sizes(G, A)
    qlu, pivots=Q_LU
    #print("InEq cons, nz, eq cons, btch",nineq, nz, neq, nBatch) #how is G changing dim to 3D
    #print("Now in solve KKt", G.size(),rx.size(), qlu.size())
    if G.size(1)>=2:
        invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
    else:
        invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)

    #print("Size after",invQ_rx.size(),G.size(),rs.size(),rz.size())
    
    if neq > 0:
        h = torch.cat((invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry,
                       invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz), 1)
    else:
        
        h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz
        #if(rz.size(0)==1):
           # h=torch.cat(h,h,1)
        
    slu, pv=S_LU

    #print("W getting generated",h.size(),slu.size())
    if G.size(1)>=2 and h.size(0)>1:
         w = -(h.unsqueeze(0).lu_solve(*S_LU)).squeeze(2)
         w=w[:,0,:]
    else:
        w = -(h.unsqueeze(2).lu_solve(*S_LU)).squeeze(2)

    #print(w.size())
    if G.size(1)>=2:
         g1 = -rx - w[:, neq:].matmul(G).squeeze(1)
    else:
        g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1)
    if neq > 0:
        g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1)
     
    g2 = -rs - w[:, neq:]
    #print(g1.size(),qlu.size())   
    if g1.dim()>2:
         dx=g1.transpose(1,2).lu_solve(*Q_LU).transpose(1,2).squeeze(2)
    else:
         dx = g1.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
    ds = g2 / d
    dz = w[:, neq:]
    dy = w[:, :neq] if neq > 0 else None
   # print("Done Solve KKT");
    return dx, ds, dz, dy
예제 #13
0
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    if neq > 0:
        # import IPython, sys; IPython.embed(); sys.exit(-1)
        # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1),
        #                 torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        # A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q)], 1),
        #                 torch.cat([A, torch.zeros(neq, nineq).type_as(Q)], 1)], 0)
        # g_ = torch.cat([rx, rs], 0)
        # h_ = torch.cat([rz, ry], 0)

        H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q)
        H_[:, :nz, :nz] = Q.repeat(nBatch, 1, 1)
        H_[:, -nineq:, -nineq:] = D

        from block import block
        A_ = block(((G, 'I'), (A, torch.zeros(neq, nineq).type_as(Q))))

        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
    else:
        H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q)
        H_[:, :nz, :nz] = Q.repeat(nBatch, 1, 1)
        H_[:, -nineq:, -nineq:] = D
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = H_.btrf()

    A = A_.repeat(nBatch, 1, 1)
    invH_A_ = A.transpose(1, 2).btrs(H_LU)
    invH_g_ = g_.btrs(H_LU)

    S_ = torch.bmm(A, invH_A_)
    S_LU = S_.btrf()
    t_ = torch.mm(invH_g_, A_.t()) - h_
    w_ = -t_.btrs(S_LU)
    t_ = -g_ - w_.mm(A_)
    v_ = t_.btrs(H_LU)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
예제 #14
0
파일: batch.py 프로젝트: lopa23/flim_optcrf
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q)
    H_[:, :nz, :nz] = Q
    H_[:, -nineq:, -nineq:] = D
    G=G.squeeze(0)#added
    if neq > 0:
        A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)], 2),
                        torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q)], 2)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
    else:
        #print(G.size(),torch.eye(nineq).type_as(Q).size())
        A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = lu_hack(H_)
    lu, pv=H_LU
    #
   # print(A_.size())
    invH_A_ = A_.lu_solve(*H_LU).unsqueeze(1)#changed from  A_.transpose(1, 2).lu_solve(*H_LU)
    invH_g_ = g_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)

   
   
    A_=A_.unsqueeze(0).transpose(1,2)#added
    S_ = torch.bmm(A_, invH_A_)
    S_LU = lu_hack(S_)
    slu, pv=S_LU

    #print(A_.size(),invH_g_.unsqueeze(1).size(), torch.bmm(invH_g_.unsqueeze(1), A_).size(),h_.size())
    t_ = torch.bmm(A_,invH_g_.unsqueeze(1)).squeeze(1) - h_#changed from torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
    
    w_ = -t_.lu_solve(*S_LU).squeeze(2)#changed
    #print(g_.size(),w_.size(),A_.size())
    t_ = -g_ - w_.bmm(A_).squeeze()
    v_ = t_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
예제 #15
0
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    Q_LU = Q.btrifact()

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    #
    # We compute a partial LU decomposition of the S matrix
    # that can be completed once D^{-1} is known.
    # See the 'Block LU factorization' part of our website
    # for more details.

    G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrisolve(*Q_LU))
    R = G_invQ_GT.clone()
    S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
        .repeat(nBatch, 1).type_as(Q).int()
    if neq > 0:
        invQ_AT = A.transpose(1, 2).btrisolve(*Q_LU)
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)

        LU_A_invQ_AT = A_invQ_AT.btrifact()
        P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.btriunpack(*LU_A_invQ_AT)
        P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)

        S_LU_11 = LU_A_invQ_AT[0]
        U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)).btrisolve(
            *LU_A_invQ_AT)
        S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
        T = G_invQ_AT.transpose(1, 2).btrisolve(*LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
        S_LU_data = torch.cat((torch.cat(
            (S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1)
        S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]

        R -= G_invQ_AT.bmm(T)
    else:
        S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)

    S_LU = [S_LU_data, S_LU_pivots]
    return Q_LU, S_LU, R
예제 #16
0
파일: batch.py 프로젝트: jdily/qpth
def factor_solve_kkt_reg(Q_tilde, D, G, A, rx, rs, rz, ry, eps):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q_tilde)
    H_[:, :nz, :nz] = Q_tilde
    H_[:, -nineq:, -nineq:] = D
    if neq > 0:
        # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1),
        # torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([
            torch.cat(
                [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)],
                2),
            torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q_tilde)], 2)
        ], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
    else:
        A_ = torch.cat(
            [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = H_.btrifact()

    invH_A_ = A_.transpose(1, 2).btrisolve(*H_LU)
    invH_g_ = g_.btrisolve(*H_LU)

    S_ = torch.bmm(A_, invH_A_)
    S_ -= eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)
    S_LU = S_.btrifact()
    t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
    w_ = -t_.btrisolve(*S_LU)
    t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze()
    v_ = t_.btrisolve(*H_LU)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
예제 #17
0
def forward_eq_new(Q_, p_, G_, h_, A_, b_, verbose=0, maxIter=100, dt=0.2):
    """ Solves equality constraints by dynamically solving both priml and
        dual problems side-by-side
    """
    neq, nz, _, nBatch = get_sizes(A_)

    Q__ = torch.cat((Q_, -Q_), dim=1)
    Q = torch.cat((Q__, -Q__), dim=2)
    p = torch.cat((p_, -p_), dim=1).unsqueeze(2)
    A = torch.cat((A_, -A_), dim=2)
    b = b_.unsqueeze(2)

    Q_T = torch.transpose(Q, -1, 1)
    A_T = torch.transpose(A, -1, 1)
    p_T = torch.transpose(p, -1, 1)
    b_T = torch.transpose(b, -1, 1)

    # Expand 'x' to allow for negative values
    x = torch.zeros(nBatch, 2 * nz, 1).type_as(Q).to(Q.device)
    y = torch.zeros(nBatch, neq, 1).type_as(Q).to(Q.device)

    for _ in range(maxIter):
        x_T = torch.transpose(x, -1, 1)

        g = x_T.bmm(Q).bmm(x) + p_T.bmm(x) - b_T.bmm(y)
        r = neg(Q.bmm(x) + p - A_T.bmm(y))
        dFx = g.bmm(2 * Q.bmm(x) + p)
        dFy = g.bmm(b)
        dx = -dFx - A_T.bmm(A.bmm(x) - b) - neg(x) - Q_T.bmm(r)
        dy = dFy + A.bmm(r)

        x.add_(dx)
        y.add_(dy)

    slacks = torch.zeros(nBatch, neq).type_as(Q).to(Q.device)

    return (x[:, :nz, :] - x[:, nz:, :]).squeeze(2), y.squeeze(2), slacks
예제 #18
0
def forward(Q, p, G, h, A, b, verbose=0, maxIter=100, dt=0.2):
    """ Solves a given QP problem by modeling it's dynamic representation
        as an RNN with 'maxIter' loops.
    """
    nineq, nz, neq, nBatch = get_sizes(G, A)

    # Base inverses and transposes
    Q_I = torch.inverse(Q)
    G_T = torch.transpose(G, -1, 1)
    A_T = torch.transpose(A, -1, 1)

    # Intermediate matrix expressions
    GQ_I = -G.bmm(Q_I)  # - G Q^{-1}
    AQ_I = -A.bmm(Q_I)  # - A Q^{-1}
    GA = GQ_I.bmm(A_T)  # - G Q^{-1} A^T
    GG = GQ_I.bmm(G_T)  # - G Q^{-1} G^T
    AA = AQ_I.bmm(A_T)  # - A Q^{-1} A^T
    AG = AQ_I.bmm(G_T)  # - A Q^{-1} G^T
    la_d = GQ_I.bmm(p.unsqueeze(2)) - h.unsqueeze(2)  # - G Q^{-1} p - h
    nu_d = AQ_I.bmm(p.unsqueeze(2)) - b.unsqueeze(2)  # - A Q^{-1} p - b

    lams = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device)
    nus = torch.zeros(nBatch, neq, 1).type_as(Q).to(Q.device)
    zeros = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device)

    for _ in range(maxIter):
        dlams = dt * (GG.bmm(lams) + GA.bmm(nus) + la_d)
        dnus = dt * (AG.bmm(lams) + AA.bmm(nus) + nu_d)
        dlams = torch.max(lams + dlams, zeros) - lams
        lams.add_(dlams)
        nus.add_(dnus)

    zhat = -Q_I.bmm(p.unsqueeze(2) + G_T.bmm(lams) + A_T.bmm(nus))
    slacks = h.unsqueeze(2) - G.bmm(zhat)

    return zhat.squeeze(2), lams.squeeze(2), nus.squeeze(2), slacks.squeeze(2)
예제 #19
0
def forward(Q, p, G, h, A, b, Q_LU, S_LU, R, eps=1e-12, verbose=0, notImprovedLim=3,
            maxIter=20, solver=KKTSolvers.LU_PARTIAL):
    """
    Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    """
    nineq, nz, neq, nBatch = get_sizes(G, A)

    # Find initial values
    if solver == KKTSolvers.LU_FULL:
        D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
        x, s, z, y = factor_solve_kkt(
            Q, D, G, A, p,
            torch.zeros(nBatch, nineq).type_as(Q),
            -h, -b if b is not None else None)
    elif solver == KKTSolvers.LU_PARTIAL:
        d = torch.ones(nBatch, nineq).type_as(Q)
        factor_kkt(S_LU, R, d)
        x, s, z, y = solve_kkt(
            Q_LU, d, G, A, S_LU,
            p, torch.zeros(nBatch, nineq).type_as(Q),
            -h, -b if neq > 0 else None)
    elif solver == KKTSolvers.IR_UNOPT:
        D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
        x, s, z, y = solve_kkt_ir(
            Q, D, G, A, p,
            torch.zeros(nBatch, nineq).type_as(Q),
            -h, -b if b is not None else None)
    else:
        assert False

    # Make all of the slack variables >= 1.
    M = torch.min(s, 1)[0]
    M = M.view(M.size(0), 1).repeat(1, nineq)
    I = M < 0
    s[I] -= M[I] - 1

    # Make all of the inequality dual variables >= 1.
    M = torch.min(z, 1)[0]
    M = M.view(M.size(0), 1).repeat(1, nineq)
    I = M < 0
    z[I] -= M[I] - 1

    best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
    nNotImproved = 0

    for i in range(maxIter):
        # affine scaling direction
        rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
            torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
            torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \
            p
        rs = z
        rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
        ry = torch.bmm(x.unsqueeze(1), A.transpose(
            1, 2)).squeeze(1) - b if neq > 0 else 0.0
        mu = torch.abs((s * z).sum(1).squeeze() / nineq)
        z_resid = torch.norm(rz, 2, 1).squeeze()
        y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
        pri_resid = y_resid + z_resid
        dual_resid = torch.norm(rx, 2, 1).squeeze()
        resids = pri_resid + dual_resid + nineq * mu

        d = z / s
        try:
            factor_kkt(S_LU, R, d)
        except:
            return best['x'], best['y'], best['z'], best['s']

        if verbose == 1:
            print('iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.format(
                i, pri_resid.mean(), dual_resid.mean(), mu.mean()))
        if best['resids'] is None:
            best['resids'] = resids
            best['x'] = x.clone()
            best['z'] = z.clone()
            best['s'] = s.clone()
            best['y'] = y.clone() if y is not None else None
            nNotImproved = 0
        else:
            I = resids < best['resids']
            if I.sum() > 0:
                nNotImproved = 0
            else:
                nNotImproved += 1
            I_nz = I.repeat(nz, 1).t()
            I_nineq = I.repeat(nineq, 1).t()
            best['resids'][I] = resids[I]
            best['x'][I_nz] = x[I_nz]
            best['z'][I_nineq] = z[I_nineq]
            best['s'][I_nineq] = s[I_nineq]
            if neq > 0:
                I_neq = I.repeat(neq, 1).t()
                best['y'][I_neq] = y[I_neq]
        if nNotImproved == notImprovedLim or best['resids'].max() < eps or mu.min() > 1e32:
            if best['resids'].max() > 1. and verbose >= 0:
                print(INACC_ERR)
            return best['x'], best['y'], best['z'], best['s']

        if solver == KKTSolvers.LU_FULL:
            D = bdiag(d)
            dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt(
                Q, D, G, A, rx, rs, rz, ry)
        elif solver == KKTSolvers.LU_PARTIAL:
            dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(
                Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
        elif solver == KKTSolvers.IR_UNOPT:
            D = bdiag(d)
            dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir(
                Q, D, G, A, rx, rs, rz, ry)
        else:
            assert False

        # compute centering directions
        alpha = torch.min(torch.min(get_step(z, dz_aff),
                                    get_step(s, ds_aff)),
                          torch.ones(nBatch).type_as(Q))
        alpha_nineq = alpha.repeat(nineq, 1).t()
        t1 = s + alpha_nineq * ds_aff
        t2 = z + alpha_nineq * dz_aff
        t3 = torch.sum(t1 * t2, 1).squeeze()
        t4 = torch.sum(s * z, 1).squeeze()
        sig = (t3 / t4)**3

        rx = torch.zeros(nBatch, nz).type_as(Q)
        rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
        rz = torch.zeros(nBatch, nineq).type_as(Q)
        ry = torch.zeros(nBatch, neq).type_as(Q) if neq > 0 else torch.Tensor()

        if solver == KKTSolvers.LU_FULL:
            D = bdiag(d)
            dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(
                Q, D, G, A, rx, rs, rz, ry)
        elif solver == KKTSolvers.LU_PARTIAL:
            dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
                Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
        elif solver == KKTSolvers.IR_UNOPT:
            D = bdiag(d)
            dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir(
                Q, D, G, A, rx, rs, rz, ry)
        else:
            assert False

        dx = dx_aff + dx_cor
        ds = ds_aff + ds_cor
        dz = dz_aff + dz_cor
        dy = dy_aff + dy_cor if neq > 0 else None
        alpha = torch.min(0.999 * torch.min(get_step(z, dz),
                                            get_step(s, ds)),
                          torch.ones(nBatch).type_as(Q))
        alpha_nineq = alpha.repeat(nineq, 1).t()
        alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
        alpha_nz = alpha.repeat(nz, 1).t()

        x += alpha_nz * dx
        s += alpha_nineq * ds
        z += alpha_nineq * dz
        y = y + alpha_neq * dy if neq > 0 else None

    if best['resids'].max() > 1. and verbose >= 0:
        print(INACC_ERR)
    return best['x'], best['y'], best['z'], best['s']
예제 #20
0
def forward(inputs, Q, G, h, A, b, Q_LU, S_LU, R, verbose=False):
    """
    Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    """
    nineq, nz, neq, nBatch = get_sizes(G, A)

    # Find initial values
    d = torch.ones(nBatch, nineq).type_as(Q)
    factor_kkt(S_LU, R, d)
    x, s, z, y = solve_kkt(Q_LU, d, G, A, S_LU, inputs,
                           torch.zeros(nBatch, nineq).type_as(Q), -h,
                           -b if neq > 0 else None)
    # D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
    # x1, s1, z1, y1 = factor_solve_kkt(
    #     Q, D, G, A,
    #     inputs, torch.zeros(nBatch, nineq).type_as(Q),
    #     -h.repeat(nBatch, 1),
    #     nb.repeat(nBatch, 1) if b is not None else None)
    # U_Q, U_S, R = pre_factor_kkt(Q, G, A)
    # factor_kkt(U_S, R, d[0])
    # x2, s2, z2, y2 = solve_kkt(
    #     U_Q, d[0], G, A, U_S,
    #     inputs[0], torch.zeros(nineq).type_as(Q), -h, nb)
    # import IPython, sys; IPython.embed(); sys.exit(-1)

    M = torch.min(s, 1)[0].repeat(1, nineq)
    I = M < 0
    s[I] -= M[I] - 1

    M = torch.min(z, 1)[0].repeat(1, nineq)
    I = M < 0
    z[I] -= M[I] - 1

    best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
    nNotImproved = 0

    for i in range(20):
        # affine scaling direction
        rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
             torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
             torch.bmm(x.unsqueeze(1), Q.transpose(1,2)).squeeze(1) + \
             inputs
        rs = z
        rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
        ry = torch.bmm(x.unsqueeze(1), A.transpose(
            1, 2)).squeeze(1) - b if neq > 0 else 0.0
        mu = torch.abs((s * z).sum(1).squeeze() / nineq)
        z_resid = torch.norm(rz, 2, 1).squeeze()
        y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
        pri_resid = y_resid + z_resid
        dual_resid = torch.norm(rx, 2, 1).squeeze()
        resids = pri_resid + dual_resid + nineq * mu

        d = z / s
        try:
            factor_kkt(S_LU, R, d)
        except:
            # TODO: Move this below.
            print('=' * 70 + '\n' +
                  "TODO: Remove try/except around factor_kkt!!!" + '\n')
            return best['x'], best['y'], best['z'], best['s']

        if verbose:
            print(
                'iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.
                format(i, pri_resid[0], dual_resid[0], mu[0]))
        if best['resids'] is None:
            best['resids'] = resids
            best['x'] = x.clone()
            best['z'] = z.clone()
            best['s'] = s.clone()
            best['y'] = y.clone() if y is not None else None
            nNotImproved = 0
        else:
            I = resids < best['resids']
            if I.sum() > 0:
                nNotImproved = 0
            else:
                nNotImproved += 1
            I_nz = I.repeat(nz, 1).t()
            I_nineq = I.repeat(nineq, 1).t()
            best['resids'][I] = resids[I]
            best['x'][I_nz] = x[I_nz]
            best['z'][I_nineq] = z[I_nineq]
            best['s'][I_nineq] = s[I_nineq]
            if neq > 0:
                I_neq = I.repeat(neq, 1).t()
                best['y'][I_neq] = y[I_neq]
        if nNotImproved == 3 or best['resids'].max() < 1e-12:
            return best['x'], best['y'], best['z'], best['s']

        # L_Q, L_S, R_ = pre_factor_kkt(Q, G, A)
        # factor_kkt(L_S, R_, d[0])
        # dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
        #     L_Q, d[0], G, A, L_S, rx[0], rs[0], rz[0], ry[0])
        # TODO: Move factorization back here.
        dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(Q_LU, d, G, A, S_LU, rx, rs,
                                                   rz, ry)

        # D = diaged(d)
        # dx_aff1, ds_aff1, dz_aff1, dy_aff1 = factor_solve_kkt(
        #     Q, D, G, A, rx, rs, rz, ry)
        # dx_aff2, ds_aff2, dz_aff2, dy_aff2 = factor_solve_kkt(
        #     Q, D[0], G, A, rx[0], rs[0], 0, ry[0])

        # compute centering directions
        # alpha0 = min(min(get_step(z[0],dz_aff[0]), get_step(s[0], ds_aff[0])), 1.0)
        alpha = torch.min(torch.min(get_step(z, dz_aff), get_step(s, ds_aff)),
                          torch.ones(nBatch).type_as(Q))
        alpha_nineq = alpha.repeat(nineq, 1).t()
        # alpha_nz = alpha.repeat(nz, 1).t()
        # sig0 = (torch.dot(s[0] + alpha[0]*ds_aff[0],
        # z[0] + alpha[0]*dz_aff[0])/(torch.dot(s[0],z[0])))**3
        t1 = s + alpha_nineq * ds_aff
        t2 = z + alpha_nineq * dz_aff
        t3 = torch.sum(t1 * t2, 1).squeeze()
        t4 = torch.sum(s * z, 1).squeeze()
        sig = (t3 / t4)**3
        # dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
        #     U_Q, d, G, A, U_S, torch.zeros(nz).type_as(Q),
        #     (-mu*sig*torch.ones(nineq).type_as(Q) + ds_aff*dz_aff)/s,
        #     torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q), neq, nz)
        # D = diaged(d)
        # dx_cor0, ds_cor0, dz_cor0, dy_cor0 = factor_solve_kkt(Q, D[0], G, A,
        #     torch.zeros(nz).type_as(Q),
        #     (-mu[0]*sig[0]*torch.ones(nineq).type_as(Q)+ds_aff[0]*dz_aff[0])/s[0],
        #     torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q))
        rx = torch.zeros(nBatch, nz).type_as(Q)
        rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
        rz = torch.zeros(nBatch, nineq).type_as(Q)
        ry = torch.zeros(nBatch, neq).type_as(Q)
        # dx_cor1, ds_cor1, dz_cor1, dy_cor1 = factor_solve_kkt(
        #     Q, D, G, A, rx, rs, rz, ry)
        dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(Q_LU, d, G, A, S_LU, rx, rs,
                                                   rz, ry)

        dx = dx_aff + dx_cor
        ds = ds_aff + ds_cor
        dz = dz_aff + dz_cor
        dy = dy_aff + dy_cor if neq > 0 else None
        # import qpth.solvers.pdipm.single as pdipm_s
        # alpha0 = min(1.0, 0.999*min(pdipm_s.get_step(s[0],ds[0]), pdipm_s.get_step(z[0],dz[0])))
        alpha = torch.min(0.999 * torch.min(get_step(z, dz), get_step(s, ds)),
                          torch.ones(nBatch).type_as(Q))
        # assert(np.isnan(alpha0) or np.isinf(alpha0) or alpha0 - alpha[0] <= 1e-5) # TODO: Remove

        alpha_nineq = alpha.repeat(nineq, 1).t()
        alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
        alpha_nz = alpha.repeat(nz, 1).t()
        dx_norm = torch.norm(dx, 2, 1).squeeze()
        dz_norm = torch.norm(dz, 2, 1).squeeze()
        # if TODO ->np.any(np.isnan(dx_norm)) or \
        #    torch.sum(dx_norm > 1e5) > 0 or \
        #    torch.sum(dz_norm > 1e5):
        #     # Overflow, return early
        #     return x, y, z

        x += alpha_nz * dx
        s += alpha_nineq * ds
        z += alpha_nineq * dz
        y = y + alpha_neq * dy if neq > 0 else None

    return best['x'], best['y'], best['z'], best['s']
예제 #21
0
def forward(inputs_i, Q, G, A, b, h, U_Q, U_S, R, verbose=False):
    """
    b = A z_0
    h = G z_0 + s_0
    U_Q, U_S, R = pre_factor_kkt(Q, G, A, nineq, neq)
    """
    nineq, nz, neq, _ = get_sizes(G, A)

    # find initial values
    d = torch.ones(nineq).type_as(Q)
    nb = -b if b is not None else None
    factor_kkt(U_S, R, d)
    x, s, z, y = solve_kkt(U_Q, d, G, A, U_S, inputs_i,
                           torch.zeros(nineq).type_as(Q), -h, nb)
    # x1, s1, z1, y1 = factor_solve_kkt(Q, torch.eye(nineq).type_as(Q), G, A, inputs_i,
    # torch.zeros(nineq).type_as(Q), -h, nb)

    if torch.min(s) < 0:
        s -= torch.min(s) - 1
    if torch.min(z) < 0:
        z -= torch.min(z) - 1

    prev_resid = None
    for i in range(20):
        # affine scaling direction
        rx = (torch.mv(A.t(), y) if neq > 0 else 0.) + \
            torch.mv(G.t(), z) + torch.mv(Q, x) + inputs_i
        rs = z
        rz = torch.mv(G, x) + s - h
        ry = torch.mv(A, x) - b if neq > 0 else torch.Tensor([0.])
        mu = torch.dot(s, z) / nineq
        pri_resid = torch.norm(ry) + torch.norm(rz)
        dual_resid = torch.norm(rx)
        resid = pri_resid + dual_resid + nineq * mu
        d = z / s
        if verbose:
            print(("primal_res = {0:.5g}, dual_res = {1:.5g}, " +
                   "gap = {2:.5g}, kappa(d) = {3:.5g}").format(
                       pri_resid, dual_resid, mu,
                       min(d) / max(d)))
        # if (pri_resid < 5e-4 and dual_resid < 5e-4 and mu < 4e-4):
        improved = (prev_resid is None) or (resid < prev_resid + 1e-6)
        if not improved or resid < 1e-6:
            return x, y, z
        prev_resid = resid

        factor_kkt(U_S, R, d)
        dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(U_Q, d, G, A, U_S, rx, rs,
                                                   rz, ry)
        # D = torch.diag((z/s).cpu()).type_as(Q)
        # dx_aff1, ds_aff1, dz_aff1, dy_aff1 = factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry)

        # compute centering directions
        alpha = min(min(get_step(z, dz_aff), get_step(s, ds_aff)), 1.0)
        sig = (torch.dot(s + alpha * ds_aff, z + alpha * dz_aff) /
               (torch.dot(s, z)))**3
        dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
            U_Q, d, G, A, U_S,
            torch.zeros(nz).type_as(Q),
            (-mu * sig * torch.ones(nineq).type_as(Q) + ds_aff * dz_aff) / s,
            torch.zeros(nineq).type_as(Q),
            torch.zeros(neq).type_as(Q))
        # dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(Q, D, G, A,
        #     torch.zeros(nz).type_as(Q),
        #     (-mu*sig*torch.ones(nineq).type_as(Q) + ds_aff*dz_aff)/s,
        #     torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q))

        dx = dx_aff + dx_cor
        ds = ds_aff + ds_cor
        dz = dz_aff + dz_cor
        dy = dy_aff + dy_cor if neq > 0 else None
        alpha = min(1.0, 0.999 * min(get_step(s, ds), get_step(z, dz)))
        dx_norm = torch.norm(dx)
        dz_norm = torch.norm(dz)
        if np.isnan(dx_norm) or dx_norm > 1e5 or dz_norm > 1e5:
            # Overflow, return early
            return x, y, z

        x += alpha * dx
        s += alpha * ds
        z += alpha * dz
        y = y + alpha * dy if neq > 0 else None

    return x, y, z
예제 #22
0
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    try:
        Q_LU = lu_hack(Q)
    except:
        raise RuntimeError("""
qpth Error: Cannot perform LU factorization on Q.
Please make sure that your Q matrix is PSD and has
a non-zero diagonal.
""")

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    #
    # We compute a partial LU decomposition of the S matrix
    # that can be completed once D^{-1} is known.
    # See the 'Block LU factorization' part of our website
    # for more details.

    G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU))
    R = G_invQ_GT.clone()
    S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
        .repeat(nBatch, 1).type_as(Q).int()
    if neq > 0:
        invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU)
        # if any(torch.isnan(torch.flatten(invQ_AT)).tolist()):
        #     logging.info("nan comes in invq AT")
        # else:
        #     logging.info("non NAN in invq AT")
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)
        # if any(torch.isnan(torch.flatten(G_invQ_AT)).tolist()):
        #     logging.info("nan comes in G_invQ_AT")
        # else:
        #     logging.info("non NAN in G_invQ_AT")
        LU_A_invQ_AT = lu_hack(A_invQ_AT)
        P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT)
        P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)

        S_LU_11 = LU_A_invQ_AT[0]
        U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)).lu_solve(
            *LU_A_invQ_AT)
        # if any(torch.isnan(torch.flatten(U_A_invQ_AT_inv)).tolist()):
        #     logging.info("nan in U_A_invQ_AT_inv")
        S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
        T = sp_lu_solve(G_invQ_AT.transpose(1, 2), *LU_A_invQ_AT)
        # T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
        S_LU_data = torch.cat((torch.cat(
            (S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1)
        S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]
        # if any(torch.isnan(torch.flatten(T)).tolist()):
        #     logging.info("nan comes in T")
        # else:
        #     logging.info("non NAN in T")
        R -= G_invQ_AT.bmm(T)
        # if any(torch.isnan(torch.flatten(R)).tolist()):

        #     logging.info("nan is here")
        # R[torch.isnan(R)] = 0

    else:
        S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)
    # S_LU_data[torch.isnan(S_LU_data)] = 0
    S_LU = [S_LU_data, S_LU_pivots]

    return Q_LU, S_LU, R