def forward(ctx, Q_, p_, G_, h_, A_, b_): eps = 1e-12 verbose = 0 notImprovedLim = 3 maxIter = 20 nBatch = extract_nBatch(Q_, p_, G_, h_, A_, b_) Q, _ = expandParam(Q_, nBatch, 3) p, _ = expandParam(p_, nBatch, 2) G, _ = expandParam(G_, nBatch, 3) h, _ = expandParam(h_, nBatch, 2) A, _ = expandParam(A_, nBatch, 3) b, _ = expandParam(b_, nBatch, 2) _, nineq, nz = G.size() neq = A.size(1) if A.nelement() > 0 else 0 assert (neq > 0 or nineq > 0) ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz if solver == QPSolvers.PDIPM_BATCHED: ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A) zhats, ctx.nus, ctx.lams, ctx.slacks = pdipm_b.forward( Q, p, G, h, A, b, ctx.Q_LU, ctx.S_LU, ctx.R, eps, verbose, notImprovedLim, maxIter) elif solver == QPSolvers.CVXPY: vals = torch.Tensor(nBatch).type_as(Q) zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q) lams = torch.Tensor(nBatch, ctx.nineq).type_as(Q) nus = torch.Tensor(nBatch, ctx.neq).type_as(Q) \ if ctx.neq > 0 else torch.Tensor() slacks = torch.Tensor(nBatch, ctx.nineq).type_as(Q) for i in range(nBatch): Ai, bi = (A[i], b[i]) if neq > 0 else (None, None) vals[ i], zhati, nui, lami, si = solvers.cvxpy.forward_single_np( *[ x.cpu().numpy() if x is not None else None for x in (Q[i], p[i], G[i], h[i], Ai, bi) ]) # if zhati[0] is None: # import IPython, sys; IPython.embed(); sys.exit(-1) zhats[i] = torch.Tensor(zhati) lams[i] = torch.Tensor(lami) slacks[i] = torch.Tensor(si) if neq > 0: nus[i] = torch.Tensor(nui) ctx.vals = vals ctx.lams = lams ctx.nus = nus ctx.slacks = slacks else: assert False ctx.save_for_backward(zhats, Q_, p_, G_, h_, A_, b_) return zhats
def backward(ctx, *dl_dzhat): zhats, Q, p, G, h, A, b = ctx.saved_tensors nBatch = extract_nBatch(Q, p, G, h, A, b) Q, Q_e = expandParam(Q, nBatch, 3) p, p_e = expandParam(p, nBatch, 2) G, G_e = expandParam(G, nBatch, 3) h, h_e = expandParam(h, nBatch, 2) A, A_e = expandParam(A, nBatch, 3) b, b_e = expandParam(b, nBatch, 2) # neq, nineq, nz = ctx.neq, ctx.nineq, ctx.nz neq, nineq = ctx.neq, ctx.nineq if solver == QPSolvers.CVXPY: ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A) # Clamp here to avoid issues coming up when the slacks are too small. # TODO: A better fix would be to get lams and slacks from the # solver that don't have this issue. d = torch.clamp(ctx.lams, min=1e-8) / torch.clamp(ctx.slacks, min=1e-8) pdipm_b.factor_kkt(ctx.S_LU, ctx.R, d) dx, _, dlam, dnu = pdipm_b.solve_kkt( ctx.Q_LU, d, G, A, ctx.S_LU, dl_dzhat[0], torch.zeros(nBatch, nineq).type_as(G), torch.zeros(nBatch, nineq).type_as(G), torch.zeros(nBatch, neq).type_as(G) if neq > 0 else torch.Tensor()) dps = dx dGs = bger(dlam, zhats) + bger(ctx.lams, dx) if G_e: dGs = dGs.mean(0) dhs = -dlam if h_e: dhs = dhs.mean(0) if neq > 0: dAs = bger(dnu, zhats) + bger(ctx.nus, dx) dbs = -dnu if A_e: dAs = dAs.mean(0) if b_e: dbs = dbs.mean(0) else: dAs, dbs = None, None dQs = 0.5 * (bger(dx, zhats) + bger(zhats, dx)) if Q_e: dQs = dQs.mean(0) if p_e: dps = dps.mean(0) grads = (dQs, dps, dGs, dhs, dAs, dbs) return grads
def test_lu_kkt_solver(): Q, p, G, h, A, b, d, D, rx, rs, rz, ry = get_kkt_problem() dx, ds, dz, dy = pdipm_b.factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry) Q_LU, S_LU, R = pdipm_b.pre_factor_kkt(Q, G, A) pdipm_b.factor_kkt(S_LU, R, d) dx_, ds_, dz_, dy_ = pdipm_b.solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry) npt.assert_allclose(dx.numpy(), dx_.numpy(), rtol=RTOL, atol=ATOL) npt.assert_allclose(ds.numpy(), ds_.numpy(), rtol=RTOL, atol=ATOL) npt.assert_allclose(dz.numpy(), dz_.numpy(), rtol=RTOL, atol=ATOL) npt.assert_allclose(dy.numpy(), dy_.numpy(), rtol=RTOL, atol=ATOL)
def prof_instance(nz, nBatch, cuda=True): nineq, neq = 100, 0 assert (neq == 0) L = npr.rand(nBatch, nz, nz) Q = np.matmul(L, L.transpose((0, 2, 1))) + 1e-3 * np.eye(nz, nz) G = npr.randn(nBatch, nineq, nz) z0 = npr.randn(nBatch, nz) s0 = npr.rand(nBatch, nineq) p = npr.randn(nBatch, nz) h = np.matmul(G, np.expand_dims(z0, axis=(2))).squeeze(2) + s0 A = npr.randn(nBatch, neq, nz) b = np.matmul(A, np.expand_dims(z0, axis=(2))).squeeze(2) zhat_g = [] gurobi_time = 0.0 for i in range(nBatch): m = gpy.Model() zhat = m.addVars(nz, lb=-gpy.GRB.INFINITY, ub=gpy.GRB.INFINITY) obj = 0.0 for j in range(nz): for k in range(nz): obj += 0.5 * Q[i, j, k] * zhat[j] * zhat[k] obj += p[i, j] * zhat[j] m.setObjective(obj) for j in range(nineq): con = 0 for k in range(nz): con += G[i, j, k] * zhat[k] m.addConstr(con <= h[i, j]) m.setParam('OutputFlag', False) start = time.time() m.optimize() gurobi_time += time.time() - start t = np.zeros(nz) for j in range(nz): t[j] = zhat[j].x zhat_g.append(t) p, L, Q, G, z0, s0, h = [torch.Tensor(x) for x in [p, L, Q, G, z0, s0, h]] if cuda: p, L, Q, G, z0, s0, h = [x.cuda() for x in [p, L, Q, G, z0, s0, h]] if neq > 0: A = torch.Tensor(A) b = torch.Tensor(b) else: A, b = [torch.Tensor()] * 2 if cuda: A = A.cuda() b = b.cuda() # af = adact.AdactFunction() single_results = [] start = time.time() for i in range(nBatch): A_i = A[i] if neq > 0 else A b_i = b[i] if neq > 0 else b U_Q, U_S, R = pdipm_s.pre_factor_kkt(Q[i], G[i], A_i) single_results.append( pdipm_s.forward(p[i], Q[i], G[i], A_i, b_i, h[i], U_Q, U_S, R)) single_time = time.time() - start start = time.time() Q_LU, S_LU, R = pdipm_b.pre_factor_kkt(Q, G, A) zhat_b, nu_b, lam_b, s_b = pdipm_b.forward(p, Q, G, h, A, b, Q_LU, S_LU, R) batched_time = time.time() - start # Usually between 1e-4 and 1e-5: # print('Diff between gurobi and pdipm: ', # np.linalg.norm(zhat_g[0]-zhat_b[0].cpu().numpy())) # import IPython, sys; IPython.embed(); sys.exit(-1) # import IPython, sys; IPython.embed(); sys.exit(-1) # zhat_diff = (single_results[0][0] - zhat_b[0]).norm() # lam_diff = (single_results[0][2] - lam_b[0]).norm() # eps = 0.1 # Pretty relaxed. # if zhat_diff > eps or lam_diff > eps: # print('===========') # print("Warning: Single and batched solutions might not match.") # print(" + zhat_diff: {}".format(zhat_diff)) # print(" + lam_diff: {}".format(lam_diff)) # print(" + (nz, neq, nineq, nBatch) = ({}, {}, {}, {})".format( # nz, neq, nineq, nBatch)) # print('===========') return gurobi_time, single_time, batched_time
def forward(ctx, p_, Q_, G_, h_, A_, b_): ############################################################ # The forward solver ############################################################ eps = 1e-12 verbose = 0 notImprovedLim = 3 maxIter = 20 nBatch = extract_nBatch(Q_, p_, G_, h_, A_, b_) Q, _ = expandParam(Q_, nBatch, 3) p, _ = expandParam(p_, nBatch, 2) G, _ = expandParam(G_, nBatch, 3) h, _ = expandParam(h_, nBatch, 2) A, _ = expandParam(A_, nBatch, 3) b, _ = expandParam(b_, nBatch, 2) _, nineq, nz = G.size() ny = Q.shape[-2] neq = A.size(1) if A.nelement() > 0 else 0 assert (neq > 0 or nineq > 0) ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz if solver == QPSolvers.PDIPM_BATCHED: ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A) zhats, ctx.nus, ctx.lams, ctx.slacks = pdipm_b.forward( Q, p, G, h, A, b, ctx.Q_LU, ctx.S_LU, ctx.R, eps, verbose, notImprovedLim, maxIter) elif solver == QPSolvers.CVXPY: vals = torch.Tensor(nBatch).type_as(Q) zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q) lams = torch.Tensor(nBatch, ctx.nineq).type_as(Q) nus = torch.Tensor(nBatch, ctx.neq).type_as(Q) \ if ctx.neq > 0 else torch.Tensor() slacks = torch.Tensor(nBatch, ctx.nineq).type_as(Q) for i in range(nBatch): Ai, bi = (A[i], b[i]) if neq > 0 else (None, None) vals[ i], zhati, nui, lami, si = solvers.cvxpy.forward_single_np( *[ x.cpu().numpy() if x is not None else None for x in (Q[i], p[i], G[i], h[i], Ai, bi) ]) # if zhati[0] is None: # import IPython, sys; IPython.embed(); sys.exit(-1) zhats[i] = torch.Tensor(zhati) lams[i] = torch.Tensor(lami) slacks[i] = torch.Tensor(si) if neq > 0: nus[i] = torch.Tensor(nui) ctx.vals = vals ctx.lams = lams ctx.nus = nus ctx.slacks = slacks else: assert False # ctx.save_for_backward(zhats, Q_, p_, G_, h_, A_, b_) if ctx.lams is None: ctx.lams = torch.Tensor() if ctx.nus is None: ctx.nus = torch.Tensor() ############################################################ # Define implicit functions ############################################################ def stationarity(j, argv): # argv = tuple([x]) + tuple(params) + tuple([y]) + tuple(duals) # print(len(argv)) x, Q, G, h, A, b, y, l, n = argv x = x.unsqueeze(-1) y = y.unsqueeze(-1) h = h.unsqueeze(-1) b = b.unsqueeze(-1) n = n.unsqueeze(-1) l = l.unsqueeze(-1) # print('Q0.shape = ', Q.shape) # print('A0.shape = ', A.shape) # print('G0.shape = ', G.shape) # print('y0.shape = ', y.shape) # print('x0.shape = ', x.shape) # print('l0.shape = ', l.shape) # print('n0.shape = ', n.shape) # print('h0.shape = ', h.shape) # print('b0.shape = ', b.shape) Qy = torch.matmul(Q[:, j, :].unsqueeze(1), y[:, :, :]).squeeze(1) if neq is 0: ATn = torch.zeros_like(Qy) else: ATn = torch.matmul( torch.transpose(A, 1, 2)[:, j, :].unsqueeze(1), n[:, :, :]).squeeze(1) if nineq is 0: GTl = torch.zeros_like(Qy) else: GTl = torch.matmul( torch.transpose(G, 1, 2)[:, j, :].unsqueeze(1), l[:, :, :]).squeeze(1) return Qy + x[:, j, :] + ATn + GTl + 0 * h.mean() + 0 * b.mean() def primal_fea(j, argv): x, Q, G, h, A, b, y, l, n = argv x = x.unsqueeze(-1) y = y.unsqueeze(-1) h = h.unsqueeze(-1) b = b.unsqueeze(-1) n = n.unsqueeze(-1) l = l.unsqueeze(-1) Ay = torch.matmul(A[:, j, :].unsqueeze(1), y[:, :, :]).squeeze(1) return Ay - b[:, j, :] def comp_slack(j, argv): x, Q, G, h, A, b, y, l, n = argv x = x.unsqueeze(-1) y = y.unsqueeze(-1) h = h.unsqueeze(-1) b = b.unsqueeze(-1) n = n.unsqueeze(-1) l = l.unsqueeze(-1) Gy = torch.matmul(G[:, j, :].unsqueeze(1), y[:, :, :]).squeeze(1) return l[:, j, :] * (Gy - h[:, j, :]) nf = ny + neq + nineq f_dict = { stationarity: (0, ny), primal_fea: (ny, ny + neq), comp_slack: (ny + neq, nf) } ############################################################ # Construct the imp_struct ############################################################ imp_struct = ImpStruct() imp_struct.x = p_ imp_struct.params = [Q_, G_, h_, A_, b_] imp_struct.y = zhats imp_struct.duals = [ctx.lams, ctx.nus] imp_struct.other_inputs = [] imp_struct.f_dict = f_dict ctx.imp_struct = imp_struct # ctx.y = zhats # ctx.x = p_ # ctx.duals = [ctx.lams, ctx.nus] # ctx.params = [Q_, G_, h_, A_, b_] return zhats