Example #1
0
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
    """ Solve KKT equations for the affine step"""

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]

    nineq, nz, neq, nBatch = get_sizes(G, A)

    invQ_rx = rx.btrisolve(*Q_LU)  # Q-1 rx
    if neq > 0:
        # A Q-1 rx - ry
        # G Q-1 rx + rs / d - rz
        h = torch.cat([invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry,
                       invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz], 1)
    else:
        h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz

    w = -(h.btrisolve(*S_LU))  # S-1 h =

    g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1)  # -rx - GT w = -rx -GT S-1 h
    if neq > 0:
        g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1)  # - AT w = -AT S-1 h
    g2 = -rs - w[:, neq:]

    dx = g1.btrisolve(*Q_LU)  # Q-1 g1 = - Q-1 AT S-1 h
    ds = g2 / d  # g2 / d = (-rs - w) / d
    dz = w[:, neq:]
    dy = w[:, :neq] if neq > 0 else None

    return dx, ds, dz, dy
Example #2
0
def sparse_solve_kkt_ir_inverse(H_,
                                A_,
                                C_tilde,
                                Q_tilde,
                                D_tilde,
                                G,
                                A,
                                F_tilde,
                                rx,
                                rs,
                                rz,
                                ry,
                                niter=1):
    """Inefficient iterative refinement."""
    ns = nineq, nz, neq, nBatch = get_sizes(G, A)
    eps = 1e-7

    dx, ds, dz, dy = sparse_solve_kkt_inverse(H_, A_, C_tilde, rx, rs, rz, ry,
                                              ns)
    resx, ress, resz, resy = kkt_resid_reg(Q_tilde, D_tilde, G, A, F_tilde,
                                           eps, dx, ds, dz, dy, rx, rs, rz, ry)
    for k in range(niter):
        ddx, dds, ddz, ddy = sparse_solve_kkt_inverse(
            H_, A_, C_tilde, -resx, -ress, -resz,
            -resy if resy is not None else None, ns)
        dx, ds, dz, dy = [
            v + dv if v is not None else None
            for v, dv in zip((dx, ds, dz, dy), (ddx, dds, ddz, ddy))
        ]
        resx, ress, resz, resy = kkt_resid_reg(Q_tilde, D_tilde, G, A, F_tilde,
                                               eps, dx, ds, dz, dy, rx, rs, rz,
                                               ry)

    return dx, ds, dz, dy
Example #3
0
def solve_kkt_ir(Q, D, G, A, F, rx, rs, rz, ry, niter=1):
    """Inefficient iterative refinement."""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    eps = 1e-7
    Q_tilde = Q + eps * torch.eye(nz).type_as(Q).repeat(nBatch, 1, 1)
    D_tilde = D + eps * torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)

    # XXX Might not workd for batch size > 1
    C_tilde = -eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)
    if F is not None:  # XXX inverted sign for F below
        C_tilde[:, :nineq, :nineq] -= F
    F_tilde = C_tilde[:, :nineq, :nineq]

    dx, ds, dz, dy = factor_solve_kkt_reg(
        Q_tilde, D_tilde, G, A, C_tilde, rx, rs, rz, ry, eps)
    resx, ress, resz, resy = kkt_resid_reg(Q, D, G, A, F_tilde, eps,
                        dx, ds, dz, dy, rx, rs, rz, ry)
    for k in range(niter):
        ddx, dds, ddz, ddy = factor_solve_kkt_reg(Q_tilde, D_tilde, G, A, C_tilde,
                                                  -resx, -ress, -resz,
                                                  -resy if resy is not None else None,
                                                  eps)
        dx, ds, dz, dy = [v + dv if v is not None else None
                          for v, dv in zip((dx, ds, dz, dy), (ddx, dds, ddz, ddy))]
        resx, ress, resz, resy = kkt_resid_reg(Q, D, G, A, F_tilde, eps,
                            dx, ds, dz, dy, rx, rs, rz, ry)

    return dx, ds, dz, dy
Example #4
0
def solve_kkt_inverse(Q_tilde, D, G, A, C_tilde, rx, rs, rz, ry, eps):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q_tilde)
    H_[:, :nz, :nz] = Q_tilde
    H_[:, -nineq:, -nineq:] = D
    if neq > 0:
        # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1),
        # torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([
            torch.cat(
                [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)],
                2),
            torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q_tilde)], 2)
        ], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
    else:
        A_ = torch.cat(
            [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    full_mat = torch.cat(
        [torch.cat([H_, A_.transpose(1, 2)], 2),
         torch.cat([A_, C_tilde], 2)], 1)
    full_res = torch.cat([g_, h_], 1)
    sol = torch.bmm(binverse(full_mat), full_res.unsqueeze(2)).squeeze(2)

    dx = sol[:, :nz]
    ds = sol[:, nz:nz + nineq]
    dz = sol[:, nz + nineq:nz + nineq + nineq]
    dy = sol[:, nz + nineq + nineq:] if neq > 0 else None

    return dx, ds, dz, dy
Example #5
0
def solve_kkt_ir_inverse(Q, D, G, A, F, rx, rs, rz, ry, niter=1):
    """Inefficient iterative refinement."""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    eps = 1e-7
    Q_tilde = Q + eps * torch.eye(nz).type_as(Q).repeat(nBatch, 1, 1)
    D_tilde = D + eps * torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)

    # TODO Test batche size > 1
    # XXX Shouldn't the sign below be positive? (Since its going to be subtracted later)
    C_tilde = -eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(
        nBatch, 1, 1)
    if F is not None:  # XXX inverted sign for F below
        C_tilde[:, :nineq, :nineq] -= F
    F_tilde = C_tilde[:, :nineq, :nineq]

    dx, ds, dz, dy = solve_kkt_inverse(Q_tilde, D_tilde, G, A, C_tilde, rx, rs,
                                       rz, ry, eps)
    resx, ress, resz, resy = kkt_resid_reg(Q, D, G, A, F_tilde, eps, dx, ds,
                                           dz, dy, rx, rs, rz, ry)
    for k in range(niter):
        ddx, dds, ddz, ddy = solve_kkt_inverse(
            Q_tilde, D_tilde, G, A, C_tilde, -resx, -ress, -resz,
            -resy if resy is not None else None, eps)
        dx, ds, dz, dy = [
            v + dv if v is not None else None
            for v, dv in zip((dx, ds, dz, dy), (ddx, dds, ddz, ddy))
        ]
        resx, ress, resz, resy = kkt_resid_reg(Q, D, G, A, F_tilde, eps, dx,
                                               ds, dz, dy, rx, rs, rz, ry)

    return dx, ds, dz, dy
Example #6
0
def sparse_solve_kkt_ir(Q, D, G, A, F, rx, rs, rz, ry, niter=1):
    """Inefficient iterative refinement."""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    eps = 1e-7
    Q_tilde = Q + eps * torch.eye(nz).type_as(Q).repeat(nBatch, 1, 1)
    D_tilde = D + eps * torch.eye(nineq).type_as(Q).repeat(nBatch, 1, 1)

    # TODO Test batche size > 1
    # XXX Shouldn't the sign below be positive? (Since its going to be subtracted later)
    C_tilde = -eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(
        nBatch, 1, 1)
    if F is not None:  # XXX inverted sign for F below
        C_tilde[:, :nineq, :nineq] -= F
    F_tilde = C_tilde[:, :nineq, :nineq]

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q_tilde)
    H_[:, :nz, :nz] = Q_tilde
    H_[:, -nineq:, -nineq:] = D
    if neq > 0:
        # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1),
        # torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([
            torch.cat(
                [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)],
                2),
            torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q_tilde)], 2)
        ], 1)
    else:
        A_ = torch.cat(
            [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2)
    spH_ = csc_matrix(H_.squeeze(0).numpy())
    A_ = A_.squeeze(0).numpy()
    spA_ = csc_matrix(A_)
    spC_tilde = csc_matrix(C_tilde.squeeze(0).numpy())

    dx, ds, dz, dy = sparse_factor_solve_kkt_reg(spH_, A_, spA_, spC_tilde, rx,
                                                 rs, rz, ry, neq, nineq, nz)
    resx, ress, resz, resy = sparse_kkt_resid_reg(Q, D, G, A, F_tilde, eps, dx,
                                                  ds, dz, dy, rx, rs, rz, ry)
    for k in range(niter):
        ddx, dds, ddz, ddy = sparse_factor_solve_kkt_reg(
            spH_, A_, spA_, spC_tilde, -resx, -ress, -resz,
            -resy if resy is not None else None, neq, nineq, nz)
        dx, ds, dz, dy = [
            v + dv if v is not None else None
            for v, dv in zip((dx, ds, dz, dy), (ddx, dds, ddz, ddy))
        ]
        resx, ress, resz, resy = sparse_kkt_resid_reg(Q, D, G, A, F_tilde, eps,
                                                      dx, ds, dz, dy, rx, rs,
                                                      rz, ry)

    return dx, ds, dz, dy
Example #7
0
def pre_factor_kkt(Q, G, F, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    try:
        Q_LU = btrifact_hack(Q)
    except:
        raise RuntimeError("""
lcp Error: Cannot perform LU factorization on Q.
Please make sure that your Q matrix is PSD and has
a non-zero diagonal.
""")

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    #
    # We compute a partial LU decomposition of the S matrix
    # that can be completed once D^{-1} is known.
    # See the 'Block LU factorization' part of our website
    # for more details.

    G_invQ_GT = torch.bmm(G, G.transpose(1, 2).btrisolve(*Q_LU)) + F
    R = G_invQ_GT.clone()
    S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
        .repeat(nBatch, 1).type_as(Q).int()
    if neq > 0:
        invQ_AT = A.transpose(1, 2).btrisolve(*Q_LU)
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)

        LU_A_invQ_AT = btrifact_hack(A_invQ_AT)
        P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.btriunpack(*LU_A_invQ_AT)
        P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)

        S_LU_11 = LU_A_invQ_AT[0]
        U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)
                           ).btrisolve(*LU_A_invQ_AT)
        S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
        T = G_invQ_AT.transpose(1, 2).btrisolve(*LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
        S_LU_data = torch.cat((torch.cat((S_LU_11, S_LU_12), 2),
                               torch.cat((S_LU_21, S_LU_22), 2)),
                              1)
        S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]

        R -= G_invQ_AT.bmm(T)
    else:
        S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)

    S_LU = [S_LU_data, S_LU_pivots]
    return Q_LU, S_LU, R
Example #8
0
def factor_solve_kkt_reg(Q_tilde, D, G, A, C_tilde, rx, rs, rz, ry, eps):
    nineq, nz, neq, nBatch = get_sizes(G, A)

    H_ = torch.zeros(nBatch, nz + nineq, nz + nineq).type_as(Q_tilde)
    H_[:, :nz, :nz] = Q_tilde
    H_[:, -nineq:, -nineq:] = D
    if neq > 0:
        # H_ = torch.cat([torch.cat([Q, torch.zeros(nz,nineq).type_as(Q)], 1),
        # torch.cat([torch.zeros(nineq, nz).type_as(Q), D], 1)], 0)
        A_ = torch.cat([
            torch.cat(
                [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)],
                2),
            torch.cat([A, torch.zeros(nBatch, neq, nineq).type_as(Q_tilde)], 2)
        ], 1)
        g_ = torch.cat([rx, rs], 1)
        h_ = torch.cat([rz, ry], 1)
    else:
        A_ = torch.cat(
            [G, torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)], 2)
        g_ = torch.cat([rx, rs], 1)
        h_ = rz

    H_LU = btrifact_hack(H_)

    invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU)  # H-1 AT
    invH_g_ = g_.lu_solve(*H_LU)  # H-1 g

    S_ = torch.bmm(A_, invH_A_)  # A H-1 AT
    # A H-1 AT + C_tilde
    S_ -= C_tilde
    S_LU = btrifact_hack(S_)
    # [(H-1 g)T AT]T - h = A H-1 g - h
    t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
    # w = (A H-1 AT + C_tilde)-1 (A H-1 g - h) <= Av - eps I w = h
    w_ = -t_.lu_solve(*S_LU)
    # Shouldn't it be just g (no minus)?
    # (Doesn't seem to make a difference, though...)
    t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze()  # -g - AT w
    v_ = t_.lu_solve(*H_LU)  # v = H-1 (-g - AT w)

    dx = v_[:, :nz]
    ds = v_[:, nz:]
    dz = w_[:, :nineq]
    dy = w_[:, nineq:] if neq > 0 else None

    return dx, ds, dz, dy
Example #9
0
def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
    """ Solve KKT equations for the affine step"""

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]

    nineq, nz, neq, nBatch = get_sizes(G, A)

    invQ_rx = rx.T.lu_solve(*Q_LU).view(1, -1)  # Q-1 rx
    if neq > 0:
        # A Q-1 rx - ry
        # G Q-1 rx + rs / d - rz
        h = torch.cat([
            invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry,
            invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs /
            (d + 1e-30) - rz
        ], 1)
    else:
        h = invQ_rx.unsqueeze(1).bmm(G.transpose(
            1, 2)).squeeze(1) + rs / (d + 1e-30) - rz

    S_LU[0] = S_LU[0] + 1e-30
    Q_LU[0][0] = Q_LU[0][0] + 1e-30

    w = -(h.T.lu_solve(*S_LU).view(1, -1))  # S-1 h =

    if np.isnan(w.sum()).item() > 0 or np.isnan(h.sum()).item() > 0:
        # import pdb
        # print(h)
        # print(w)
        # pdb.set_trace()

        h = h * 0
        w = w * 0

    g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(
        1)  # -rx - GT w = -rx -GT S-1 h
    if neq > 0:
        g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1)  # - AT w = -AT S-1 h
    g2 = -rs - w[:, neq:]

    dx = g1.T.lu_solve(*Q_LU).view(1, -1)  # Q-1 g1 = - Q-1 AT S-1 h
    ds = g2 / (d + 1e-30)  # g2 / d = (-rs - w) / d
    dz = w[:, neq:]
    dy = w[:, :neq] if neq > 0 else None

    return dx, ds, dz, dy
Example #10
0
def forward(Q, p, G, h, A, b, F, Q_LU, S_LU, R,
            eps=1e-12, verbose=-1, not_improved_lim=3,
            max_iter=20, solver=KKTSolvers.LU_PARTIAL):
    """
    Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    """
    nineq, nz, neq, batch_size = get_sizes(G, A)

    # Find initial values
    if solver == KKTSolvers.LU_FULL:
        D = torch.eye(nineq).repeat(batch_size, 1, 1).type_as(Q)
        reg_eps = 1e-7
        Q_tilde = Q + reg_eps * torch.eye(nz).type_as(Q).repeat(batch_size, 1, 1)
        D_tilde = D + reg_eps * torch.eye(nineq).type_as(Q).repeat(batch_size, 1, 1)

        if neq > 0:
            A_ = torch.cat([torch.cat([G, torch.eye(nineq).type_as(Q_tilde).repeat(batch_size, 1, 1)], 2),
                            torch.cat([A, torch.zeros(batch_size, neq, nineq).type_as(Q_tilde)], 2)], 1)
        else:
            A_ = torch.cat([G, torch.eye(nineq).type_as(Q_tilde).unsqueeze(0)], 2)

        C_tilde = reg_eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(batch_size, 1, 1)
        if F is not None:
            C_tilde[:, :nineq, :nineq] += F
        ns = [nineq, nz, neq, batch_size]
        x, s, z, y = factor_solve_kkt(
            Q_tilde, D_tilde, A_, C_tilde, p,
            torch.zeros(batch_size, nineq).type_as(Q),
            -h, -b if b is not None else None, ns)
    elif solver == KKTSolvers.LU_PARTIAL:
        d = torch.ones(batch_size, nineq).type_as(Q)
        factor_kkt(S_LU, R, d)
        x, s, z, y = solve_kkt(
            Q_LU, d, G, A, S_LU,
            p, torch.zeros(batch_size, nineq).type_as(Q),
            -h, -b if neq > 0 else None)
    elif solver == KKTSolvers.IR_UNOPT:
        D = torch.eye(nineq).repeat(batch_size, 1, 1).type_as(Q)
        x, s, z, y = solve_kkt_ir(
            Q, D, G, A, F, p,
            torch.zeros(batch_size, nineq).type_as(Q),
            -h, -b if b is not None else None)
    else:
        raise NotImplementedError('Specified KKTSolver not implemented.')

    M = torch.min(s, 1)[0].repeat(1, nineq)
    I = M <= 0
    s[I] -= M[I] - 1

    M = torch.min(z, 1)[0].repeat(1, nineq)
    I = M <= 0
    z[I] -= M[I] - 1

    best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
    nNotImproved = 0

    for i in range(max_iter):
        # affine scaling direction
        rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
            torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
            torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \
            p
        rs = z
        rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
        if F is not None:  # XXX (Inverted sign for F below)
            rz -= torch.bmm(z.unsqueeze(1), F.transpose(1, 2)).squeeze(1)
        ry = torch.bmm(x.unsqueeze(1), A.transpose(
            1, 2)).squeeze(1) - b if neq > 0 else 0.0
        mu = torch.abs((s * z).sum(1).squeeze() / nineq)
        z_resid = torch.norm(rz, 2, 1).squeeze()
        y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
        pri_resid = y_resid + z_resid
        dual_resid = torch.norm(rx, 2, 1).squeeze()
        resids = pri_resid + dual_resid + nineq * mu

        d = z / s
        if solver == KKTSolvers.LU_PARTIAL:
            try:
                factor_kkt(S_LU, R, d)
            except:
                return best['x'], best['y'], best['z'], best['s']

        if verbose == 1:
            print('iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.format(
                i, pri_resid.mean(), dual_resid.mean(), mu.mean()))
        if best['resids'] is None:
            best['resids'] = resids
            best['x'] = x.clone()
            best['z'] = z.clone()
            best['s'] = s.clone()
            best['y'] = y.clone() if y is not None else None
            nNotImproved = 0
        else:
            I = resids < best['resids']
            if I.sum() > 0:
                nNotImproved = 0
            else:
                nNotImproved += 1
            I_nz = I.repeat(nz, 1).t()
            I_nineq = I.repeat(nineq, 1).t()
            best['resids'][I] = resids[I]
            best['x'][I_nz] = x[I_nz]
            best['z'][I_nineq] = z[I_nineq]
            best['s'][I_nineq] = s[I_nineq]
            if neq > 0:
                I_neq = I.repeat(neq, 1).t()
                best['y'][I_neq] = y[I_neq]
        if nNotImproved == not_improved_lim or best['resids'].max() < eps or mu.min() > 1e100:
            if best['resids'].max() > 1. and verbose >= 0:
                print(INACC_ERR)
                print(best['resids'].max())
            return best['x'], best['y'], best['z'], best['s']

        if solver == KKTSolvers.LU_FULL:
            D = bdiag(d)
            D_tilde = D + reg_eps * torch.eye(nineq).type_as(Q).repeat(batch_size, 1, 1)
            dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt(
                Q_tilde, D_tilde, A_, C_tilde, rx, rs, rz, ry, ns)
        elif solver == KKTSolvers.LU_PARTIAL:
            dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(
                Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
        elif solver == KKTSolvers.IR_UNOPT:
            D = bdiag(d)
            dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir(
                Q, D, G, A, F, rx, rs, rz, ry)
        else:
            raise NotImplementedError('Specified KKTSolver not implemented.')

        # compute centering directions
        alpha = torch.min(torch.min(get_step(z, dz_aff),
                                    get_step(s, ds_aff)),
                          torch.ones(batch_size).type_as(Q))
        alpha_nineq = alpha.repeat(nineq, 1).t()
        t1 = s + alpha_nineq * ds_aff
        t2 = z + alpha_nineq * dz_aff
        t3 = torch.sum(t1 * t2, 1).squeeze()
        t4 = torch.sum(s * z, 1).squeeze()
        sig = (t3 / t4)**3

        rx = torch.zeros(batch_size, nz).type_as(Q)
        rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
        rz = torch.zeros(batch_size, nineq).type_as(Q)
        ry = torch.zeros(batch_size, neq).type_as(Q)

        if solver == KKTSolvers.LU_FULL:
            D = bdiag(d)
            D_tilde = D + reg_eps * torch.eye(nineq).type_as(Q).repeat(batch_size, 1, 1)
            dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(
                Q_tilde, D_tilde, A_, C_tilde, rx, rs, rz, ry, ns)
        elif solver == KKTSolvers.LU_PARTIAL:
            dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
                Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
        elif solver == KKTSolvers.IR_UNOPT:
            D = bdiag(d)
            dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir(
                Q, D, G, A, F, rx, rs, rz, ry)
        else:
            raise NotImplementedError('Specified KKTSolver not implemented.')

        dx = dx_aff + dx_cor
        ds = ds_aff + ds_cor
        dz = dz_aff + dz_cor
        dy = dy_aff + dy_cor if neq > 0 else None
        alpha = torch.min(0.999 * torch.min(get_step(z, dz),
                                            get_step(s, ds)),
                          torch.ones(batch_size).type_as(Q))
        alpha_nineq = alpha.repeat(nineq, 1).t()
        alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
        alpha_nz = alpha.repeat(nz, 1).t()

        x += alpha_nz * dx
        s += alpha_nineq * ds
        z += alpha_nineq * dz
        y = y + alpha_neq * dy if neq > 0 else None

    if best['resids'].max() > 1. and verbose >= 0:
        print(INACC_ERR)
        print(best['resids'].max())
    return best['x'], best['y'], best['z'], best['s']
Example #11
0
def forward(Q,
            p,
            G,
            h,
            A,
            b,
            F,
            Q_LU,
            S_LU,
            R,
            eps=1e-12,
            verbose=-1,
            not_improved_lim=3,
            max_iter=20):
    """
    Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    """
    nineq, nz, neq, batch_size = get_sizes(G, A)

    # Find initial values
    d = Q.new_ones(batch_size, nineq)
    factor_kkt(S_LU, R, d)
    x, s, z, y = solve_kkt(Q_LU, d, G, A, S_LU, p,
                           Q.new_zeros(batch_size, nineq), -h,
                           -b if neq > 0 else None)

    # Make all of the slack variables >= 1.
    M = torch.min(s, 1)[0]
    M = M.view(M.size(0), 1).repeat(1, nineq)
    I = M <= 0
    s[I] -= M[I] - 1

    # Make all of the inequality dual variables >= 1.
    M = torch.min(z, 1)[0]
    M = M.view(M.size(0), 1).repeat(1, nineq)
    I = M <= 0
    z[I] -= M[I] - 1

    best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
    nNotImproved = 0

    for i in range(max_iter):
        # affine scaling direction
        rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
            torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
            torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \
            p
        rs = z
        rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h \
            - torch.bmm(z.unsqueeze(1), F.transpose(1, 2)).squeeze(1)
        ry = torch.bmm(x.unsqueeze(1), A.transpose(
            1, 2)).squeeze(1) - b if neq > 0 else 0.0
        mu = torch.abs((s * z).sum(1).squeeze() / nineq)
        z_resid = torch.norm(rz, 2, 1).squeeze()
        y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
        pri_resid = y_resid + z_resid
        dual_resid = torch.norm(rx, 2, 1).squeeze()
        resids = pri_resid + dual_resid + nineq * mu

        d = z / s
        try:
            factor_kkt(S_LU, R, d)
        except:
            return best['x'], best['y'], best['z'], best['s']

        if verbose > 0:
            print(
                'iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.
                format(i, pri_resid.mean(), dual_resid.mean(), mu.mean()))
        if best['resids'] is None:
            best['resids'] = resids
            best['x'] = x.clone()
            best['z'] = z.clone()
            best['s'] = s.clone()
            best['y'] = y.clone() if y is not None else None
            nNotImproved = 0
        else:
            I = resids < best['resids']
            if I.sum() > 0:
                nNotImproved = 0
            else:
                nNotImproved += 1
            I_nz = I.repeat(nz, 1).t()
            I_nineq = I.repeat(nineq, 1).t()
            # best['resids'][I] = resids[I]
            # best['x'][I_nz] = x[I_nz]
            # best['z'][I_nineq] = z[I_nineq]
            # best['s'][I_nineq] = s[I_nineq]
            best['resids'].masked_scatter_(I, resids[I])
            best['x'].masked_scatter_(I_nz, x[I_nz])
            best['z'].masked_scatter_(I_nineq, z[I_nineq])
            best['s'].masked_scatter_(I_nineq, s[I_nineq])
            if neq > 0:
                I_neq = I.repeat(neq, 1).t()
                best['y'][I_neq] = y[I_neq]
        if nNotImproved == not_improved_lim or best['resids'].max().item(
        ) < eps or mu.min().item() > 1e100:
            if best['resids'].max() > 1. and verbose >= 0:
                print(INACC_ERR)
            return best['x'], best['y'], best['z'], best['s']

        dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(Q_LU, d, G, A, S_LU, rx, rs,
                                                   rz, ry)

        # compute centering directions
        alpha = torch.min(torch.min(get_step(z, dz_aff), get_step(s, ds_aff)),
                          torch.ones(batch_size).type_as(Q))
        alpha_nineq = alpha.repeat(nineq, 1).t()
        t1 = s + alpha_nineq * ds_aff
        t2 = z + alpha_nineq * dz_aff
        t3 = torch.sum(t1 * t2, 1).squeeze()
        t4 = torch.sum(s * z, 1).squeeze()
        sig = (t3 / t4)**3

        rx = Q.new_zeros(batch_size, nz)
        rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
        rz = Q.new_zeros(batch_size, nineq)
        ry = Q.new_zeros(batch_size, neq)

        dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(Q_LU, d, G, A, S_LU, rx, rs,
                                                   rz, ry)

        dx = dx_aff + dx_cor
        ds = ds_aff + ds_cor
        dz = dz_aff + dz_cor
        dy = dy_aff + dy_cor if neq > 0 else None
        alpha = torch.min(0.999 * torch.min(get_step(z, dz), get_step(s, ds)),
                          Q.new_ones(batch_size))
        alpha_nineq = alpha.repeat(nineq, 1).t()
        alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
        alpha_nz = alpha.repeat(nz, 1).t()

        x += alpha_nz * dx
        s += alpha_nineq * ds
        z += alpha_nineq * dz
        y = y + alpha_neq * dy if neq > 0 else None

    if best['resids'].max() > 1. and verbose >= 0:
        print(INACC_ERR)
        print(best['resids'].max())
    return best['x'], best['y'], best['z'], best['s']
Example #12
0
def forward(Q,
            p,
            G,
            h,
            A,
            b,
            F,
            Q_LU,
            S_LU,
            R,
            eps=1e-12,
            verbose=-1,
            notImprovedLim=3,
            maxIter=20,
            solver=KKTSolvers.LU_PARTIAL):
    """
    Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    """
    nineq, nz, neq, nBatch = get_sizes(G, A)

    # Find initial values
    if solver == KKTSolvers.LU_FULL:
        D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
        reg_eps = 1e-7
        Q_tilde = Q + reg_eps * torch.eye(nz).type_as(Q).repeat(nBatch, 1, 1)
        D_tilde = D + reg_eps * torch.eye(nineq).type_as(Q).repeat(
            nBatch, 1, 1)

        if neq > 0:
            A_ = torch.cat([
                torch.cat([
                    G,
                    torch.eye(nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)
                ], 2),
                torch.cat(
                    [A, torch.zeros(nBatch, neq, nineq).type_as(Q_tilde)], 2)
            ], 1)
        else:
            A_ = torch.cat(
                [G, torch.eye(nineq).type_as(Q_tilde).unsqueeze(0)], 2)

        C_tilde = reg_eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(
            nBatch, 1, 1)
        if F is not None:
            C_tilde[:, :nineq, :nineq] += F
        ns = [nineq, nz, neq, nBatch]
        x, s, z, y = factor_solve_kkt(Q_tilde, D_tilde, A_, C_tilde, p,
                                      torch.zeros(nBatch, nineq).type_as(Q),
                                      -h, -b if b is not None else None, ns)
    elif solver == KKTSolvers.SP_LU_FULL:
        # TODO Have it work for batches
        D = eye(nineq, format='csc')
        reg_eps = 1e-7
        Q_tilde = csc_matrix(
            Q.squeeze(0).numpy()) + reg_eps * eye(nz, format='csc')
        D_tilde = D * (1 + reg_eps)

        if neq > 0:
            A_ = vstack([
                hstack([
                    csc_matrix(G.squeeze(0).numpy()),
                    eye(nineq, format='csc')
                ],
                       format='csc'),
                hstack([
                    csc_matrix(A.squeeze(0).numpy()),
                    csc_matrix((neq, nineq))
                ],
                       format='csc')
            ],
                        format='csc')
        else:
            A_ = hstack([G, eye(nineq, format='csc')])

        # XXX
        C_tilde = reg_eps * np.eye(neq + nineq)
        if F is not None:
            C_tilde[:nineq, :nineq] += F.squeeze(0).numpy()
        C_tilde = csc_matrix(C_tilde)
        ns = [nineq, nz, neq, nBatch]
        x, s, z, y = sparse_factor_solve_kkt(
            Q_tilde, D_tilde, A_, C_tilde, p,
            torch.zeros(nBatch, nineq).type_as(Q), -h,
            -b if b is not None else None, ns)
    elif solver == KKTSolvers.LU_PARTIAL:
        # XXX
        reg_eps = 1e-7
        d = torch.ones(nBatch, nineq).type_as(Q)  # * (1 + reg_eps)
        factor_kkt(S_LU, R, d)
        x, s, z, y = solve_kkt(Q_LU, d, G, A, S_LU, p,
                               torch.zeros(nBatch, nineq).type_as(Q), -h,
                               -b if neq > 0 else None)
    elif solver == KKTSolvers.IR_UNOPT:
        D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
        x, s, z, y = solve_kkt_ir(Q, D, G, A, F, p,
                                  torch.zeros(nBatch, nineq).type_as(Q), -h,
                                  -b if b is not None else None)
    elif solver == KKTSolvers.SP_IR_UNOPT:
        D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
        x, s, z, y = sparse_solve_kkt_ir(Q, D, G, A, F, p,
                                         torch.zeros(nBatch, nineq).type_as(Q),
                                         -h, -b if b is not None else None)
    elif solver == KKTSolvers.IR_INVERSE:
        D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
        x, s, z, y = solve_kkt_ir_inverse(
            Q, D, G, A, F, p,
            torch.zeros(nBatch, nineq).type_as(Q), -h,
            -b if b is not None else None)
    elif solver == KKTSolvers.SP_IR_INVERSE:
        reg_eps = 1e-7
        D = eye(nineq)
        D_tilde = D + reg_eps * eye(nineq)

        Q_tilde = csc_matrix(Q.squeeze(0).numpy()) + reg_eps * eye(nz)
        H_ = block_diag([Q_tilde, D_tilde], format='csc')
        if neq > 0:
            A_ = vstack([
                hstack([
                    csc_matrix(G.squeeze(0).numpy()),
                    eye(nineq, format='csc')
                ],
                       format='csc'),
                hstack([
                    csc_matrix(A.squeeze(0).numpy()),
                    csc_matrix((neq, nineq))
                ],
                       format='csc')
            ],
                        format='csc')
        else:
            A_ = hstack([G, eye(nineq, format='csc')])

        # TODO Test batche size > 1
        # XXX Shouldn't the sign below be positive? (Since its going to be subtracted later)
        C_tilde = -eps * eye(neq + nineq, format='csc')
        if F is not None:  # XXX inverted sign for F below
            C_tilde[:nineq, :nineq] -= F.squeeze(0).numpy()
        F_tilde = C_tilde[:nineq, :nineq]
        # C_tilde = csc_matrix(C_tilde.squeeze(0).numpy())

        x, s, z, y = sparse_solve_kkt_ir_inverse(
            H_, A_, C_tilde, Q_tilde, D_tilde, G, A, F_tilde, p,
            torch.zeros(nBatch, nineq).type_as(Q), -h,
            -b if b is not None else None)
    else:
        assert False

    M = torch.min(s, 1)[0].repeat(1, nineq)
    I = M <= 0
    s[I] -= M[I] - 1

    M = torch.min(z, 1)[0].repeat(1, nineq)
    I = M <= 0
    z[I] -= M[I] - 1

    best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
    nNotImproved = 0

    for i in range(maxIter):
        # affine scaling direction
        rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
            torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
            torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \
            p
        rs = z
        rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
        if F is not None:  # XXX (Inverted sign for F below)
            rz -= torch.bmm(z.unsqueeze(1), F.transpose(1, 2)).squeeze(1)
        ry = torch.bmm(x.unsqueeze(1), A.transpose(
            1, 2)).squeeze(1) - b if neq > 0 else 0.0
        mu = torch.abs((s * z).sum(1).squeeze() / nineq)
        z_resid = torch.norm(rz, 2, 1).squeeze()
        y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
        pri_resid = y_resid + z_resid
        dual_resid = torch.norm(rx, 2, 1).squeeze()
        resids = pri_resid + dual_resid + nineq * mu

        d = z / s
        if solver == KKTSolvers.LU_PARTIAL:
            try:
                factor_kkt(S_LU, R, d)
            except:
                return best['x'], best['y'], best['z'], best['s']

        if verbose == 1:
            print(
                'iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.
                format(i, pri_resid.mean(), dual_resid.mean(), mu.mean()))
        if best['resids'] is None:
            best['resids'] = resids
            best['x'] = x.clone()
            best['z'] = z.clone()
            best['s'] = s.clone()
            best['y'] = y.clone() if y is not None else None
            nNotImproved = 0
        else:
            I = resids < best['resids']
            if I.sum() > 0:
                nNotImproved = 0
            else:
                nNotImproved += 1
            I_nz = I.repeat(nz, 1).t()
            I_nineq = I.repeat(nineq, 1).t()
            best['resids'][I] = resids[I]
            best['x'][I_nz] = x[I_nz]
            best['z'][I_nineq] = z[I_nineq]
            best['s'][I_nineq] = s[I_nineq]
            if neq > 0:
                I_neq = I.repeat(neq, 1).t()
                best['y'][I_neq] = y[I_neq]
        if nNotImproved == notImprovedLim or best['resids'].max(
        ) < eps or mu.min() > 1e100:
            if best['resids'].max() > 1. and verbose >= 0:
                print(INACC_ERR)
                print(best['resids'].max())
            return best['x'], best['y'], best['z'], best['s']

        if solver == KKTSolvers.LU_FULL:
            D = bdiag(d)
            D_tilde = D + reg_eps * torch.eye(nineq).type_as(Q).repeat(
                nBatch, 1, 1)
            dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt(
                Q_tilde, D_tilde, A_, C_tilde, rx, rs, rz, ry, ns)
        elif solver == KKTSolvers.SP_LU_FULL:
            D = diags(d.squeeze(0).numpy())
            D_tilde = D + reg_eps * eye(nineq, format='csc')

            dx_aff, ds_aff, dz_aff, dy_aff = sparse_factor_solve_kkt(
                Q_tilde, D_tilde, A_, C_tilde, rx, rs, rz, ry, ns)
        elif solver == KKTSolvers.LU_PARTIAL:
            dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(Q_LU, d, G, A, S_LU, rx,
                                                       rs, rz, ry)
        elif solver == KKTSolvers.IR_UNOPT:
            D = bdiag(d)
            dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir(
                Q, D, G, A, F, rx, rs, rz, ry)
        elif solver == KKTSolvers.SP_IR_UNOPT:
            D = bdiag(d)
            dx_aff, ds_aff, dz_aff, dy_aff = sparse_solve_kkt_ir(
                Q, D, G, A, F, rx, rs, rz, ry)
        elif solver == KKTSolvers.IR_INVERSE:
            D = bdiag(d)
            dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir_inverse(
                Q, D, G, A, F, rx, rs, rz, ry)
        elif solver == KKTSolvers.SP_IR_INVERSE:
            D = diags(d.squeeze(0).numpy())
            D_tilde = D + reg_eps * eye(nineq)
            # H_ = block_diag([Q_tilde.squeeze(0).numpy(), D_tilde.squeeze(0).numpy()], format='csc')
            H_ = block_diag([Q_tilde, D_tilde], format='csc')
            dx_aff, ds_aff, dz_aff, dy_aff = sparse_solve_kkt_ir_inverse(
                H_, A_, C_tilde, Q_tilde, D_tilde, G, A, F_tilde, rx, rs, rz,
                ry)
        else:
            assert False

        # compute centering directions
        alpha = torch.min(torch.min(get_step(z, dz_aff), get_step(s, ds_aff)),
                          torch.ones(nBatch).type_as(Q))
        alpha_nineq = alpha.repeat(nineq, 1).t()
        t1 = s + alpha_nineq * ds_aff
        t2 = z + alpha_nineq * dz_aff
        t3 = torch.sum(t1 * t2, 1).squeeze()
        t4 = torch.sum(s * z, 1).squeeze()
        sig = (t3 / t4)**3

        rx = torch.zeros(nBatch, nz).type_as(Q)
        rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
        rz = torch.zeros(nBatch, nineq).type_as(Q)
        ry = torch.zeros(nBatch, neq).type_as(Q)

        if solver == KKTSolvers.LU_FULL:
            D = bdiag(d)
            D_tilde = D + reg_eps * torch.eye(nineq).type_as(Q).repeat(
                nBatch, 1, 1)
            dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(
                Q_tilde, D_tilde, A_, C_tilde, rx, rs, rz, ry, ns)
        elif solver == KKTSolvers.SP_LU_FULL:
            D = diags(d.squeeze(0).numpy())
            D_tilde = D + reg_eps * eye(nineq, format='csc')

            dx_cor, ds_cor, dz_cor, dy_cor = sparse_factor_solve_kkt(
                Q_tilde, D_tilde, A_, C_tilde, rx, rs, rz, ry, ns)
        elif solver == KKTSolvers.LU_PARTIAL:
            dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(Q_LU, d, G, A, S_LU, rx,
                                                       rs, rz, ry)
        elif solver == KKTSolvers.IR_UNOPT:
            D = bdiag(d)
            dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir(
                Q, D, G, A, F, rx, rs, rz, ry)
        elif solver == KKTSolvers.SP_IR_UNOPT:
            D = bdiag(d)
            dx_cor, ds_cor, dz_cor, dy_cor = sparse_solve_kkt_ir(
                Q, D, G, A, F, rx, rs, rz, ry)
        elif solver == KKTSolvers.IR_INVERSE:
            D = bdiag(d)
            dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir_inverse(
                Q, D, G, A, F, rx, rs, rz, ry)
        elif solver == KKTSolvers.SP_IR_INVERSE:
            D = diags(d.squeeze(0).numpy())
            D_tilde = D + reg_eps * eye(nineq)
            # H_ = block_diag([Q_tilde.squeeze(0).numpy(), D_tilde.squeeze(0).numpy()], format='csc')
            H_ = block_diag([Q_tilde, D_tilde], format='csc')
            dx_cor, ds_cor, dz_cor, dy_cor = sparse_solve_kkt_ir_inverse(
                H_, A_, C_tilde, Q_tilde, D_tilde, G, A, F_tilde, rx, rs, rz,
                ry)
        else:
            assert False

        dx = dx_aff + dx_cor
        ds = ds_aff + ds_cor
        dz = dz_aff + dz_cor
        dy = dy_aff + dy_cor if neq > 0 else None
        alpha = torch.min(0.999 * torch.min(get_step(z, dz), get_step(s, ds)),
                          torch.ones(nBatch).type_as(Q))
        alpha_nineq = alpha.repeat(nineq, 1).t()
        alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
        alpha_nz = alpha.repeat(nz, 1).t()

        x += alpha_nz * dx
        s += alpha_nineq * ds
        z += alpha_nineq * dz
        y = y + alpha_neq * dy if neq > 0 else None

    if best['resids'].max() > 1. and verbose >= 0:
        print(INACC_ERR)
        print(best['resids'].max())
    return best['x'], best['y'], best['z'], best['s']