def forward_ineq(Q, p, G, h, verbose=0, maxIter=100, dt=0.2): """ Solves a QP problem with only inequality constraints, fastest version for the dynamic solver """ nineq, nz, neq, nBatch = get_sizes(G) A_T = torch.transpose(G, -1, 1) Q_I = torch.inverse(Q) AQ_I = G.bmm(Q_I) D = -AQ_I.bmm(A_T) d = -(h.unsqueeze(2) + AQ_I.bmm(p.unsqueeze(2))) lams = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device) zeros = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device) for _ in range(maxIter): dlams = dt * (D.bmm(lams) + d) dlams = torch.max(lams + dlams, zeros) - lams lams.add_(dlams) # lams = torch.where(lams > 0, lams, zeros) # lams = torch.max(lams + dlams, zeros) zhat = -Q_I.bmm(p.unsqueeze(2) + A_T.bmm(lams)) slacks = h.unsqueeze(2) - G.bmm(zhat) return zhat.squeeze(2), lams.squeeze(2), slacks.squeeze(2)
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry): """ Solve KKT equations for the affine step""" nineq, nz, neq, nBatch = get_sizes(G, A) invQ_rx = rx.btrs(Q_LU) if neq > 0: h = torch.cat( (invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)) - ry, invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)) + rs / d - rz), 2).squeeze(1) else: h = (invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)) + rs / d - rz).squeeze(1) w = -(h.btrs(S_LU)) g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G) if neq > 0: g1 -= w[:, :neq].unsqueeze(1).bmm(A) g2 = -rs - w[:, neq:] dx = g1.btrs(Q_LU) ds = g2 / d dz = w[:, neq:] dy = w[:, :neq] if neq > 0 else None return dx, ds, dz, dy
def pre_factor_kkt(Q, G, A): """ Perform all one-time factorizations and cache relevant matrix products""" nineq, nz, neq, _ = get_sizes(G, A) # S = [ A Q^{-1} A^T A Q^{-1} G^T ] # [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ] U_Q = torch.potrf(Q) # partial cholesky of S matrix U_S = torch.zeros(neq + nineq, neq + nineq).type_as(Q) G_invQ_GT = torch.mm(G, torch.potrs(G.t(), U_Q)) R = G_invQ_GT if neq > 0: invQ_AT = torch.potrs(A.t(), U_Q) A_invQ_AT = torch.mm(A, invQ_AT) G_invQ_AT = torch.mm(G, invQ_AT) # TODO: torch.potrf sometimes says the matrix is not PSD but # numpy does? I filed an issue at # https://github.com/pytorch/pytorch/issues/199 try: U11 = torch.potrf(A_invQ_AT) except: U11 = torch.Tensor(np.linalg.cholesky( A_invQ_AT.cpu().numpy())).type_as(A_invQ_AT) # TODO: torch.trtrs is currently not implemented on the GPU # and we are using gesv as a workaround. U12 = torch.gesv(G_invQ_AT.t(), U11.t())[0] U_S[:neq, :neq] = U11 U_S[:neq, neq:] = U12 R -= torch.mm(U12.t(), U12) return U_Q, U_S, R
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry): """ Solve KKT equations for the affine step""" nineq, nz, neq, nBatch = get_sizes(G, A) invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2) if neq > 0: h = torch.cat( (invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry, invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz), 1) else: h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz w = -(h.unsqueeze(2).lu_solve(*S_LU)).squeeze(2) # if any(torch.isnan(torch.flatten(w)).tolist()): # logging.info("**Alert** nan in param w ") # logging.info( any(torch.isnan(torch.flatten(S_LU[0])).tolist())) # logging.info( any(torch.isnan(torch.flatten(S_LU[1])).tolist())) # logging.info( any(torch.isnan(torch.flatten(h)).tolist())) # w = -sp_lu_solve(h.unsqueeze(2),*S_LU).squeeze(2) # w[torch.isnan(w)] = 0 # logging.info(w[0:3]) g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1) if neq > 0: g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1) g2 = -rs - w[:, neq:] dx = g1.unsqueeze(2).lu_solve(*Q_LU).squeeze(2) ds = g2 / d dz = w[:, neq:] dy = w[:, :neq] if neq > 0 else None return dx, ds, dz, dy
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry): nineq, nz, neq, _ = get_sizes(G, A) if neq > 0: H_ = torch.cat([ torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1), torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1) ], 0) A_ = torch.cat([ torch.cat([G, torch.eye(nineq).type_as(Q)], 1), torch.cat([A, torch.zeros(neq, nineq).type_as(Q)], 1) ], 0) g_ = torch.cat([rx, rs], 0) h_ = torch.cat([rz, ry], 0) else: H_ = torch.cat([ torch.cat([Q, torch.zeros(nz, nineq).type_as(Q)], 1), torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1) ], 0) A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1) g_ = torch.cat([rx, rs], 0) h_ = rz U_H_ = torch.potrf(H_) invH_A_ = torch.potrs(A_.t(), U_H_) invH_g_ = torch.potrs(g_.view(-1, 1), U_H_).view(-1) S_ = torch.mm(A_, invH_A_) U_S_ = torch.potrf(S_) t_ = torch.mv(A_, invH_g_).view(-1, 1) - h_ w_ = -torch.potrs(t_, U_S_).view(-1) v_ = torch.potrs(-g_.view(-1, 1) - torch.mv(A_.t(), w_), U_H_).view(-1) return v_[:nz], v_[nz:], w_[:nineq], w_[nineq:] if neq > 0 else None
def solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry, dbg=False): """ Solve KKT equations for the affine step""" nineq, nz, neq, _ = get_sizes(G, A) invQ_rx = torch.potrs(rx.view(-1, 1), U_Q).view(-1) if neq > 0: h = torch.cat( [torch.mv(A, invQ_rx) - ry, torch.mv(G, invQ_rx) + rs / d - rz], 0) else: h = torch.mv(G, invQ_rx) + rs / d - rz w = -torch.potrs(h.view(-1, 1), U_S).view(-1) g1 = -rx - torch.mv(G.t(), w[neq:]) if neq > 0: g1 -= torch.mv(A.t(), w[:neq]) g2 = -rs - w[neq:] dx = torch.potrs(g1.view(-1, 1), U_Q).view(-1) ds = g2 / d dz = w[neq:] dy = w[:neq] if neq > 0 else None # if np.all(np.array([x.norm() for x in [rx, rs, rz, ry]]) != 0): if dbg: import IPython import sys IPython.embed() sys.exit(-1) # if rs.norm() > 0: import IPython, sys; IPython.embed(); sys.exit(-1) return dx, ds, dz, dy
def solve_kkt_ir(Q, D, G, A, rx, rs, rz, ry, niter=1): """Inefficient iterative refinement.""" nineq, nz, neq, nBatch = get_sizes(G, A) eps = 1e-7 Q_tilde = Q + eps * torch.eye(nz).type_as(Q).repeat(nBatch, 1, 1) D_tilde = D + eps * torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1) dx, ds, dz, dy = factor_solve_kkt_reg( Q_tilde, D_tilde, G, A, rx, rs, rz, ry, eps) res = kkt_resid_reg(Q, D, G, A, eps, dx, ds, dz, dy, rx, rs, rz, ry) resx, ress, resz, resy = res res = resx for k in range(niter): ddx, dds, ddz, ddy = factor_solve_kkt_reg(Q_tilde, D_tilde, G, A, -resx, -ress, -resz, -resy if resy is not None else None, eps) dx, ds, dz, dy = [v + dv if v is not None else None for v, dv in zip((dx, ds, dz, dy), (ddx, dds, ddz, ddy))] res = kkt_resid_reg(Q, D, G, A, eps, dx, ds, dz, dy, rx, rs, rz, ry) resx, ress, resz, resy = res # res = torch.cat(resx) res = resx return dx, ds, dz, dy
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry): nineq, nz, neq, nBatch = get_sizes(G, A) H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q) H_[:, :nz, :nz] = Q H_[:, -nineq:, -nineq:] = D if neq > 0: A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)], 2), torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q)], 2)], 1) g_ = torch.cat([rx, rs], 1) h_ = torch.cat([rz, ry], 1) else: A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1) g_ = torch.cat([rx, rs], 1) h_ = rz H_LU = lu_hack(H_) invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU) invH_g_ = g_.lu_solve(*H_LU) S_ = torch.bmm(A_, invH_A_) S_LU = lu_hack(S_) t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_ w_ = -t_.lu_solve(*S_LU) t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze() v_ = t_.lu_solve(*H_LU) dx = v_[:, :nz] ds = v_[:, nz:] dz = w_[:, :nineq] dy = w_[:, nineq:] if neq > 0 else None return dx, ds, dz, dy
def pre_factor_kkt(Q, G, A): """ Perform all one-time factorizations and cache relevant matrix products""" nineq, nz, neq, nBatch = get_sizes(G, A) try: Q_LU = lu_hack(Q) except: raise RuntimeError(""" qpth Error: Cannot perform LU factorization on Q. Please make sure that your Q matrix is PSD and has a non-zero diagonal. """) # S = [ A Q^{-1} A^T A Q^{-1} G^T ] # [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ] # # We compute a partial LU decomposition of the S matrix # that can be completed once D^{-1} is known. # See the 'Block LU factorization' part of our website # for more details. qlu, pivots=Q_LU ##put in a condition for m>1 if G.size(2)==2: ##something funny here print(G.size(),qlu.size()) G_invQ_GT = torch.matmul(G, G.lu_solve(*Q_LU).transpose(1,2)) else: print(G.size(2),G.transpose(1,2).size(),qlu.size()) G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU)) R = G_invQ_GT.clone() S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \ .repeat(nBatch, 1).type_as(Q).int() if neq > 0: invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU) A_invQ_AT = torch.bmm(A, invQ_AT) G_invQ_AT = torch.bmm(G, invQ_AT) LU_A_invQ_AT = lu_hack(A_invQ_AT) P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT) P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT) S_LU_11 = LU_A_invQ_AT[0] U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT) ).lu_solve(*LU_A_invQ_AT) S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv) T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT) S_LU_12 = U_A_invQ_AT.bmm(T) S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q) S_LU_data = torch.cat((torch.cat((S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1) S_LU_pivots[:, :neq] = LU_A_invQ_AT[1] R -= G_invQ_AT.bmm(T) else: S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q) S_LU = [S_LU_data, S_LU_pivots] return Q_LU, S_LU, R
def forward_eq_conv(Q, p, G_, h_, A_, b_, verbose=0, maxIter=100, dt=0.2): """ DOES NOT WORK, ONLY HERE TO DOCUMENT OUR PROCESS Solves the QP problem by transforming equality constraints into two inequality constraint, then using the ineq solver. """ nineq, nz, neq, nBatch = get_sizes(G_, A_) G = torch.cat((G_, A_, -A_), dim=1) h = torch.cat((h_, b_, -b_), dim=1) A_T = torch.transpose(G, -1, 1) Q_I = torch.inverse(Q) AQ_I = G.bmm(Q_I) D = -AQ_I.bmm(A_T) d = -(h.unsqueeze(2) + AQ_I.bmm(p.unsqueeze(2))) lams = torch.zeros(nBatch, nineq + neq + neq, 1).type_as(Q).to(Q.device) zeros = torch.zeros(nBatch, nineq + neq + neq, 1).type_as(Q).to(Q.device) for _ in range(maxIter): dlams = dt * (D.bmm(lams) + d) dlams = torch.max(lams + dlams, zeros) - lams lams.add_(dlams) zhat = -Q_I.bmm(p.unsqueeze(2) + A_T.bmm(lams)) slacks = h.unsqueeze(2) - G.bmm(zhat) return zhat.squeeze(2), lams.squeeze(2)[:, :nineq], lams.squeeze( 2)[:, -neq:], slacks.squeeze(2)
def pre_factor_kkt(Q, G, A): """ Perform all one-time factorizations and cache relevant matrix products""" nineq, nz, neq, nBatch = get_sizes(G, A) Q_LU = Q.btrf() # S = [ A Q^{-1} A^T A Q^{-1} G^T ] # [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ] # # We compute a partial LU decomposition of S matrix # that can be completed once D^{-1} is known. # This is done for a general matrix by decomposing # S using the Schur complement and then LU-factorizing # the matrices in the middle: # # [ A B ] = [ I 0 ] [ A 0 ] [ I A^{-1} B ] # [ C D ] [ C A^{-1} I ] [ 0 D - C A^{-1} B ] [ 0 I ] G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrs(Q_LU)) R = G_invQ_GT if neq > 0: invQ_AT = A.transpose(1, 2).btrs(Q_LU) A_invQ_AT = torch.bmm(A, invQ_AT) G_invQ_AT = torch.bmm(G, invQ_AT) LU_A_invQ_AT = A_invQ_AT.btrf() if neq == 1: LU_A_invQ_AT = LU_A_invQ_AT.view(-1, 1, 1) M = torch.tril(torch.ones(neq, neq)).unsqueeze(0).expand(nBatch, neq, neq).type_as(Q) L_A_invQ_AT = M * LU_A_invQ_AT L_A_invQ_AT[torch.eye(neq).unsqueeze(0).expand( nBatch, neq, neq).type_as(Q).byte()] = 1.0 M = torch.triu(torch.ones(neq, neq)).unsqueeze(0).expand(nBatch, neq, neq).type_as(Q) U_A_invQ_AT = M * LU_A_invQ_AT S_LU_11 = LU_A_invQ_AT S_LU_21 = G_invQ_AT.bmm(L_A_invQ_AT.btrs(LU_A_invQ_AT)) T = G_invQ_AT.transpose(1, 2).btrs(LU_A_invQ_AT) S_LU_12 = U_A_invQ_AT.bmm(T) S_LU = torch.cat( (torch.cat((S_LU_11, S_LU_12), 2), torch.cat( (S_LU_21, torch.zeros(nBatch, nineq, nineq).type_as(Q)), 2)), 1) R -= G_invQ_AT.bmm(T) else: S_LU = torch.zeros(nBatch, nineq, nineq).type_as(Q) return Q_LU, S_LU, R
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry): """ Solve KKT equations for the affine step""" nineq, nz, neq, nBatch = get_sizes(G, A) qlu, pivots=Q_LU #print("InEq cons, nz, eq cons, btch",nineq, nz, neq, nBatch) #how is G changing dim to 3D #print("Now in solve KKt", G.size(),rx.size(), qlu.size()) if G.size(1)>=2: invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2) else: invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2) #print("Size after",invQ_rx.size(),G.size(),rs.size(),rz.size()) if neq > 0: h = torch.cat((invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry, invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz), 1) else: h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz #if(rz.size(0)==1): # h=torch.cat(h,h,1) slu, pv=S_LU #print("W getting generated",h.size(),slu.size()) if G.size(1)>=2 and h.size(0)>1: w = -(h.unsqueeze(0).lu_solve(*S_LU)).squeeze(2) w=w[:,0,:] else: w = -(h.unsqueeze(2).lu_solve(*S_LU)).squeeze(2) #print(w.size()) if G.size(1)>=2: g1 = -rx - w[:, neq:].matmul(G).squeeze(1) else: g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1) if neq > 0: g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1) g2 = -rs - w[:, neq:] #print(g1.size(),qlu.size()) if g1.dim()>2: dx=g1.transpose(1,2).lu_solve(*Q_LU).transpose(1,2).squeeze(2) else: dx = g1.unsqueeze(2).lu_solve(*Q_LU).squeeze(2) ds = g2 / d dz = w[:, neq:] dy = w[:, :neq] if neq > 0 else None # print("Done Solve KKT"); return dx, ds, dz, dy
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry): nineq, nz, neq, nBatch = get_sizes(G, A) if neq > 0: # import IPython, sys; IPython.embed(); sys.exit(-1) # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1), # torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0) # A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q)], 1), # torch.cat([A, torch.zeros(neq, nineq).type_as(Q)], 1)], 0) # g_ = torch.cat([rx, rs], 0) # h_ = torch.cat([rz, ry], 0) H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q) H_[:, :nz, :nz] = Q.repeat(nBatch, 1, 1) H_[:, -nineq:, -nineq:] = D from block import block A_ = block(((G, 'I'), (A, torch.zeros(neq, nineq).type_as(Q)))) g_ = torch.cat([rx, rs], 1) h_ = torch.cat([rz, ry], 1) else: H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q) H_[:, :nz, :nz] = Q.repeat(nBatch, 1, 1) H_[:, -nineq:, -nineq:] = D A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1) g_ = torch.cat([rx, rs], 1) h_ = rz H_LU = H_.btrf() A = A_.repeat(nBatch, 1, 1) invH_A_ = A.transpose(1, 2).btrs(H_LU) invH_g_ = g_.btrs(H_LU) S_ = torch.bmm(A, invH_A_) S_LU = S_.btrf() t_ = torch.mm(invH_g_, A_.t()) - h_ w_ = -t_.btrs(S_LU) t_ = -g_ - w_.mm(A_) v_ = t_.btrs(H_LU) dx = v_[:, :nz] ds = v_[:, nz:] dz = w_[:, :nineq] dy = w_[:, nineq:] if neq > 0 else None return dx, ds, dz, dy
def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry): nineq, nz, neq, nBatch = get_sizes(G, A) H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q) H_[:, :nz, :nz] = Q H_[:, -nineq:, -nineq:] = D G=G.squeeze(0)#added if neq > 0: A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)], 2), torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q)], 2)], 1) g_ = torch.cat([rx, rs], 1) h_ = torch.cat([rz, ry], 1) else: #print(G.size(),torch.eye(nineq).type_as(Q).size()) A_ = torch.cat([G, torch.eye(nineq).type_as(Q)], 1) g_ = torch.cat([rx, rs], 1) h_ = rz H_LU = lu_hack(H_) lu, pv=H_LU # # print(A_.size()) invH_A_ = A_.lu_solve(*H_LU).unsqueeze(1)#changed from A_.transpose(1, 2).lu_solve(*H_LU) invH_g_ = g_.unsqueeze(2).lu_solve(*H_LU).squeeze(2) A_=A_.unsqueeze(0).transpose(1,2)#added S_ = torch.bmm(A_, invH_A_) S_LU = lu_hack(S_) slu, pv=S_LU #print(A_.size(),invH_g_.unsqueeze(1).size(), torch.bmm(invH_g_.unsqueeze(1), A_).size(),h_.size()) t_ = torch.bmm(A_,invH_g_.unsqueeze(1)).squeeze(1) - h_#changed from torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_ w_ = -t_.lu_solve(*S_LU).squeeze(2)#changed #print(g_.size(),w_.size(),A_.size()) t_ = -g_ - w_.bmm(A_).squeeze() v_ = t_.unsqueeze(2).lu_solve(*H_LU).squeeze(2) dx = v_[:, :nz] ds = v_[:, nz:] dz = w_[:, :nineq] dy = w_[:, nineq:] if neq > 0 else None return dx, ds, dz, dy
def pre_factor_kkt(Q, G, A): """ Perform all one-time factorizations and cache relevant matrix products""" nineq, nz, neq, nBatch = get_sizes(G, A) Q_LU = Q.btrifact() # S = [ A Q^{-1} A^T A Q^{-1} G^T ] # [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ] # # We compute a partial LU decomposition of the S matrix # that can be completed once D^{-1} is known. # See the 'Block LU factorization' part of our website # for more details. G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrisolve(*Q_LU)) R = G_invQ_GT.clone() S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \ .repeat(nBatch, 1).type_as(Q).int() if neq > 0: invQ_AT = A.transpose(1, 2).btrisolve(*Q_LU) A_invQ_AT = torch.bmm(A, invQ_AT) G_invQ_AT = torch.bmm(G, invQ_AT) LU_A_invQ_AT = A_invQ_AT.btrifact() P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.btriunpack(*LU_A_invQ_AT) P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT) S_LU_11 = LU_A_invQ_AT[0] U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)).btrisolve( *LU_A_invQ_AT) S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv) T = G_invQ_AT.transpose(1, 2).btrisolve(*LU_A_invQ_AT) S_LU_12 = U_A_invQ_AT.bmm(T) S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q) S_LU_data = torch.cat((torch.cat( (S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1) S_LU_pivots[:, :neq] = LU_A_invQ_AT[1] R -= G_invQ_AT.bmm(T) else: S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q) S_LU = [S_LU_data, S_LU_pivots] return Q_LU, S_LU, R
def factor_solve_kkt_reg(Q_tilde, D, G, A, rx, rs, rz, ry, eps): nineq, nz, neq, nBatch = get_sizes(G, A) H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q_tilde) H_[:, :nz, :nz] = Q_tilde H_[:, -nineq:, -nineq:] = D if neq > 0: # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1), # torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0) A_ = torch.cat([ torch.cat( [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2), torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q_tilde)], 2) ], 1) g_ = torch.cat([rx, rs], 1) h_ = torch.cat([rz, ry], 1) else: A_ = torch.cat( [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2) g_ = torch.cat([rx, rs], 1) h_ = rz H_LU = H_.btrifact() invH_A_ = A_.transpose(1, 2).btrisolve(*H_LU) invH_g_ = g_.btrisolve(*H_LU) S_ = torch.bmm(A_, invH_A_) S_ -= eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(nBatch, 1, 1) S_LU = S_.btrifact() t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_ w_ = -t_.btrisolve(*S_LU) t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze() v_ = t_.btrisolve(*H_LU) dx = v_[:, :nz] ds = v_[:, nz:] dz = w_[:, :nineq] dy = w_[:, nineq:] if neq > 0 else None return dx, ds, dz, dy
def forward_eq_new(Q_, p_, G_, h_, A_, b_, verbose=0, maxIter=100, dt=0.2): """ Solves equality constraints by dynamically solving both priml and dual problems side-by-side """ neq, nz, _, nBatch = get_sizes(A_) Q__ = torch.cat((Q_, -Q_), dim=1) Q = torch.cat((Q__, -Q__), dim=2) p = torch.cat((p_, -p_), dim=1).unsqueeze(2) A = torch.cat((A_, -A_), dim=2) b = b_.unsqueeze(2) Q_T = torch.transpose(Q, -1, 1) A_T = torch.transpose(A, -1, 1) p_T = torch.transpose(p, -1, 1) b_T = torch.transpose(b, -1, 1) # Expand 'x' to allow for negative values x = torch.zeros(nBatch, 2 * nz, 1).type_as(Q).to(Q.device) y = torch.zeros(nBatch, neq, 1).type_as(Q).to(Q.device) for _ in range(maxIter): x_T = torch.transpose(x, -1, 1) g = x_T.bmm(Q).bmm(x) + p_T.bmm(x) - b_T.bmm(y) r = neg(Q.bmm(x) + p - A_T.bmm(y)) dFx = g.bmm(2 * Q.bmm(x) + p) dFy = g.bmm(b) dx = -dFx - A_T.bmm(A.bmm(x) - b) - neg(x) - Q_T.bmm(r) dy = dFy + A.bmm(r) x.add_(dx) y.add_(dy) slacks = torch.zeros(nBatch, neq).type_as(Q).to(Q.device) return (x[:, :nz, :] - x[:, nz:, :]).squeeze(2), y.squeeze(2), slacks
def forward(Q, p, G, h, A, b, verbose=0, maxIter=100, dt=0.2): """ Solves a given QP problem by modeling it's dynamic representation as an RNN with 'maxIter' loops. """ nineq, nz, neq, nBatch = get_sizes(G, A) # Base inverses and transposes Q_I = torch.inverse(Q) G_T = torch.transpose(G, -1, 1) A_T = torch.transpose(A, -1, 1) # Intermediate matrix expressions GQ_I = -G.bmm(Q_I) # - G Q^{-1} AQ_I = -A.bmm(Q_I) # - A Q^{-1} GA = GQ_I.bmm(A_T) # - G Q^{-1} A^T GG = GQ_I.bmm(G_T) # - G Q^{-1} G^T AA = AQ_I.bmm(A_T) # - A Q^{-1} A^T AG = AQ_I.bmm(G_T) # - A Q^{-1} G^T la_d = GQ_I.bmm(p.unsqueeze(2)) - h.unsqueeze(2) # - G Q^{-1} p - h nu_d = AQ_I.bmm(p.unsqueeze(2)) - b.unsqueeze(2) # - A Q^{-1} p - b lams = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device) nus = torch.zeros(nBatch, neq, 1).type_as(Q).to(Q.device) zeros = torch.zeros(nBatch, nineq, 1).type_as(Q).to(Q.device) for _ in range(maxIter): dlams = dt * (GG.bmm(lams) + GA.bmm(nus) + la_d) dnus = dt * (AG.bmm(lams) + AA.bmm(nus) + nu_d) dlams = torch.max(lams + dlams, zeros) - lams lams.add_(dlams) nus.add_(dnus) zhat = -Q_I.bmm(p.unsqueeze(2) + G_T.bmm(lams) + A_T.bmm(nus)) slacks = h.unsqueeze(2) - G.bmm(zhat) return zhat.squeeze(2), lams.squeeze(2), nus.squeeze(2), slacks.squeeze(2)
def forward(Q, p, G, h, A, b, Q_LU, S_LU, R, eps=1e-12, verbose=0, notImprovedLim=3, maxIter=20, solver=KKTSolvers.LU_PARTIAL): """ Q_LU, S_LU, R = pre_factor_kkt(Q, G, A) """ nineq, nz, neq, nBatch = get_sizes(G, A) # Find initial values if solver == KKTSolvers.LU_FULL: D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q) x, s, z, y = factor_solve_kkt( Q, D, G, A, p, torch.zeros(nBatch, nineq).type_as(Q), -h, -b if b is not None else None) elif solver == KKTSolvers.LU_PARTIAL: d = torch.ones(nBatch, nineq).type_as(Q) factor_kkt(S_LU, R, d) x, s, z, y = solve_kkt( Q_LU, d, G, A, S_LU, p, torch.zeros(nBatch, nineq).type_as(Q), -h, -b if neq > 0 else None) elif solver == KKTSolvers.IR_UNOPT: D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q) x, s, z, y = solve_kkt_ir( Q, D, G, A, p, torch.zeros(nBatch, nineq).type_as(Q), -h, -b if b is not None else None) else: assert False # Make all of the slack variables >= 1. M = torch.min(s, 1)[0] M = M.view(M.size(0), 1).repeat(1, nineq) I = M < 0 s[I] -= M[I] - 1 # Make all of the inequality dual variables >= 1. M = torch.min(z, 1)[0] M = M.view(M.size(0), 1).repeat(1, nineq) I = M < 0 z[I] -= M[I] - 1 best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None} nNotImproved = 0 for i in range(maxIter): # affine scaling direction rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \ torch.bmm(z.unsqueeze(1), G).squeeze(1) + \ torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \ p rs = z rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h ry = torch.bmm(x.unsqueeze(1), A.transpose( 1, 2)).squeeze(1) - b if neq > 0 else 0.0 mu = torch.abs((s * z).sum(1).squeeze() / nineq) z_resid = torch.norm(rz, 2, 1).squeeze() y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0 pri_resid = y_resid + z_resid dual_resid = torch.norm(rx, 2, 1).squeeze() resids = pri_resid + dual_resid + nineq * mu d = z / s try: factor_kkt(S_LU, R, d) except: return best['x'], best['y'], best['z'], best['s'] if verbose == 1: print('iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.format( i, pri_resid.mean(), dual_resid.mean(), mu.mean())) if best['resids'] is None: best['resids'] = resids best['x'] = x.clone() best['z'] = z.clone() best['s'] = s.clone() best['y'] = y.clone() if y is not None else None nNotImproved = 0 else: I = resids < best['resids'] if I.sum() > 0: nNotImproved = 0 else: nNotImproved += 1 I_nz = I.repeat(nz, 1).t() I_nineq = I.repeat(nineq, 1).t() best['resids'][I] = resids[I] best['x'][I_nz] = x[I_nz] best['z'][I_nineq] = z[I_nineq] best['s'][I_nineq] = s[I_nineq] if neq > 0: I_neq = I.repeat(neq, 1).t() best['y'][I_neq] = y[I_neq] if nNotImproved == notImprovedLim or best['resids'].max() < eps or mu.min() > 1e32: if best['resids'].max() > 1. and verbose >= 0: print(INACC_ERR) return best['x'], best['y'], best['z'], best['s'] if solver == KKTSolvers.LU_FULL: D = bdiag(d) dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt( Q, D, G, A, rx, rs, rz, ry) elif solver == KKTSolvers.LU_PARTIAL: dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt( Q_LU, d, G, A, S_LU, rx, rs, rz, ry) elif solver == KKTSolvers.IR_UNOPT: D = bdiag(d) dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir( Q, D, G, A, rx, rs, rz, ry) else: assert False # compute centering directions alpha = torch.min(torch.min(get_step(z, dz_aff), get_step(s, ds_aff)), torch.ones(nBatch).type_as(Q)) alpha_nineq = alpha.repeat(nineq, 1).t() t1 = s + alpha_nineq * ds_aff t2 = z + alpha_nineq * dz_aff t3 = torch.sum(t1 * t2, 1).squeeze() t4 = torch.sum(s * z, 1).squeeze() sig = (t3 / t4)**3 rx = torch.zeros(nBatch, nz).type_as(Q) rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s rz = torch.zeros(nBatch, nineq).type_as(Q) ry = torch.zeros(nBatch, neq).type_as(Q) if neq > 0 else torch.Tensor() if solver == KKTSolvers.LU_FULL: D = bdiag(d) dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt( Q, D, G, A, rx, rs, rz, ry) elif solver == KKTSolvers.LU_PARTIAL: dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt( Q_LU, d, G, A, S_LU, rx, rs, rz, ry) elif solver == KKTSolvers.IR_UNOPT: D = bdiag(d) dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir( Q, D, G, A, rx, rs, rz, ry) else: assert False dx = dx_aff + dx_cor ds = ds_aff + ds_cor dz = dz_aff + dz_cor dy = dy_aff + dy_cor if neq > 0 else None alpha = torch.min(0.999 * torch.min(get_step(z, dz), get_step(s, ds)), torch.ones(nBatch).type_as(Q)) alpha_nineq = alpha.repeat(nineq, 1).t() alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None alpha_nz = alpha.repeat(nz, 1).t() x += alpha_nz * dx s += alpha_nineq * ds z += alpha_nineq * dz y = y + alpha_neq * dy if neq > 0 else None if best['resids'].max() > 1. and verbose >= 0: print(INACC_ERR) return best['x'], best['y'], best['z'], best['s']
def forward(inputs, Q, G, h, A, b, Q_LU, S_LU, R, verbose=False): """ Q_LU, S_LU, R = pre_factor_kkt(Q, G, A) """ nineq, nz, neq, nBatch = get_sizes(G, A) # Find initial values d = torch.ones(nBatch, nineq).type_as(Q) factor_kkt(S_LU, R, d) x, s, z, y = solve_kkt(Q_LU, d, G, A, S_LU, inputs, torch.zeros(nBatch, nineq).type_as(Q), -h, -b if neq > 0 else None) # D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q) # x1, s1, z1, y1 = factor_solve_kkt( # Q, D, G, A, # inputs, torch.zeros(nBatch, nineq).type_as(Q), # -h.repeat(nBatch, 1), # nb.repeat(nBatch, 1) if b is not None else None) # U_Q, U_S, R = pre_factor_kkt(Q, G, A) # factor_kkt(U_S, R, d[0]) # x2, s2, z2, y2 = solve_kkt( # U_Q, d[0], G, A, U_S, # inputs[0], torch.zeros(nineq).type_as(Q), -h, nb) # import IPython, sys; IPython.embed(); sys.exit(-1) M = torch.min(s, 1)[0].repeat(1, nineq) I = M < 0 s[I] -= M[I] - 1 M = torch.min(z, 1)[0].repeat(1, nineq) I = M < 0 z[I] -= M[I] - 1 best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None} nNotImproved = 0 for i in range(20): # affine scaling direction rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \ torch.bmm(z.unsqueeze(1), G).squeeze(1) + \ torch.bmm(x.unsqueeze(1), Q.transpose(1,2)).squeeze(1) + \ inputs rs = z rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h ry = torch.bmm(x.unsqueeze(1), A.transpose( 1, 2)).squeeze(1) - b if neq > 0 else 0.0 mu = torch.abs((s * z).sum(1).squeeze() / nineq) z_resid = torch.norm(rz, 2, 1).squeeze() y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0 pri_resid = y_resid + z_resid dual_resid = torch.norm(rx, 2, 1).squeeze() resids = pri_resid + dual_resid + nineq * mu d = z / s try: factor_kkt(S_LU, R, d) except: # TODO: Move this below. print('=' * 70 + '\n' + "TODO: Remove try/except around factor_kkt!!!" + '\n') return best['x'], best['y'], best['z'], best['s'] if verbose: print( 'iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'. format(i, pri_resid[0], dual_resid[0], mu[0])) if best['resids'] is None: best['resids'] = resids best['x'] = x.clone() best['z'] = z.clone() best['s'] = s.clone() best['y'] = y.clone() if y is not None else None nNotImproved = 0 else: I = resids < best['resids'] if I.sum() > 0: nNotImproved = 0 else: nNotImproved += 1 I_nz = I.repeat(nz, 1).t() I_nineq = I.repeat(nineq, 1).t() best['resids'][I] = resids[I] best['x'][I_nz] = x[I_nz] best['z'][I_nineq] = z[I_nineq] best['s'][I_nineq] = s[I_nineq] if neq > 0: I_neq = I.repeat(neq, 1).t() best['y'][I_neq] = y[I_neq] if nNotImproved == 3 or best['resids'].max() < 1e-12: return best['x'], best['y'], best['z'], best['s'] # L_Q, L_S, R_ = pre_factor_kkt(Q, G, A) # factor_kkt(L_S, R_, d[0]) # dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt( # L_Q, d[0], G, A, L_S, rx[0], rs[0], rz[0], ry[0]) # TODO: Move factorization back here. dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry) # D = diaged(d) # dx_aff1, ds_aff1, dz_aff1, dy_aff1 = factor_solve_kkt( # Q, D, G, A, rx, rs, rz, ry) # dx_aff2, ds_aff2, dz_aff2, dy_aff2 = factor_solve_kkt( # Q, D[0], G, A, rx[0], rs[0], 0, ry[0]) # compute centering directions # alpha0 = min(min(get_step(z[0],dz_aff[0]), get_step(s[0], ds_aff[0])), 1.0) alpha = torch.min(torch.min(get_step(z, dz_aff), get_step(s, ds_aff)), torch.ones(nBatch).type_as(Q)) alpha_nineq = alpha.repeat(nineq, 1).t() # alpha_nz = alpha.repeat(nz, 1).t() # sig0 = (torch.dot(s[0] + alpha[0]*ds_aff[0], # z[0] + alpha[0]*dz_aff[0])/(torch.dot(s[0],z[0])))**3 t1 = s + alpha_nineq * ds_aff t2 = z + alpha_nineq * dz_aff t3 = torch.sum(t1 * t2, 1).squeeze() t4 = torch.sum(s * z, 1).squeeze() sig = (t3 / t4)**3 # dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt( # U_Q, d, G, A, U_S, torch.zeros(nz).type_as(Q), # (-mu*sig*torch.ones(nineq).type_as(Q) + ds_aff*dz_aff)/s, # torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q), neq, nz) # D = diaged(d) # dx_cor0, ds_cor0, dz_cor0, dy_cor0 = factor_solve_kkt(Q, D[0], G, A, # torch.zeros(nz).type_as(Q), # (-mu[0]*sig[0]*torch.ones(nineq).type_as(Q)+ds_aff[0]*dz_aff[0])/s[0], # torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q)) rx = torch.zeros(nBatch, nz).type_as(Q) rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s rz = torch.zeros(nBatch, nineq).type_as(Q) ry = torch.zeros(nBatch, neq).type_as(Q) # dx_cor1, ds_cor1, dz_cor1, dy_cor1 = factor_solve_kkt( # Q, D, G, A, rx, rs, rz, ry) dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry) dx = dx_aff + dx_cor ds = ds_aff + ds_cor dz = dz_aff + dz_cor dy = dy_aff + dy_cor if neq > 0 else None # import qpth.solvers.pdipm.single as pdipm_s # alpha0 = min(1.0, 0.999*min(pdipm_s.get_step(s[0],ds[0]), pdipm_s.get_step(z[0],dz[0]))) alpha = torch.min(0.999 * torch.min(get_step(z, dz), get_step(s, ds)), torch.ones(nBatch).type_as(Q)) # assert(np.isnan(alpha0) or np.isinf(alpha0) or alpha0 - alpha[0] <= 1e-5) # TODO: Remove alpha_nineq = alpha.repeat(nineq, 1).t() alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None alpha_nz = alpha.repeat(nz, 1).t() dx_norm = torch.norm(dx, 2, 1).squeeze() dz_norm = torch.norm(dz, 2, 1).squeeze() # if TODO ->np.any(np.isnan(dx_norm)) or \ # torch.sum(dx_norm > 1e5) > 0 or \ # torch.sum(dz_norm > 1e5): # # Overflow, return early # return x, y, z x += alpha_nz * dx s += alpha_nineq * ds z += alpha_nineq * dz y = y + alpha_neq * dy if neq > 0 else None return best['x'], best['y'], best['z'], best['s']
def forward(inputs_i, Q, G, A, b, h, U_Q, U_S, R, verbose=False): """ b = A z_0 h = G z_0 + s_0 U_Q, U_S, R = pre_factor_kkt(Q, G, A, nineq, neq) """ nineq, nz, neq, _ = get_sizes(G, A) # find initial values d = torch.ones(nineq).type_as(Q) nb = -b if b is not None else None factor_kkt(U_S, R, d) x, s, z, y = solve_kkt(U_Q, d, G, A, U_S, inputs_i, torch.zeros(nineq).type_as(Q), -h, nb) # x1, s1, z1, y1 = factor_solve_kkt(Q, torch.eye(nineq).type_as(Q), G, A, inputs_i, # torch.zeros(nineq).type_as(Q), -h, nb) if torch.min(s) < 0: s -= torch.min(s) - 1 if torch.min(z) < 0: z -= torch.min(z) - 1 prev_resid = None for i in range(20): # affine scaling direction rx = (torch.mv(A.t(), y) if neq > 0 else 0.) + \ torch.mv(G.t(), z) + torch.mv(Q, x) + inputs_i rs = z rz = torch.mv(G, x) + s - h ry = torch.mv(A, x) - b if neq > 0 else torch.Tensor([0.]) mu = torch.dot(s, z) / nineq pri_resid = torch.norm(ry) + torch.norm(rz) dual_resid = torch.norm(rx) resid = pri_resid + dual_resid + nineq * mu d = z / s if verbose: print(("primal_res = {0:.5g}, dual_res = {1:.5g}, " + "gap = {2:.5g}, kappa(d) = {3:.5g}").format( pri_resid, dual_resid, mu, min(d) / max(d))) # if (pri_resid < 5e-4 and dual_resid < 5e-4 and mu < 4e-4): improved = (prev_resid is None) or (resid < prev_resid + 1e-6) if not improved or resid < 1e-6: return x, y, z prev_resid = resid factor_kkt(U_S, R, d) dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(U_Q, d, G, A, U_S, rx, rs, rz, ry) # D = torch.diag((z/s).cpu()).type_as(Q) # dx_aff1, ds_aff1, dz_aff1, dy_aff1 = factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry) # compute centering directions alpha = min(min(get_step(z, dz_aff), get_step(s, ds_aff)), 1.0) sig = (torch.dot(s + alpha * ds_aff, z + alpha * dz_aff) / (torch.dot(s, z)))**3 dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt( U_Q, d, G, A, U_S, torch.zeros(nz).type_as(Q), (-mu * sig * torch.ones(nineq).type_as(Q) + ds_aff * dz_aff) / s, torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q)) # dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(Q, D, G, A, # torch.zeros(nz).type_as(Q), # (-mu*sig*torch.ones(nineq).type_as(Q) + ds_aff*dz_aff)/s, # torch.zeros(nineq).type_as(Q), torch.zeros(neq).type_as(Q)) dx = dx_aff + dx_cor ds = ds_aff + ds_cor dz = dz_aff + dz_cor dy = dy_aff + dy_cor if neq > 0 else None alpha = min(1.0, 0.999 * min(get_step(s, ds), get_step(z, dz))) dx_norm = torch.norm(dx) dz_norm = torch.norm(dz) if np.isnan(dx_norm) or dx_norm > 1e5 or dz_norm > 1e5: # Overflow, return early return x, y, z x += alpha * dx s += alpha * ds z += alpha * dz y = y + alpha * dy if neq > 0 else None return x, y, z
def pre_factor_kkt(Q, G, A): """ Perform all one-time factorizations and cache relevant matrix products""" nineq, nz, neq, nBatch = get_sizes(G, A) try: Q_LU = lu_hack(Q) except: raise RuntimeError(""" qpth Error: Cannot perform LU factorization on Q. Please make sure that your Q matrix is PSD and has a non-zero diagonal. """) # S = [ A Q^{-1} A^T A Q^{-1} G^T ] # [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ] # # We compute a partial LU decomposition of the S matrix # that can be completed once D^{-1} is known. # See the 'Block LU factorization' part of our website # for more details. G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU)) R = G_invQ_GT.clone() S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \ .repeat(nBatch, 1).type_as(Q).int() if neq > 0: invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU) # if any(torch.isnan(torch.flatten(invQ_AT)).tolist()): # logging.info("nan comes in invq AT") # else: # logging.info("non NAN in invq AT") A_invQ_AT = torch.bmm(A, invQ_AT) G_invQ_AT = torch.bmm(G, invQ_AT) # if any(torch.isnan(torch.flatten(G_invQ_AT)).tolist()): # logging.info("nan comes in G_invQ_AT") # else: # logging.info("non NAN in G_invQ_AT") LU_A_invQ_AT = lu_hack(A_invQ_AT) P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT) P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT) S_LU_11 = LU_A_invQ_AT[0] U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)).lu_solve( *LU_A_invQ_AT) # if any(torch.isnan(torch.flatten(U_A_invQ_AT_inv)).tolist()): # logging.info("nan in U_A_invQ_AT_inv") S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv) T = sp_lu_solve(G_invQ_AT.transpose(1, 2), *LU_A_invQ_AT) # T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT) S_LU_12 = U_A_invQ_AT.bmm(T) S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q) S_LU_data = torch.cat((torch.cat( (S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1) S_LU_pivots[:, :neq] = LU_A_invQ_AT[1] # if any(torch.isnan(torch.flatten(T)).tolist()): # logging.info("nan comes in T") # else: # logging.info("non NAN in T") R -= G_invQ_AT.bmm(T) # if any(torch.isnan(torch.flatten(R)).tolist()): # logging.info("nan is here") # R[torch.isnan(R)] = 0 else: S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q) # S_LU_data[torch.isnan(S_LU_data)] = 0 S_LU = [S_LU_data, S_LU_pivots] return Q_LU, S_LU, R