예제 #1
0
    def forward(ctx, Q_, p_, G_, h_, A_, b_):
        eps = 1e-12
        verbose = 0
        notImprovedLim = 3
        maxIter = 20
        nBatch = extract_nBatch(Q_, p_, G_, h_, A_, b_)
        Q, _ = expandParam(Q_, nBatch, 3)
        p, _ = expandParam(p_, nBatch, 2)
        G, _ = expandParam(G_, nBatch, 3)
        h, _ = expandParam(h_, nBatch, 2)
        A, _ = expandParam(A_, nBatch, 3)
        b, _ = expandParam(b_, nBatch, 2)

        _, nineq, nz = G.size()
        neq = A.size(1) if A.nelement() > 0 else 0
        assert (neq > 0 or nineq > 0)
        ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz

        if solver == QPSolvers.PDIPM_BATCHED:
            ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A)
            zhats, ctx.nus, ctx.lams, ctx.slacks = pdipm_b.forward(
                Q, p, G, h, A, b, ctx.Q_LU, ctx.S_LU, ctx.R, eps, verbose,
                notImprovedLim, maxIter)
        elif solver == QPSolvers.CVXPY:
            vals = torch.Tensor(nBatch).type_as(Q)
            zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q)
            lams = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
            nus = torch.Tensor(nBatch, ctx.neq).type_as(Q) \
                if ctx.neq > 0 else torch.Tensor()
            slacks = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
            for i in range(nBatch):
                Ai, bi = (A[i], b[i]) if neq > 0 else (None, None)
                vals[
                    i], zhati, nui, lami, si = solvers.cvxpy.forward_single_np(
                        *[
                            x.cpu().numpy() if x is not None else None
                            for x in (Q[i], p[i], G[i], h[i], Ai, bi)
                        ])
                # if zhati[0] is None:
                #     import IPython, sys; IPython.embed(); sys.exit(-1)
                zhats[i] = torch.Tensor(zhati)
                lams[i] = torch.Tensor(lami)
                slacks[i] = torch.Tensor(si)
                if neq > 0:
                    nus[i] = torch.Tensor(nui)

            ctx.vals = vals
            ctx.lams = lams
            ctx.nus = nus
            ctx.slacks = slacks
        else:
            assert False

        ctx.save_for_backward(zhats, Q_, p_, G_, h_, A_, b_)
        return zhats
예제 #2
0
    def backward(ctx, *dl_dzhat):
        zhats, Q, p, G, h, A, b = ctx.saved_tensors
        nBatch = extract_nBatch(Q, p, G, h, A, b)
        Q, Q_e = expandParam(Q, nBatch, 3)
        p, p_e = expandParam(p, nBatch, 2)
        G, G_e = expandParam(G, nBatch, 3)
        h, h_e = expandParam(h, nBatch, 2)
        A, A_e = expandParam(A, nBatch, 3)
        b, b_e = expandParam(b, nBatch, 2)

        # neq, nineq, nz = ctx.neq, ctx.nineq, ctx.nz
        neq, nineq = ctx.neq, ctx.nineq

        if solver == QPSolvers.CVXPY:
            ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A)

        # Clamp here to avoid issues coming up when the slacks are too small.
        # TODO: A better fix would be to get lams and slacks from the
        # solver that don't have this issue.
        d = torch.clamp(ctx.lams, min=1e-8) / torch.clamp(ctx.slacks, min=1e-8)

        pdipm_b.factor_kkt(ctx.S_LU, ctx.R, d)
        dx, _, dlam, dnu = pdipm_b.solve_kkt(
            ctx.Q_LU, d, G, A, ctx.S_LU, dl_dzhat[0],
            torch.zeros(nBatch, nineq).type_as(G),
            torch.zeros(nBatch, nineq).type_as(G),
            torch.zeros(nBatch, neq).type_as(G) if neq > 0 else torch.Tensor())

        dps = dx
        dGs = bger(dlam, zhats) + bger(ctx.lams, dx)
        if G_e:
            dGs = dGs.mean(0)
        dhs = -dlam
        if h_e:
            dhs = dhs.mean(0)
        if neq > 0:
            dAs = bger(dnu, zhats) + bger(ctx.nus, dx)
            dbs = -dnu
            if A_e:
                dAs = dAs.mean(0)
            if b_e:
                dbs = dbs.mean(0)
        else:
            dAs, dbs = None, None
        dQs = 0.5 * (bger(dx, zhats) + bger(zhats, dx))
        if Q_e:
            dQs = dQs.mean(0)
        if p_e:
            dps = dps.mean(0)

        grads = (dQs, dps, dGs, dhs, dAs, dbs)

        return grads
예제 #3
0
def test_lu_kkt_solver():
    Q, p, G, h, A, b, d, D, rx, rs, rz, ry = get_kkt_problem()

    dx, ds, dz, dy = pdipm_b.factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry)

    Q_LU, S_LU, R = pdipm_b.pre_factor_kkt(Q, G, A)
    pdipm_b.factor_kkt(S_LU, R, d)
    dx_, ds_, dz_, dy_ = pdipm_b.solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry)

    npt.assert_allclose(dx.numpy(), dx_.numpy(), rtol=RTOL, atol=ATOL)
    npt.assert_allclose(ds.numpy(), ds_.numpy(), rtol=RTOL, atol=ATOL)
    npt.assert_allclose(dz.numpy(), dz_.numpy(), rtol=RTOL, atol=ATOL)
    npt.assert_allclose(dy.numpy(), dy_.numpy(), rtol=RTOL, atol=ATOL)
예제 #4
0
파일: prof.py 프로젝트: jdily/qpth
def prof_instance(nz, nBatch, cuda=True):
    nineq, neq = 100, 0
    assert (neq == 0)
    L = npr.rand(nBatch, nz, nz)
    Q = np.matmul(L, L.transpose((0, 2, 1))) + 1e-3 * np.eye(nz, nz)
    G = npr.randn(nBatch, nineq, nz)
    z0 = npr.randn(nBatch, nz)
    s0 = npr.rand(nBatch, nineq)
    p = npr.randn(nBatch, nz)
    h = np.matmul(G, np.expand_dims(z0, axis=(2))).squeeze(2) + s0
    A = npr.randn(nBatch, neq, nz)
    b = np.matmul(A, np.expand_dims(z0, axis=(2))).squeeze(2)

    zhat_g = []
    gurobi_time = 0.0
    for i in range(nBatch):
        m = gpy.Model()
        zhat = m.addVars(nz, lb=-gpy.GRB.INFINITY, ub=gpy.GRB.INFINITY)

        obj = 0.0
        for j in range(nz):
            for k in range(nz):
                obj += 0.5 * Q[i, j, k] * zhat[j] * zhat[k]
            obj += p[i, j] * zhat[j]
        m.setObjective(obj)
        for j in range(nineq):
            con = 0
            for k in range(nz):
                con += G[i, j, k] * zhat[k]
            m.addConstr(con <= h[i, j])
        m.setParam('OutputFlag', False)
        start = time.time()
        m.optimize()
        gurobi_time += time.time() - start
        t = np.zeros(nz)
        for j in range(nz):
            t[j] = zhat[j].x
        zhat_g.append(t)

    p, L, Q, G, z0, s0, h = [torch.Tensor(x) for x in [p, L, Q, G, z0, s0, h]]
    if cuda:
        p, L, Q, G, z0, s0, h = [x.cuda() for x in [p, L, Q, G, z0, s0, h]]
    if neq > 0:
        A = torch.Tensor(A)
        b = torch.Tensor(b)
    else:
        A, b = [torch.Tensor()] * 2
    if cuda:
        A = A.cuda()
        b = b.cuda()

    # af = adact.AdactFunction()

    single_results = []
    start = time.time()
    for i in range(nBatch):
        A_i = A[i] if neq > 0 else A
        b_i = b[i] if neq > 0 else b
        U_Q, U_S, R = pdipm_s.pre_factor_kkt(Q[i], G[i], A_i)
        single_results.append(
            pdipm_s.forward(p[i], Q[i], G[i], A_i, b_i, h[i], U_Q, U_S, R))
    single_time = time.time() - start

    start = time.time()
    Q_LU, S_LU, R = pdipm_b.pre_factor_kkt(Q, G, A)
    zhat_b, nu_b, lam_b, s_b = pdipm_b.forward(p, Q, G, h, A, b, Q_LU, S_LU, R)
    batched_time = time.time() - start

    # Usually between 1e-4 and 1e-5:
    # print('Diff between gurobi and pdipm: ',
    #       np.linalg.norm(zhat_g[0]-zhat_b[0].cpu().numpy()))
    # import IPython, sys; IPython.embed(); sys.exit(-1)

    # import IPython, sys; IPython.embed(); sys.exit(-1)
    # zhat_diff = (single_results[0][0] - zhat_b[0]).norm()
    # lam_diff = (single_results[0][2] - lam_b[0]).norm()
    # eps = 0.1 # Pretty relaxed.
    # if zhat_diff > eps or lam_diff > eps:
    #     print('===========')
    #     print("Warning: Single and batched solutions might not match.")
    #     print("  + zhat_diff: {}".format(zhat_diff))
    #     print("  + lam_diff: {}".format(lam_diff))
    #     print("  + (nz, neq, nineq, nBatch) = ({}, {}, {}, {})".format(
    #         nz, neq, nineq, nBatch))
    #     print('===========')

    return gurobi_time, single_time, batched_time
예제 #5
0
    def forward(ctx, p_, Q_, G_, h_, A_, b_):

        ############################################################
        #  The forward solver
        ############################################################
        eps = 1e-12
        verbose = 0
        notImprovedLim = 3
        maxIter = 20
        nBatch = extract_nBatch(Q_, p_, G_, h_, A_, b_)
        Q, _ = expandParam(Q_, nBatch, 3)
        p, _ = expandParam(p_, nBatch, 2)
        G, _ = expandParam(G_, nBatch, 3)
        h, _ = expandParam(h_, nBatch, 2)
        A, _ = expandParam(A_, nBatch, 3)
        b, _ = expandParam(b_, nBatch, 2)

        _, nineq, nz = G.size()
        ny = Q.shape[-2]
        neq = A.size(1) if A.nelement() > 0 else 0
        assert (neq > 0 or nineq > 0)
        ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz

        if solver == QPSolvers.PDIPM_BATCHED:
            ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A)
            zhats, ctx.nus, ctx.lams, ctx.slacks = pdipm_b.forward(
                Q, p, G, h, A, b, ctx.Q_LU, ctx.S_LU, ctx.R, eps, verbose,
                notImprovedLim, maxIter)
        elif solver == QPSolvers.CVXPY:
            vals = torch.Tensor(nBatch).type_as(Q)
            zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q)
            lams = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
            nus = torch.Tensor(nBatch, ctx.neq).type_as(Q) \
                if ctx.neq > 0 else torch.Tensor()
            slacks = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
            for i in range(nBatch):
                Ai, bi = (A[i], b[i]) if neq > 0 else (None, None)
                vals[
                    i], zhati, nui, lami, si = solvers.cvxpy.forward_single_np(
                        *[
                            x.cpu().numpy() if x is not None else None
                            for x in (Q[i], p[i], G[i], h[i], Ai, bi)
                        ])
                # if zhati[0] is None:
                #     import IPython, sys; IPython.embed(); sys.exit(-1)
                zhats[i] = torch.Tensor(zhati)
                lams[i] = torch.Tensor(lami)
                slacks[i] = torch.Tensor(si)
                if neq > 0:
                    nus[i] = torch.Tensor(nui)

            ctx.vals = vals
            ctx.lams = lams
            ctx.nus = nus
            ctx.slacks = slacks
        else:
            assert False

        # ctx.save_for_backward(zhats, Q_, p_, G_, h_, A_, b_)
        if ctx.lams is None:
            ctx.lams = torch.Tensor()
        if ctx.nus is None:
            ctx.nus = torch.Tensor()

        ############################################################
        # Define implicit functions
        ############################################################
        def stationarity(j, argv):
            # argv = tuple([x]) + tuple(params) + tuple([y]) + tuple(duals)
            # print(len(argv))
            x, Q, G, h, A, b, y, l, n = argv
            x = x.unsqueeze(-1)
            y = y.unsqueeze(-1)
            h = h.unsqueeze(-1)
            b = b.unsqueeze(-1)
            n = n.unsqueeze(-1)
            l = l.unsqueeze(-1)

            # print('Q0.shape = ', Q.shape)
            # print('A0.shape = ', A.shape)
            # print('G0.shape = ', G.shape)
            # print('y0.shape = ', y.shape)
            # print('x0.shape = ', x.shape)
            # print('l0.shape = ', l.shape)
            # print('n0.shape = ', n.shape)
            # print('h0.shape = ', h.shape)
            # print('b0.shape = ', b.shape)

            Qy = torch.matmul(Q[:, j, :].unsqueeze(1), y[:, :, :]).squeeze(1)
            if neq is 0:
                ATn = torch.zeros_like(Qy)
            else:
                ATn = torch.matmul(
                    torch.transpose(A, 1, 2)[:, j, :].unsqueeze(1),
                    n[:, :, :]).squeeze(1)
            if nineq is 0:
                GTl = torch.zeros_like(Qy)
            else:
                GTl = torch.matmul(
                    torch.transpose(G, 1, 2)[:, j, :].unsqueeze(1),
                    l[:, :, :]).squeeze(1)

            return Qy + x[:, j, :] + ATn + GTl + 0 * h.mean() + 0 * b.mean()

        def primal_fea(j, argv):
            x, Q, G, h, A, b, y, l, n = argv
            x = x.unsqueeze(-1)
            y = y.unsqueeze(-1)
            h = h.unsqueeze(-1)
            b = b.unsqueeze(-1)
            n = n.unsqueeze(-1)
            l = l.unsqueeze(-1)
            Ay = torch.matmul(A[:, j, :].unsqueeze(1), y[:, :, :]).squeeze(1)
            return Ay - b[:, j, :]

        def comp_slack(j, argv):
            x, Q, G, h, A, b, y, l, n = argv
            x = x.unsqueeze(-1)
            y = y.unsqueeze(-1)
            h = h.unsqueeze(-1)
            b = b.unsqueeze(-1)
            n = n.unsqueeze(-1)
            l = l.unsqueeze(-1)
            Gy = torch.matmul(G[:, j, :].unsqueeze(1), y[:, :, :]).squeeze(1)
            return l[:, j, :] * (Gy - h[:, j, :])

        nf = ny + neq + nineq
        f_dict = {
            stationarity: (0, ny),
            primal_fea: (ny, ny + neq),
            comp_slack: (ny + neq, nf)
        }

        ############################################################
        # Construct the imp_struct
        ############################################################
        imp_struct = ImpStruct()
        imp_struct.x = p_
        imp_struct.params = [Q_, G_, h_, A_, b_]
        imp_struct.y = zhats
        imp_struct.duals = [ctx.lams, ctx.nus]
        imp_struct.other_inputs = []
        imp_struct.f_dict = f_dict
        ctx.imp_struct = imp_struct

        # ctx.y = zhats
        # ctx.x = p_
        # ctx.duals = [ctx.lams, ctx.nus]
        # ctx.params = [Q_, G_, h_, A_, b_]

        return zhats