示例#1
0
    def backward(ctx, *dl_dzhat):
        zhats, Q, p, G, h, A, b = ctx.saved_tensors
        nBatch = extract_nBatch(Q, p, G, h, A, b)
        Q, Q_e = expandParam(Q, nBatch, 3)
        p, p_e = expandParam(p, nBatch, 2)
        G, G_e = expandParam(G, nBatch, 3)
        h, h_e = expandParam(h, nBatch, 2)
        A, A_e = expandParam(A, nBatch, 3)
        b, b_e = expandParam(b, nBatch, 2)

        # neq, nineq, nz = ctx.neq, ctx.nineq, ctx.nz
        neq, nineq = ctx.neq, ctx.nineq

        if solver == QPSolvers.CVXPY:
            ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A)

        # Clamp here to avoid issues coming up when the slacks are too small.
        # TODO: A better fix would be to get lams and slacks from the
        # solver that don't have this issue.
        d = torch.clamp(ctx.lams, min=1e-8) / torch.clamp(ctx.slacks, min=1e-8)

        pdipm_b.factor_kkt(ctx.S_LU, ctx.R, d)
        dx, _, dlam, dnu = pdipm_b.solve_kkt(
            ctx.Q_LU, d, G, A, ctx.S_LU, dl_dzhat[0],
            torch.zeros(nBatch, nineq).type_as(G),
            torch.zeros(nBatch, nineq).type_as(G),
            torch.zeros(nBatch, neq).type_as(G) if neq > 0 else torch.Tensor())

        dps = dx
        dGs = bger(dlam, zhats) + bger(ctx.lams, dx)
        if G_e:
            dGs = dGs.mean(0)
        dhs = -dlam
        if h_e:
            dhs = dhs.mean(0)
        if neq > 0:
            dAs = bger(dnu, zhats) + bger(ctx.nus, dx)
            dbs = -dnu
            if A_e:
                dAs = dAs.mean(0)
            if b_e:
                dbs = dbs.mean(0)
        else:
            dAs, dbs = None, None
        dQs = 0.5 * (bger(dx, zhats) + bger(zhats, dx))
        if Q_e:
            dQs = dQs.mean(0)
        if p_e:
            dps = dps.mean(0)

        grads = (dQs, dps, dGs, dhs, dAs, dbs)

        return grads
示例#2
0
def test_lu_kkt_solver():
    Q, p, G, h, A, b, d, D, rx, rs, rz, ry = get_kkt_problem()

    dx, ds, dz, dy = pdipm_b.factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry)

    Q_LU, S_LU, R = pdipm_b.pre_factor_kkt(Q, G, A)
    pdipm_b.factor_kkt(S_LU, R, d)
    dx_, ds_, dz_, dy_ = pdipm_b.solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry)

    npt.assert_allclose(dx.numpy(), dx_.numpy(), rtol=RTOL, atol=ATOL)
    npt.assert_allclose(ds.numpy(), ds_.numpy(), rtol=RTOL, atol=ATOL)
    npt.assert_allclose(dz.numpy(), dz_.numpy(), rtol=RTOL, atol=ATOL)
    npt.assert_allclose(dy.numpy(), dy_.numpy(), rtol=RTOL, atol=ATOL)