def backward(ctx, *dl_dzhat): zhats, Q, p, G, h, A, b = ctx.saved_tensors nBatch = extract_nBatch(Q, p, G, h, A, b) Q, Q_e = expandParam(Q, nBatch, 3) p, p_e = expandParam(p, nBatch, 2) G, G_e = expandParam(G, nBatch, 3) h, h_e = expandParam(h, nBatch, 2) A, A_e = expandParam(A, nBatch, 3) b, b_e = expandParam(b, nBatch, 2) # neq, nineq, nz = ctx.neq, ctx.nineq, ctx.nz neq, nineq = ctx.neq, ctx.nineq if solver == QPSolvers.CVXPY: ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A) # Clamp here to avoid issues coming up when the slacks are too small. # TODO: A better fix would be to get lams and slacks from the # solver that don't have this issue. d = torch.clamp(ctx.lams, min=1e-8) / torch.clamp(ctx.slacks, min=1e-8) pdipm_b.factor_kkt(ctx.S_LU, ctx.R, d) dx, _, dlam, dnu = pdipm_b.solve_kkt( ctx.Q_LU, d, G, A, ctx.S_LU, dl_dzhat[0], torch.zeros(nBatch, nineq).type_as(G), torch.zeros(nBatch, nineq).type_as(G), torch.zeros(nBatch, neq).type_as(G) if neq > 0 else torch.Tensor()) dps = dx dGs = bger(dlam, zhats) + bger(ctx.lams, dx) if G_e: dGs = dGs.mean(0) dhs = -dlam if h_e: dhs = dhs.mean(0) if neq > 0: dAs = bger(dnu, zhats) + bger(ctx.nus, dx) dbs = -dnu if A_e: dAs = dAs.mean(0) if b_e: dbs = dbs.mean(0) else: dAs, dbs = None, None dQs = 0.5 * (bger(dx, zhats) + bger(zhats, dx)) if Q_e: dQs = dQs.mean(0) if p_e: dps = dps.mean(0) grads = (dQs, dps, dGs, dhs, dAs, dbs) return grads
def test_lu_kkt_solver(): Q, p, G, h, A, b, d, D, rx, rs, rz, ry = get_kkt_problem() dx, ds, dz, dy = pdipm_b.factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry) Q_LU, S_LU, R = pdipm_b.pre_factor_kkt(Q, G, A) pdipm_b.factor_kkt(S_LU, R, d) dx_, ds_, dz_, dy_ = pdipm_b.solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry) npt.assert_allclose(dx.numpy(), dx_.numpy(), rtol=RTOL, atol=ATOL) npt.assert_allclose(ds.numpy(), ds_.numpy(), rtol=RTOL, atol=ATOL) npt.assert_allclose(dz.numpy(), dz_.numpy(), rtol=RTOL, atol=ATOL) npt.assert_allclose(dy.numpy(), dy_.numpy(), rtol=RTOL, atol=ATOL)