Exemple #1
0
def factor_kkt(S_LU, R, d):
    """ Factor the U22 block that we can only do after we know D. """
    nineq = d.size(0)
    neq = S_LU[1].size(0) - nineq
    global factor_kkt_eye
    if factor_kkt_eye is None or factor_kkt_eye.size() != d.size():
        factor_kkt_eye = torch.eye(nineq).type_as(R).bool()
    T = R.clone()
    T[factor_kkt_eye] += (1. / d).squeeze().view(-1)

    T_LU = lu_hack(T)

    # TODO: Don't use pivoting in most cases because
    # torch.lu_unpack is inefficient here:
    oldPivotsPacked = S_LU[1][-nineq:] - neq
    oldPivots, _, _ = torch.lu_unpack(T_LU[0],
                                      oldPivotsPacked,
                                      unpack_data=False)
    newPivotsPacked = T_LU[1]
    newPivots, _, _ = torch.lu_unpack(T_LU[0],
                                      newPivotsPacked,
                                      unpack_data=False)

    # Re-pivot the S_LU_21 block.
    if neq > 0:
        S_LU_21 = S_LU[0][-nineq:, :neq]
        S_LU[0][-nineq:, :neq] = newPivots.T.mm(oldPivots.mm(S_LU_21))

    # Add the new S_LU_22 block pivots.
    S_LU[1][-nineq:] = newPivotsPacked + neq

    # Add the new S_LU_22 block.
    S_LU[0][-nineq:, -nineq:] = T_LU[0]
    def __init__(self, channels=3, permutation=None):
        super(GeneralizedChannelPermute, self).__init__()
        self.__delattr__('permutation')

        # Sample a random orthogonal matrix
        W, _ = torch.qr(torch.randn(channels, channels))

        # Construct the partially pivoted LU-form and the pivots
        LU, pivots = W.lu()

        # Convert the pivots into the permutation matrix
        if permutation is None:
            P, _, _ = torch.lu_unpack(LU, pivots)
        else:
            if len(permutation) != channels:
                raise ValueError(
                    'Keyword argument "permutation" expected to have {} elements but {} found.'
                    .format(channels, len(permutation)))
            P = torch.eye(channels,
                          channels)[permutation.type(dtype=torch.int64)]

        # We register the permutation matrix so that the model can be serialized
        self.register_buffer('permutation', P)

        # NOTE: For this implementation I have chosen to store the parameters densely, rather than
        # storing L, U, and s separately
        self.LU = torch.nn.Parameter(LU)
Exemple #3
0
    def __init__(self, num_features, LU_decomposed=True):
        super(Invertible1x1ConvLU, self).__init__()
        w_shape = [num_features, num_features]
        w_init = torch.qr(torch.randn(*w_shape))[0]

        self.num_features = num_features

        if not LU_decomposed:
            self.weight = nn.Parameter(torch.Tensor(w_init))
        else:
            p, lower, upper = torch.lu_unpack(*torch.lu(w_init))
            s = torch.diag(upper)
            sign_s = torch.sign(s)
            log_s = torch.log(torch.abs(s))
            upper = torch.triu(upper, 1)
            l_mask = torch.tril(torch.ones(w_shape), -1)
            eye = torch.eye(*w_shape)

            self.register_buffer("p", p)
            self.register_buffer("sign_s", sign_s)
            self.lower = nn.Parameter(lower)
            self.log_s = nn.Parameter(log_s)
            self.upper = nn.Parameter(upper)
            self.l_mask = l_mask
            self.eye = eye

        self.if_LU = LU_decomposed
        self.w_shape = w_shape
Exemple #4
0
    def __init__(self, num_channels, LU_decomposed):
        super().__init__()
        w_shape = [num_channels, num_channels]
        w_init = torch.qr(torch.randn(*w_shape))[0]

        if not LU_decomposed:
            self.weight = nn.Parameter(torch.Tensor(w_init))
        else:
            p, lower, upper = torch.lu_unpack(*torch.lu(w_init))
            s = torch.diag(upper)
            sign_s = torch.sign(s)
            log_s = torch.log(torch.abs(s))
            upper = torch.triu(upper, 1)
            l_mask = torch.tril(torch.ones(w_shape), -1)
            eye = torch.eye(*w_shape)

            self.register_buffer("p", p)
            self.register_buffer("sign_s", sign_s)
            self.lower = nn.Parameter(lower)
            self.log_s = nn.Parameter(log_s)
            self.upper = nn.Parameter(upper)
            self.l_mask = l_mask
            self.eye = eye

        self.w_shape = w_shape
        self.LU_decomposed = LU_decomposed
Exemple #5
0
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)
    
    
    try:
        Q_LU = lu_hack(Q)
    except:
        raise RuntimeError("""
qpth Error: Cannot perform LU factorization on Q.
Please make sure that your Q matrix is PSD and has
a non-zero diagonal.
""")

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    #
    # We compute a partial LU decomposition of the S matrix
    # that can be completed once D^{-1} is known.
    # See the 'Block LU factorization' part of our website
    # for more details.
    qlu, pivots=Q_LU
   ##put in a condition for m>1
    if G.size(2)==2: ##something funny here
        print(G.size(),qlu.size())
        G_invQ_GT = torch.matmul(G, G.lu_solve(*Q_LU).transpose(1,2))
    else:
        print(G.size(2),G.transpose(1,2).size(),qlu.size())
        G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU))
    
    R = G_invQ_GT.clone()
    S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
        .repeat(nBatch, 1).type_as(Q).int()
    if neq > 0:
        invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU)
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)

        LU_A_invQ_AT = lu_hack(A_invQ_AT)
        P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT)
        P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)

        S_LU_11 = LU_A_invQ_AT[0]
        U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)
                           ).lu_solve(*LU_A_invQ_AT)
        S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
        T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
        S_LU_data = torch.cat((torch.cat((S_LU_11, S_LU_12), 2),
                               torch.cat((S_LU_21, S_LU_22), 2)),
                              1)
        S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]

        R -= G_invQ_AT.bmm(T)
    else:
        S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)

    S_LU = [S_LU_data, S_LU_pivots]
    return Q_LU, S_LU, R
Exemple #6
0
    def __init__(self, num_channels, LU_decomposed=True):
        super().__init__()
        self.num_channels = num_channels
        self.LU_decomposed = LU_decomposed
        weight_shape = [num_channels, num_channels]
        weight, _ = torch.qr(torch.randn(*weight_shape))
        if not self.LU_decomposed:
            self.weight = nn.Parameter(weight)
        else:
            weight_lu, pivots = torch.lu(weight)
            w_p, w_l, w_u = torch.lu_unpack(weight_lu, pivots)
            w_s = torch.diag(w_u)
            sign_s = torch.sign(w_s)
            log_s = torch.log(torch.abs(w_s))
            w_u = torch.triu(w_u, 1)

            u_mask = torch.triu(torch.ones_like(w_u), 1)
            l_mask = u_mask.T.contiguous()
            eye = torch.eye(l_mask.shape[0])

            self.register_buffer('p', w_p)
            self.register_buffer('sign_s', sign_s)
            self.register_buffer('eye', eye)
            self.register_buffer('u_mask', u_mask)
            self.register_buffer('l_mask', l_mask)
            self.l = nn.Parameter(w_l)
            self.u = nn.Parameter(w_u)
            self.log_s = nn.Parameter(log_s)
Exemple #7
0
    def __init__(self, in_out_channels):
        super(InvertibleConv1x1, self).__init__()

        W = torch.zeros((in_out_channels, in_out_channels),
                        dtype=torch.float32)
        nn.init.orthogonal_(W)
        LU, pivots = torch.lu(W)

        P, L, U = torch.lu_unpack(LU, pivots)
        self.P = nn.Parameter(P, requires_grad=False)
        self.L = nn.Parameter(L, requires_grad=True)
        self.U = nn.Parameter(U, requires_grad=True)
        self.I = nn.Parameter(torch.eye(in_out_channels), requires_grad=False)
        self.pivots = nn.Parameter(pivots, requires_grad=False)

        L_mask = np.tril(np.ones((in_out_channels, in_out_channels),
                                 dtype='float32'),
                         k=-1)
        U_mask = L_mask.T.copy()
        self.L_mask = nn.Parameter(torch.from_numpy(L_mask),
                                   requires_grad=False)
        self.U_mask = nn.Parameter(torch.from_numpy(U_mask),
                                   requires_grad=False)

        s = torch.diag(U)
        sign_s = torch.sign(s)
        log_s = torch.log(torch.abs(s))
        self.log_s = nn.Parameter(log_s, requires_grad=True)
        self.sign_s = nn.Parameter(sign_s, requires_grad=False)
Exemple #8
0
 def sampling_W(self, dim):
     # sample a rotation matrix
     W = torch.empty(dim, dim)
     torch.nn.init.orthogonal_(W)
     # compute LU
     LU, pivot = torch.lu(W)
     P, L, U = torch.lu_unpack(LU, pivot)
     return W, P, L, U
 def __init__(self, dim):
     super().__init__()
     self.dim = dim
     Q = nn.init.orthogonal_(torch.randn(dim, dim).to(device))
     P, L, U = torch.lu_unpack(*torch.lu(Q))
     self.register_buffer('P', P)
     self.L = nn.Parameter(L) # lower triangular portion
     self.S = nn.Parameter(U.diag()) # "crop out" the diagonal to its own parameter
     self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S
Exemple #10
0
 def __init__(self, dim):
     super().__init__()
     self.dim = dim
     Q = torch.nn.init.orthogonal_(torch.randn(dim, dim).to("cuda:0"))
     P, L, U = torch.lu_unpack(*Q.lu())
     self.P = P
     self.L = nn.Parameter(L)
     self.S = nn.Parameter(U.diag())
     self.U = nn.Parameter(torch.triu(U, diagonal=1))
 def __init__(self, dim):
     super().__init__()
     self.dim = dim
     Q = torch.nn.init.orthogonal_(torch.randn(dim, dim))
     P, L, U = torch.lu_unpack(*Q.lu())
     self.P = P  # remains fixed during optimization
     self.L = nn.Parameter(L)  # lower triangular portion
     self.S = nn.Parameter(U.diag())  # "crop out" the diagonal to its own parameter
     self.U = nn.Parameter(torch.triu(U, diagonal=1))  # "crop out" diagonal, stored in S
Exemple #12
0
 def __init__(self, dim):
     super().__init__()
     self.register_buffer('placeholder', torch.randn(1))
     self.dim = dim
     Q = torch.nn.init.orthogonal_(torch.randn(dim, dim))
     P, L, U = torch.lu_unpack(*Q.lu())
     self.P = nn.Parameter(P, requires_grad=False)  # remains fixed during optimization
     self.L = nn.Parameter(L, requires_grad=True)  # lower triangular portion
     self.S = nn.Parameter(U.diag(), requires_grad=True)  # "crop out" the diagonal to its own parameter
     self.U = nn.Parameter(torch.triu(U, diagonal=1), requires_grad=True)  # "crop out" diagonal, stored in S
Exemple #13
0
 def blas_lapack_ops(self):
     m = torch.randn(3, 3)
     a = torch.randn(10, 3, 4)
     b = torch.randn(10, 4, 3)
     v = torch.randn(3)
     return (
         torch.addbmm(m, a, b),
         torch.addmm(torch.randn(2, 3), torch.randn(2, 3),
                     torch.randn(3, 3)),
         torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)),
         torch.addr(torch.zeros(3, 3), v, v),
         torch.baddbmm(m, a, b),
         torch.bmm(a, b),
         torch.chain_matmul(torch.randn(3, 3), torch.randn(3, 3),
                            torch.randn(3, 3)),
         # torch.cholesky(a), # deprecated
         torch.cholesky_inverse(torch.randn(3, 3)),
         torch.cholesky_solve(torch.randn(3, 3), torch.randn(3, 3)),
         torch.dot(v, v),
         torch.eig(m),
         torch.geqrf(a),
         torch.ger(v, v),
         torch.inner(m, m),
         torch.inverse(m),
         torch.det(m),
         torch.logdet(m),
         torch.slogdet(m),
         torch.lstsq(m, m),
         torch.lu(m),
         torch.lu_solve(m, *torch.lu(m)),
         torch.lu_unpack(*torch.lu(m)),
         torch.matmul(m, m),
         torch.matrix_power(m, 2),
         # torch.matrix_rank(m),
         torch.matrix_exp(m),
         torch.mm(m, m),
         torch.mv(m, v),
         # torch.orgqr(a, m),
         # torch.ormqr(a, m, v),
         torch.outer(v, v),
         torch.pinverse(m),
         # torch.qr(a),
         torch.solve(m, m),
         torch.svd(a),
         # torch.svd_lowrank(a),
         # torch.pca_lowrank(a),
         # torch.symeig(a), # deprecated
         # torch.lobpcg(a, b), # not supported
         torch.trapz(m, m),
         torch.trapezoid(m, m),
         torch.cumulative_trapezoid(m, m),
         # torch.triangular_solve(m, m),
         torch.vdot(v, v),
     )
Exemple #14
0
 def __init__(self, shape):
     super().__init__()
     self.d_cpu = torch.prod(torch.tensor(shape))
     Q = torch.nn.init.orthogonal_(torch.randn(self.d_cpu, self.d_cpu))
     P, L, U = torch.lu_unpack(*Q.lu())
     self.register_buffer('P', P)  # remains fixed during optimization
     self.L = nn.Parameter(L)  # lower triangular portion
     self.S = nn.Parameter(
         U.diag())  # "crop out" the diagonal to its own parameter
     self.U = nn.Parameter(torch.triu(
         U, diagonal=1))  # "crop out" diagonal, stored in S
def factor_kkt(S_LU, R, d):
    """ Factor the U22 block that we can only do after we know D. """
    nBatch, nineq = d.size()
    neq = S_LU[1].size(1) - nineq
    # TODO: There's probably a better way to add a batched diagonal.
    global factor_kkt_eye
    if factor_kkt_eye is None or factor_kkt_eye.size() != d.size():
        # print('Updating batchedEye size.')
        factor_kkt_eye = torch.eye(nineq).repeat(nBatch, 1,
                                                 1).type_as(R).byte()
    T = R.clone()
    T[factor_kkt_eye] += (1. / d).squeeze().view(-1)

    T_LU = btrifact_hack(T)

    global shown_btrifact_warning
    if shown_btrifact_warning or not T.is_cuda:
        # TODO: Don't use pivoting in most cases because
        # torch.btriunpack is inefficient here:
        oldPivotsPacked = S_LU[1][:, -nineq:] - neq
        oldPivots, _, _ = torch.lu_unpack(T_LU[0],
                                          oldPivotsPacked,
                                          unpack_data=False)
        newPivotsPacked = T_LU[1]
        newPivots, _, _ = torch.lu_unpack(T_LU[0],
                                          newPivotsPacked,
                                          unpack_data=False)

        # Re-pivot the S_LU_21 block.
        if neq > 0:
            S_LU_21 = S_LU[0][:, -nineq:, :neq]
            S_LU[0][:, -nineq:, :neq] = newPivots.transpose(1, 2).bmm(
                oldPivots.bmm(S_LU_21))

        # Add the new S_LU_22 block pivots.
        S_LU[1][:, -nineq:] = newPivotsPacked + neq

    # Add the new S_LU_22 block.
    S_LU[0][:, -nineq:, -nineq:] = T_LU[0]
 def __init__(self, dim, device='cpu', condition_size=0):
     super().__init__()
     self.conditional = True  # forward backward unchanged when conditional, thus always True
     self.cond_size = condition_size  # for compatibility with the rest of the flow who have this attribute
     self.dim = dim
     self.device = device
     Q = torch.nn.init.orthogonal_(torch.randn(dim, dim).to(self.device))
     P, L, U = torch.lu_unpack(*Q.lu())
     self.P = P.to(self.device)  # remains fixed during optimization
     self.L = nn.Parameter(L).to(self.device)  # lower triangular portion
     self.S = nn.Parameter(U.diag()).to(
         self.device)  # "crop out" the diagonal to its own parameter
     self.U = nn.Parameter(torch.triu(U, diagonal=1)).to(
         self.device)  # "crop out" diagonal, stored in S
Exemple #17
0
    def __init__(self, in_channel, fixed, use_lu):
        super().__init__()
        self.use_lu = use_lu
        self.fixed = fixed
        if not use_lu:
            weight = th.randn(in_channel, in_channel)
            q, _ = th.qr(weight)
            weight = q
            if fixed:
                self.register_buffer('weight', weight.data)
                self.register_buffer('weight_inverse', weight.data.inverse())
                self.register_buffer(
                    'fixed_log_det',
                    th.slogdet(self.weight.double())[1].float())
            else:
                self.weight = nn.Parameter(weight)
        if use_lu:
            assert not fixed
            #weight = np.random.randn(in_channel, in_channel)
            weight = th.randn(in_channel, in_channel)
            #q, _ = la.qr(weight)
            q, _ = th.qr(weight)

            # w_p, w_l, w_u = la.lu(q.astype(np.float32))
            w_p, w_l, w_u = th.lu_unpack(*th.lu(q))

            #w_s = np.diag(w_u)
            w_s = th.diag(w_u)
            #w_u = np.triu(w_u, 1)
            w_u = th.triu(w_u, 1)
            #u_mask = np.triu(np.ones_like(w_u), 1)
            u_mask = th.triu(th.ones_like(w_u), 1)
            #l_mask = u_mask.T
            l_mask = u_mask.t()

            #w_p = th.from_numpy(w_p)
            #w_l = th.from_numpy(w_l)w
            #w_s = th.from_numpy(w_s)
            #w_u = th.from_numpy(w_u)

            self.register_buffer('w_p', w_p)
            self.register_buffer('u_mask', u_mask)
            self.register_buffer('l_mask', l_mask)
            self.register_buffer('s_sign', th.sign(w_s))
            self.register_buffer('l_eye', th.eye(l_mask.shape[0]))
            self.w_l = nn.Parameter(w_l)
            self.w_s = nn.Parameter(th.log(th.abs(w_s)))
            self.w_u = nn.Parameter(w_u)
Exemple #18
0
    def __init__(self, c):
        super(Invertible1x1ConvLUS, self).__init__()
        # Sample a random orthonormal matrix to initialize weights
        W, _ = torch.linalg.qr(torch.randn(c, c))
        # Ensure determinant is 1.0 not -1.0
        if torch.det(W) < 0:
            W[:, 0] = -1 * W[:, 0]
        p, lower, upper = torch.lu_unpack(*torch.lu(W))

        self.register_buffer('p', p)
        # diagonals of lower will always be 1s anyway
        lower = torch.tril(lower, -1)
        lower_diag = torch.diag(torch.eye(c, c))
        self.register_buffer('lower_diag', lower_diag)
        self.lower = nn.Parameter(lower)
        self.upper_diag = nn.Parameter(torch.diag(upper))
        self.upper = nn.Parameter(torch.triu(upper, 1))
Exemple #19
0
    def __init__(self, in_channel):
        super().__init__()

        weight = torch.randn(in_channel, in_channel)
        q, _ = torch.qr(weight)
        w_p, w_l, w_u = torch.lu_unpack(*q.lu())
        w_s = torch.diag(w_u)
        w_u = torch.triu(w_u, 1)
        u_mask = torch.triu(torch.ones_like(w_u), 1)
        l_mask = u_mask.T
        self.register_buffer("w_p", w_p)
        self.register_buffer("u_mask", u_mask)
        self.register_buffer("l_mask", l_mask)
        self.register_buffer("s_sign", torch.sign(w_s))
        self.register_buffer("l_eye", torch.eye(l_mask.size(0)))
        self.w_l = nn.Parameter(w_l)
        self.w_s = nn.Parameter(logabs(w_s))
        self.w_u = nn.Parameter(w_u)
Exemple #20
0
    def __init__(self, dim):
        super(Invertible1x1Conv, self).__init__()
        self.dim = dim

        # Grab the weight and bias from a randomly initialized Conv2d.
        m = nn.Conv2d(dim, dim, kernel_size=1)
        W = m.weight.clone().detach().reshape(dim, dim)
        LU, pivots = torch.lu(W)
        P, _, _ = torch.lu_unpack(LU, pivots)

        s = torch.diag(LU)
        # noinspection PyTypeChecker
        LU = torch.where(torch.eye(dim) == 0, LU, torch.zeros_like(LU))

        self.register_buffer("P", P)
        self.register_buffer("s_sign", torch.sign(s))
        self.register_parameter("s_log", nn.Parameter(torch.log(torch.abs(s) + 1e-3)))
        self.register_parameter("LU", nn.Parameter(LU))
Exemple #21
0
    def forward(ctx, input):

        # LUP decompose the matrices
        inp_lu, pivots = input.lu()
        perm, inpl, inpu = torch.lu_unpack(inp_lu, pivots)

        # get the number of permuations
        s = (pivots != torch.as_tensor(range(
            1, input.shape[1] + 1)).int()).sum(1).type(
                torch.get_default_dtype())

        # get the prod of the diag of U
        d = torch.diagonal(inpu, dim1=-2, dim2=-1).prod(1)

        # assemble
        det = ((-1)**s * d)
        ctx.save_for_backward(input, det)

        return det
Exemple #22
0
    def __init__(self, dim_inputs):
        super().__init__()
        self.dim_inputs = dim_inputs
        w_shape = [dim_inputs, dim_inputs]
        w_init = torch.qr(torch.randn(w_shape))[0]

        p, lower, upper = torch.lu_unpack(*torch.lu(w_init))
        s = torch.diag(upper)
        sign_s = torch.sign(s)
        log_s = torch.log(torch.abs(s))
        upper = torch.triu(upper, 1)
        l_mask = torch.tril(torch.ones(w_shape), -1)
        eye = torch.eye(*w_shape)

        self.register_buffer('p', p)
        self.register_buffer('sign_s', sign_s)
        self.lower = nn.Parameter(lower)
        self.log_s = nn.Parameter(log_s)
        self.upper = nn.Parameter(upper)
        self.register_buffer('l_mask', l_mask)
        self.register_buffer('eye', eye)
Exemple #23
0
    def __init__(self, num_channels, lu_decomposition=False):
        """
        Invertible 1x1 convolution layer

        :param num_channels: number of channels
        :type num_channels: int
        :param lu_decomposition: whether to use LU decomposition
        :type lu_decomposition: bool
        """
        super().__init__()
        self.num_channels = num_channels
        self.lu_decomposition = lu_decomposition
        w_shape = [num_channels, num_channels]
        tolerance = 1e-4
        # Sample a random orthogonal matrix
        w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32')
        if self.lu_decomposition:
            w_LU_pts, pivots = torch.lu(torch.Tensor(w_init))
            p, w_l, w_u = torch.lu_unpack(w_LU_pts, pivots)
            s = torch.diag(torch.diag(w_u))
            w_u -= s
            print((torch.Tensor(w_init) -
                   torch.matmul(p, torch.matmul(w_l, (w_u + s)))).abs().sum())
            assert (torch.Tensor(w_init) - torch.matmul(
                p, torch.matmul(w_l, (w_u + s)))).abs().sum() < tolerance
            self.register_parameter('weight_L',
                                    nn.Parameter(torch.Tensor(w_l)))
            self.register_parameter('weight_U',
                                    nn.Parameter(torch.Tensor(w_u)))
            self.register_parameter('s', nn.Parameter(torch.Tensor(s)))
            self.register_buffer('weight_P', torch.FloatTensor(p))

        else:
            #w_shape = [num_channels, num_channels]
            # Sample a random orthogonal matrix
            #w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32')
            self.register_parameter('weight',
                                    nn.Parameter(torch.Tensor(w_init)))
Exemple #24
0
 def __init__(self, num_channels, use_lu=False):
     """
     Constructor
     :param num_channels: Number of channels of the data
     :param use_lu: Flag whether to parametrize weights through the LU decomposition
     """
     super().__init__()
     self.num_channels = num_channels
     self.use_lu = use_lu
     Q = torch.qr(torch.randn(self.num_channels, self.num_channels))[0]
     if use_lu:
         P, L, U = torch.lu_unpack(*Q.lu())
         self.register_buffer('P', P)  # remains fixed during optimization
         self.L = nn.Parameter(L)  # lower triangular portion
         S = U.diag()  # "crop out" the diagonal to its own parameter
         self.register_buffer("sign_S", torch.sign(S))
         self.log_S = nn.Parameter(torch.log(torch.abs(S)))
         self.U = nn.Parameter(torch.triu(
             U, diagonal=1))  # "crop out" diagonal, stored in S
         self.register_buffer("eye",
                              torch.diag(torch.ones(self.num_channels)))
     else:
         self.W = nn.Parameter(Q)
def rank_revealing_LUP_GPU(A, threshold=10**-10):
    """
    It returns the rank of A using GPU acceleration. This is based on PyTorch's LUP implementation.
    
    Parameters
    __________

    A: (2D float64 Numpy Array)
    Input array for which rank needs to be computed

    threshold: (Float)
    See documentation of the function "sort_rows_fast"


    Returns:
    _________

    rank: (Integer)
    rank of A

    """

    try:
        import torch
    except ImportError as e:
        print(str(e))
        print("rank_revealing_LUP_GPU requires PyTorch to run.")

    A = A / np.abs(A).max()
    A_tensor = torch.from_numpy(A).cuda()
    A_LU, pivots = torch.lu(A_tensor, get_infos=False)
    _, _, u = torch.lu_unpack(A_LU, pivots, unpack_pivots=False)
    u = u.cpu().numpy()
    ref = compute_ref_LUP_fast(u, threshold)
    rank = compute_rank_ref(ref, threshold)

    return rank
    def __init__(self, in_channel, weight=None):
        # in_channel indicates the channel size of the input
        super().__init__()
        self.in_channel = in_channel

        # set an random orthogonal matrix as the initial weight
        if weight is None:
            weight = torch.randn(in_channel, in_channel)
        Weight, _ = torch.qr(weight)

        # PLU decomposition
        self.weight = Weight  # only for testing

        Weight_LU, pivots = torch.lu(Weight)
        w_p, w_l, w_u = torch.lu_unpack(Weight_LU, pivots)
        w_s = torch.diag(w_u)
        w_logs = torch.log(torch.abs(w_s))
        s_sign = torch.sign(w_s)

        w_u = torch.triu(w_u, 1)

        u_mask = torch.triu(torch.ones_like(w_u), 1)
        l_mask = u_mask.T

        l_eye = torch.eye(l_mask.shape[0])

        # fix P
        self.register_buffer("w_p", w_p)
        self.register_buffer("u_mask", u_mask)
        self.register_buffer("l_mask", l_mask)
        self.register_buffer("s_sign", s_sign)
        self.register_buffer("l_eye", l_eye)

        self.w_l = torch.nn.Parameter(w_l)
        self.w_u = torch.nn.Parameter(w_u)
        # self.w_s = torch.nn.Parameter(w_s)
        self.w_logs = torch.nn.Parameter(w_logs)
Exemple #27
0
    def __init__(self, nb_channels: int, lu_decomposition: bool = False):
        """
        Invertible 1x1 Convolution for 2D inputs with LU parameterization.
        References : See Glow paper for details https://arxiv.org/abs/1807.03039

        :param nb_channels: the number of input/output channels
        :param lu_decomposition: whether to use LU parameterization
        """
        super(Invertible1x1Conv, self).__init__()

        self.__channels = nb_channels
        # Initialize W as a random orthogonal matrix
        W = torch.zeros(nb_channels, nb_channels)
        nn.init.orthogonal_(W)
        # Make sure the det is 1 (and not -1) to have a defined log-det
        if torch.det(W) < 0:
            W[:, 0] = -1 * W[:, 0]
        self.__lu = lu_decomposition
        if self.__lu:
            W_LU, pivots = W.lu()
            P, L, U = torch.lu_unpack(W_LU, pivots)
            s = torch.diag(U)
            U = U - torch.diag(s)
            # Assign the module trainable parameters
            self.sign_s = nn.Parameter(torch.sign(s), requires_grad=True)
            self.log_s = nn.Parameter(torch.log(s.abs()), requires_grad=True)
            self.lower = nn.Parameter(L, requires_grad=True)
            self.upper = nn.Parameter(U, requires_grad=True)
            # Assign the non trainable ones
            self.register_buffer('permutation', P)
            self.register_buffer('permutation_inv', torch.inverse(P))
            self.register_buffer('eye', torch.eye(nb_channels))
            self.register_buffer(
                'l_mask', torch.tril(torch.ones(nb_channels, nb_channels), -1))
        else:
            self.weight = nn.Parameter(W, requires_grad=True)
Exemple #28
0
    def backward(ctx, LU_grad, pivots_grad, infors_grad):
        """
        Here we derive the gradients for the LU decomposition.
        LIMITATIONS: square inputs of full rank.
        If not stated otherwise, for tensors A and B,
        `A B` means the matrix product of A and B.

        Forward AD:
        Note that PyTorch returns packed LU, it is a mapping
        A -> (B:= L + U - I, P), such that A = P L U, and
        P is a permutation matrix, and is non-differentiable.

        Using B = L + U - I, A = P L U, we get

        dB = dL + dU and     (*)
        P^T dA = dL U + L dU (**)

        By left/right multiplication of (**) with L^{-1}/U^{-1} we get:
        L^{-1} P^T dA U^{-1} = L^{-1} dL + dU U^{-1}.

        Note that L^{-1} dL is lower-triangular with zero diagonal,
        and dU U^{-1} is upper-triangular.
        Define 1_U := triu(ones(n, n)), and 1_L := ones(n, n) - 1_U, so

        L^{-1} dL = 1_L * (L^{-1} P^T dA U^{-1}),
        dU U^{-1} = 1_U * (L^{-1} P^T dA U^{-1}), where * denotes the Hadamard product.

        Hence we finally get:
        dL = L 1_L * (L^{-1} P^T dA U^{-1}),
        dU = 1_U * (L^{-1} P^T dA U^{-1}) U

        Backward AD:
        The backward sensitivity is then:
        Tr(B_grad^T dB) = Tr(B_grad^T dL) + Tr(B_grad^T dU) = [1] + [2].

        [1] = Tr(B_grad^T dL) = Tr(B_grad^T L 1_L * (L^{-1} P^T dA U^{-1}))
            = [using Tr(A (B * C)) = Tr((A * B^T) C)]
            = Tr((B_grad^T L * 1_L^T) L^{-1} P^T dA U^{-1})
            = [cyclic property of trace]
            = Tr(U^{-1} (B_grad^T L * 1_L^T) L^{-1} P^T dA)
            = Tr((P L^{-T} (L^T B_grad * 1_L) U^{-T})^T dA).
        Similar, [2] can be rewritten as:
        [2] = Tr(P L^{-T} (B_grad U^T * 1_U) U^{-T})^T dA, hence
        Tr(A_grad^T dA) = [1] + [2]
                        = Tr((P L^{-T} (L^T B_grad * 1_L + B_grad U^T * 1_U) U^{-T})^T dA), so
        A_grad = P L^{-T} (L^T B_grad * 1_L + B_grad U^T * 1_U) U^{-T}.

        In the code below we use the name `LU` instead of `B`, so that there is no confusion
        in the derivation above between the matrix product and a two-letter variable name.
        """
        LU, pivots = ctx.saved_tensors
        P, L, U = torch.lu_unpack(LU, pivots)

        # To make sure MyPy infers types right
        assert (L is not None) and (U is not None)

        I = LU_grad.new_zeros(LU_grad.shape)
        I.diagonal(dim1=-2, dim2=-1).fill_(1)

        Lt_inv = torch.triangular_solve(I, L, upper=False).solution.transpose(-1, -2)
        Ut_inv = torch.triangular_solve(I, U, upper=True).solution.transpose(-1, -2)

        phi_L = (L.transpose(-1, -2) @ LU_grad).tril_()
        phi_L.diagonal(dim1=-2, dim2=-1).fill_(0.0)
        phi_U = (LU_grad @ U.transpose(-1, -2)).triu_()

        self_grad_perturbed = Lt_inv @ (phi_L + phi_U) @ Ut_inv
        return P @ self_grad_perturbed, None, None
Exemple #29
0
def pre_factor_kkt(Q, G, A):
    """ Perform all one-time factorizations and cache relevant matrix products"""
    nineq, nz, neq, nBatch = get_sizes(G, A)

    try:
        Q_LU = lu_hack(Q)
    except:
        raise RuntimeError("""
qpth Error: Cannot perform LU factorization on Q.
Please make sure that your Q matrix is PSD and has
a non-zero diagonal.
""")

    # S = [ A Q^{-1} A^T        A Q^{-1} G^T          ]
    #     [ G Q^{-1} A^T        G Q^{-1} G^T + D^{-1} ]
    #
    # We compute a partial LU decomposition of the S matrix
    # that can be completed once D^{-1} is known.
    # See the 'Block LU factorization' part of our website
    # for more details.

    G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU))
    R = G_invQ_GT.clone()
    S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \
        .repeat(nBatch, 1).type_as(Q).int()
    if neq > 0:
        invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU)
        # if any(torch.isnan(torch.flatten(invQ_AT)).tolist()):
        #     logging.info("nan comes in invq AT")
        # else:
        #     logging.info("non NAN in invq AT")
        A_invQ_AT = torch.bmm(A, invQ_AT)
        G_invQ_AT = torch.bmm(G, invQ_AT)
        # if any(torch.isnan(torch.flatten(G_invQ_AT)).tolist()):
        #     logging.info("nan comes in G_invQ_AT")
        # else:
        #     logging.info("non NAN in G_invQ_AT")
        LU_A_invQ_AT = lu_hack(A_invQ_AT)
        P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT)
        P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT)

        S_LU_11 = LU_A_invQ_AT[0]
        U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)).lu_solve(
            *LU_A_invQ_AT)
        # if any(torch.isnan(torch.flatten(U_A_invQ_AT_inv)).tolist()):
        #     logging.info("nan in U_A_invQ_AT_inv")
        S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv)
        T = sp_lu_solve(G_invQ_AT.transpose(1, 2), *LU_A_invQ_AT)
        # T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT)
        S_LU_12 = U_A_invQ_AT.bmm(T)
        S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q)
        S_LU_data = torch.cat((torch.cat(
            (S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1)
        S_LU_pivots[:, :neq] = LU_A_invQ_AT[1]
        # if any(torch.isnan(torch.flatten(T)).tolist()):
        #     logging.info("nan comes in T")
        # else:
        #     logging.info("non NAN in T")
        R -= G_invQ_AT.bmm(T)
        # if any(torch.isnan(torch.flatten(R)).tolist()):

        #     logging.info("nan is here")
        # R[torch.isnan(R)] = 0

    else:
        S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q)
    # S_LU_data[torch.isnan(S_LU_data)] = 0
    S_LU = [S_LU_data, S_LU_pivots]

    return Q_LU, S_LU, R
Exemple #30
0
    def backward(ctx, LU_grad, pivots_grad, infors_grad):
        """
        Here we derive the gradients for the LU decomposition.
        LIMITATIONS: square inputs of full rank.
        If not stated otherwise, for tensors A and B,
        `A B` means the matrix product of A and B.

        Let A^H = (A^T).conj()

        Forward AD:
        Note that PyTorch returns packed LU, it is a mapping
        A -> (B:= L + U - I, P), such that A = P L U, and
        P is a permutation matrix, and is non-differentiable.

        Using B = L + U - I, A = P L U, we get

        dB = dL + dU and     (*)
        P^T dA = dL U + L dU (**)

        By left/right multiplication of (**) with L^{-1}/U^{-1} we get:
        L^{-1} P^T dA U^{-1} = L^{-1} dL + dU U^{-1}.

        Note that L^{-1} dL is lower-triangular with zero diagonal,
        and dU U^{-1} is upper-triangular.
        Define 1_U := triu(ones(n, n)), and 1_L := ones(n, n) - 1_U, so

        L^{-1} dL = 1_L * (L^{-1} P^T dA U^{-1}),
        dU U^{-1} = 1_U * (L^{-1} P^T dA U^{-1}), where * denotes the Hadamard product.

        Hence we finally get:
        dL = L 1_L * (L^{-1} P^T dA U^{-1}),
        dU = 1_U * (L^{-1} P^T dA U^{-1}) U

        Backward AD:
        The backward sensitivity is then:
        Tr(B_grad^H dB) = Tr(B_grad^H dL) + Tr(B_grad^H dU) = [1] + [2].

        [1] = Tr(B_grad^H dL) = Tr(B_grad^H L 1_L * (L^{-1} P^T dA U^{-1}))
            = [using Tr(A (B * C)) = Tr((A * B^T) C)]
            = Tr((B_grad^H L * 1_L^T) L^{-1} P^T dA U^{-1})
            = [cyclic property of trace]
            = Tr(U^{-1} (B_grad^H L * 1_L^T) L^{-1} P^T dA)
            = Tr((P L^{-H} (L^H B_grad * 1_L) U^{-H})^H dA).
        Similar, [2] can be rewritten as:
        [2] = Tr(P L^{-H} (B_grad U^H * 1_U) U^{-H})^H dA, hence
        Tr(A_grad^H dA) = [1] + [2]
                        = Tr((P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H})^H dA), so
        A_grad = P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H}.

        In the code below we use the name `LU` instead of `B`, so that there is no confusion
        in the derivation above between the matrix product and a two-letter variable name.
        """
        LU, pivots = ctx.saved_tensors
        P, L, U = torch.lu_unpack(LU, pivots)

        # To make sure MyPy infers types right
        assert (L is not None) and (U is not None) and (P is not None)

        # phi_L = L^H B_grad * 1_L
        phi_L = (L.transpose(-1, -2).conj() @ LU_grad).tril_()
        phi_L.diagonal(dim1=-2, dim2=-1).fill_(0.0)
        # phi_U = B_grad U^H * 1_U
        phi_U = (LU_grad @ U.transpose(-1, -2).conj()).triu_()
        phi = phi_L + phi_U

        # using the notation from above plus the variable names, note
        # A_grad = P L^{-H} phi U^{-H}.
        # Instead of inverting L and U, we solve two systems of equations, i.e.,
        # the above expression could be rewritten as
        # L^H P^T A_grad U^H = phi.
        # Let X = P^T A_grad U_H, then
        # X = L^{-H} phi, where L^{-H} is upper triangular, or
        # X = torch.triangular_solve(phi, L^H)
        # using the definition of X we see:
        # X = P^T A_grad U_H => P X = A_grad U_H => U A_grad^H = X^H P^T, so
        # A_grad = (U^{-1} X^H P^T)^H, or
        # A_grad = torch.triangular_solve(X^H P^T, U)^H
        X = torch.triangular_solve(phi, L.transpose(-1, -2).conj(),
                                   upper=True).solution
        A_grad = torch.triangular_solve(X.transpose(-1, -2).conj() @ P.transpose(-1, -2), U, upper=True) \
            .solution.transpose(-1, -2).conj()

        return A_grad, None, None