def factor_kkt(S_LU, R, d): """ Factor the U22 block that we can only do after we know D. """ nineq = d.size(0) neq = S_LU[1].size(0) - nineq global factor_kkt_eye if factor_kkt_eye is None or factor_kkt_eye.size() != d.size(): factor_kkt_eye = torch.eye(nineq).type_as(R).bool() T = R.clone() T[factor_kkt_eye] += (1. / d).squeeze().view(-1) T_LU = lu_hack(T) # TODO: Don't use pivoting in most cases because # torch.lu_unpack is inefficient here: oldPivotsPacked = S_LU[1][-nineq:] - neq oldPivots, _, _ = torch.lu_unpack(T_LU[0], oldPivotsPacked, unpack_data=False) newPivotsPacked = T_LU[1] newPivots, _, _ = torch.lu_unpack(T_LU[0], newPivotsPacked, unpack_data=False) # Re-pivot the S_LU_21 block. if neq > 0: S_LU_21 = S_LU[0][-nineq:, :neq] S_LU[0][-nineq:, :neq] = newPivots.T.mm(oldPivots.mm(S_LU_21)) # Add the new S_LU_22 block pivots. S_LU[1][-nineq:] = newPivotsPacked + neq # Add the new S_LU_22 block. S_LU[0][-nineq:, -nineq:] = T_LU[0]
def __init__(self, channels=3, permutation=None): super(GeneralizedChannelPermute, self).__init__() self.__delattr__('permutation') # Sample a random orthogonal matrix W, _ = torch.qr(torch.randn(channels, channels)) # Construct the partially pivoted LU-form and the pivots LU, pivots = W.lu() # Convert the pivots into the permutation matrix if permutation is None: P, _, _ = torch.lu_unpack(LU, pivots) else: if len(permutation) != channels: raise ValueError( 'Keyword argument "permutation" expected to have {} elements but {} found.' .format(channels, len(permutation))) P = torch.eye(channels, channels)[permutation.type(dtype=torch.int64)] # We register the permutation matrix so that the model can be serialized self.register_buffer('permutation', P) # NOTE: For this implementation I have chosen to store the parameters densely, rather than # storing L, U, and s separately self.LU = torch.nn.Parameter(LU)
def __init__(self, num_features, LU_decomposed=True): super(Invertible1x1ConvLU, self).__init__() w_shape = [num_features, num_features] w_init = torch.qr(torch.randn(*w_shape))[0] self.num_features = num_features if not LU_decomposed: self.weight = nn.Parameter(torch.Tensor(w_init)) else: p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) s = torch.diag(upper) sign_s = torch.sign(s) log_s = torch.log(torch.abs(s)) upper = torch.triu(upper, 1) l_mask = torch.tril(torch.ones(w_shape), -1) eye = torch.eye(*w_shape) self.register_buffer("p", p) self.register_buffer("sign_s", sign_s) self.lower = nn.Parameter(lower) self.log_s = nn.Parameter(log_s) self.upper = nn.Parameter(upper) self.l_mask = l_mask self.eye = eye self.if_LU = LU_decomposed self.w_shape = w_shape
def __init__(self, num_channels, LU_decomposed): super().__init__() w_shape = [num_channels, num_channels] w_init = torch.qr(torch.randn(*w_shape))[0] if not LU_decomposed: self.weight = nn.Parameter(torch.Tensor(w_init)) else: p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) s = torch.diag(upper) sign_s = torch.sign(s) log_s = torch.log(torch.abs(s)) upper = torch.triu(upper, 1) l_mask = torch.tril(torch.ones(w_shape), -1) eye = torch.eye(*w_shape) self.register_buffer("p", p) self.register_buffer("sign_s", sign_s) self.lower = nn.Parameter(lower) self.log_s = nn.Parameter(log_s) self.upper = nn.Parameter(upper) self.l_mask = l_mask self.eye = eye self.w_shape = w_shape self.LU_decomposed = LU_decomposed
def pre_factor_kkt(Q, G, A): """ Perform all one-time factorizations and cache relevant matrix products""" nineq, nz, neq, nBatch = get_sizes(G, A) try: Q_LU = lu_hack(Q) except: raise RuntimeError(""" qpth Error: Cannot perform LU factorization on Q. Please make sure that your Q matrix is PSD and has a non-zero diagonal. """) # S = [ A Q^{-1} A^T A Q^{-1} G^T ] # [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ] # # We compute a partial LU decomposition of the S matrix # that can be completed once D^{-1} is known. # See the 'Block LU factorization' part of our website # for more details. qlu, pivots=Q_LU ##put in a condition for m>1 if G.size(2)==2: ##something funny here print(G.size(),qlu.size()) G_invQ_GT = torch.matmul(G, G.lu_solve(*Q_LU).transpose(1,2)) else: print(G.size(2),G.transpose(1,2).size(),qlu.size()) G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU)) R = G_invQ_GT.clone() S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \ .repeat(nBatch, 1).type_as(Q).int() if neq > 0: invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU) A_invQ_AT = torch.bmm(A, invQ_AT) G_invQ_AT = torch.bmm(G, invQ_AT) LU_A_invQ_AT = lu_hack(A_invQ_AT) P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT) P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT) S_LU_11 = LU_A_invQ_AT[0] U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT) ).lu_solve(*LU_A_invQ_AT) S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv) T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT) S_LU_12 = U_A_invQ_AT.bmm(T) S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q) S_LU_data = torch.cat((torch.cat((S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1) S_LU_pivots[:, :neq] = LU_A_invQ_AT[1] R -= G_invQ_AT.bmm(T) else: S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q) S_LU = [S_LU_data, S_LU_pivots] return Q_LU, S_LU, R
def __init__(self, num_channels, LU_decomposed=True): super().__init__() self.num_channels = num_channels self.LU_decomposed = LU_decomposed weight_shape = [num_channels, num_channels] weight, _ = torch.qr(torch.randn(*weight_shape)) if not self.LU_decomposed: self.weight = nn.Parameter(weight) else: weight_lu, pivots = torch.lu(weight) w_p, w_l, w_u = torch.lu_unpack(weight_lu, pivots) w_s = torch.diag(w_u) sign_s = torch.sign(w_s) log_s = torch.log(torch.abs(w_s)) w_u = torch.triu(w_u, 1) u_mask = torch.triu(torch.ones_like(w_u), 1) l_mask = u_mask.T.contiguous() eye = torch.eye(l_mask.shape[0]) self.register_buffer('p', w_p) self.register_buffer('sign_s', sign_s) self.register_buffer('eye', eye) self.register_buffer('u_mask', u_mask) self.register_buffer('l_mask', l_mask) self.l = nn.Parameter(w_l) self.u = nn.Parameter(w_u) self.log_s = nn.Parameter(log_s)
def __init__(self, in_out_channels): super(InvertibleConv1x1, self).__init__() W = torch.zeros((in_out_channels, in_out_channels), dtype=torch.float32) nn.init.orthogonal_(W) LU, pivots = torch.lu(W) P, L, U = torch.lu_unpack(LU, pivots) self.P = nn.Parameter(P, requires_grad=False) self.L = nn.Parameter(L, requires_grad=True) self.U = nn.Parameter(U, requires_grad=True) self.I = nn.Parameter(torch.eye(in_out_channels), requires_grad=False) self.pivots = nn.Parameter(pivots, requires_grad=False) L_mask = np.tril(np.ones((in_out_channels, in_out_channels), dtype='float32'), k=-1) U_mask = L_mask.T.copy() self.L_mask = nn.Parameter(torch.from_numpy(L_mask), requires_grad=False) self.U_mask = nn.Parameter(torch.from_numpy(U_mask), requires_grad=False) s = torch.diag(U) sign_s = torch.sign(s) log_s = torch.log(torch.abs(s)) self.log_s = nn.Parameter(log_s, requires_grad=True) self.sign_s = nn.Parameter(sign_s, requires_grad=False)
def sampling_W(self, dim): # sample a rotation matrix W = torch.empty(dim, dim) torch.nn.init.orthogonal_(W) # compute LU LU, pivot = torch.lu(W) P, L, U = torch.lu_unpack(LU, pivot) return W, P, L, U
def __init__(self, dim): super().__init__() self.dim = dim Q = nn.init.orthogonal_(torch.randn(dim, dim).to(device)) P, L, U = torch.lu_unpack(*torch.lu(Q)) self.register_buffer('P', P) self.L = nn.Parameter(L) # lower triangular portion self.S = nn.Parameter(U.diag()) # "crop out" the diagonal to its own parameter self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S
def __init__(self, dim): super().__init__() self.dim = dim Q = torch.nn.init.orthogonal_(torch.randn(dim, dim).to("cuda:0")) P, L, U = torch.lu_unpack(*Q.lu()) self.P = P self.L = nn.Parameter(L) self.S = nn.Parameter(U.diag()) self.U = nn.Parameter(torch.triu(U, diagonal=1))
def __init__(self, dim): super().__init__() self.dim = dim Q = torch.nn.init.orthogonal_(torch.randn(dim, dim)) P, L, U = torch.lu_unpack(*Q.lu()) self.P = P # remains fixed during optimization self.L = nn.Parameter(L) # lower triangular portion self.S = nn.Parameter(U.diag()) # "crop out" the diagonal to its own parameter self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S
def __init__(self, dim): super().__init__() self.register_buffer('placeholder', torch.randn(1)) self.dim = dim Q = torch.nn.init.orthogonal_(torch.randn(dim, dim)) P, L, U = torch.lu_unpack(*Q.lu()) self.P = nn.Parameter(P, requires_grad=False) # remains fixed during optimization self.L = nn.Parameter(L, requires_grad=True) # lower triangular portion self.S = nn.Parameter(U.diag(), requires_grad=True) # "crop out" the diagonal to its own parameter self.U = nn.Parameter(torch.triu(U, diagonal=1), requires_grad=True) # "crop out" diagonal, stored in S
def blas_lapack_ops(self): m = torch.randn(3, 3) a = torch.randn(10, 3, 4) b = torch.randn(10, 4, 3) v = torch.randn(3) return ( torch.addbmm(m, a, b), torch.addmm(torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3)), torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)), torch.addr(torch.zeros(3, 3), v, v), torch.baddbmm(m, a, b), torch.bmm(a, b), torch.chain_matmul(torch.randn(3, 3), torch.randn(3, 3), torch.randn(3, 3)), # torch.cholesky(a), # deprecated torch.cholesky_inverse(torch.randn(3, 3)), torch.cholesky_solve(torch.randn(3, 3), torch.randn(3, 3)), torch.dot(v, v), torch.eig(m), torch.geqrf(a), torch.ger(v, v), torch.inner(m, m), torch.inverse(m), torch.det(m), torch.logdet(m), torch.slogdet(m), torch.lstsq(m, m), torch.lu(m), torch.lu_solve(m, *torch.lu(m)), torch.lu_unpack(*torch.lu(m)), torch.matmul(m, m), torch.matrix_power(m, 2), # torch.matrix_rank(m), torch.matrix_exp(m), torch.mm(m, m), torch.mv(m, v), # torch.orgqr(a, m), # torch.ormqr(a, m, v), torch.outer(v, v), torch.pinverse(m), # torch.qr(a), torch.solve(m, m), torch.svd(a), # torch.svd_lowrank(a), # torch.pca_lowrank(a), # torch.symeig(a), # deprecated # torch.lobpcg(a, b), # not supported torch.trapz(m, m), torch.trapezoid(m, m), torch.cumulative_trapezoid(m, m), # torch.triangular_solve(m, m), torch.vdot(v, v), )
def __init__(self, shape): super().__init__() self.d_cpu = torch.prod(torch.tensor(shape)) Q = torch.nn.init.orthogonal_(torch.randn(self.d_cpu, self.d_cpu)) P, L, U = torch.lu_unpack(*Q.lu()) self.register_buffer('P', P) # remains fixed during optimization self.L = nn.Parameter(L) # lower triangular portion self.S = nn.Parameter( U.diag()) # "crop out" the diagonal to its own parameter self.U = nn.Parameter(torch.triu( U, diagonal=1)) # "crop out" diagonal, stored in S
def factor_kkt(S_LU, R, d): """ Factor the U22 block that we can only do after we know D. """ nBatch, nineq = d.size() neq = S_LU[1].size(1) - nineq # TODO: There's probably a better way to add a batched diagonal. global factor_kkt_eye if factor_kkt_eye is None or factor_kkt_eye.size() != d.size(): # print('Updating batchedEye size.') factor_kkt_eye = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(R).byte() T = R.clone() T[factor_kkt_eye] += (1. / d).squeeze().view(-1) T_LU = btrifact_hack(T) global shown_btrifact_warning if shown_btrifact_warning or not T.is_cuda: # TODO: Don't use pivoting in most cases because # torch.btriunpack is inefficient here: oldPivotsPacked = S_LU[1][:, -nineq:] - neq oldPivots, _, _ = torch.lu_unpack(T_LU[0], oldPivotsPacked, unpack_data=False) newPivotsPacked = T_LU[1] newPivots, _, _ = torch.lu_unpack(T_LU[0], newPivotsPacked, unpack_data=False) # Re-pivot the S_LU_21 block. if neq > 0: S_LU_21 = S_LU[0][:, -nineq:, :neq] S_LU[0][:, -nineq:, :neq] = newPivots.transpose(1, 2).bmm( oldPivots.bmm(S_LU_21)) # Add the new S_LU_22 block pivots. S_LU[1][:, -nineq:] = newPivotsPacked + neq # Add the new S_LU_22 block. S_LU[0][:, -nineq:, -nineq:] = T_LU[0]
def __init__(self, dim, device='cpu', condition_size=0): super().__init__() self.conditional = True # forward backward unchanged when conditional, thus always True self.cond_size = condition_size # for compatibility with the rest of the flow who have this attribute self.dim = dim self.device = device Q = torch.nn.init.orthogonal_(torch.randn(dim, dim).to(self.device)) P, L, U = torch.lu_unpack(*Q.lu()) self.P = P.to(self.device) # remains fixed during optimization self.L = nn.Parameter(L).to(self.device) # lower triangular portion self.S = nn.Parameter(U.diag()).to( self.device) # "crop out" the diagonal to its own parameter self.U = nn.Parameter(torch.triu(U, diagonal=1)).to( self.device) # "crop out" diagonal, stored in S
def __init__(self, in_channel, fixed, use_lu): super().__init__() self.use_lu = use_lu self.fixed = fixed if not use_lu: weight = th.randn(in_channel, in_channel) q, _ = th.qr(weight) weight = q if fixed: self.register_buffer('weight', weight.data) self.register_buffer('weight_inverse', weight.data.inverse()) self.register_buffer( 'fixed_log_det', th.slogdet(self.weight.double())[1].float()) else: self.weight = nn.Parameter(weight) if use_lu: assert not fixed #weight = np.random.randn(in_channel, in_channel) weight = th.randn(in_channel, in_channel) #q, _ = la.qr(weight) q, _ = th.qr(weight) # w_p, w_l, w_u = la.lu(q.astype(np.float32)) w_p, w_l, w_u = th.lu_unpack(*th.lu(q)) #w_s = np.diag(w_u) w_s = th.diag(w_u) #w_u = np.triu(w_u, 1) w_u = th.triu(w_u, 1) #u_mask = np.triu(np.ones_like(w_u), 1) u_mask = th.triu(th.ones_like(w_u), 1) #l_mask = u_mask.T l_mask = u_mask.t() #w_p = th.from_numpy(w_p) #w_l = th.from_numpy(w_l)w #w_s = th.from_numpy(w_s) #w_u = th.from_numpy(w_u) self.register_buffer('w_p', w_p) self.register_buffer('u_mask', u_mask) self.register_buffer('l_mask', l_mask) self.register_buffer('s_sign', th.sign(w_s)) self.register_buffer('l_eye', th.eye(l_mask.shape[0])) self.w_l = nn.Parameter(w_l) self.w_s = nn.Parameter(th.log(th.abs(w_s))) self.w_u = nn.Parameter(w_u)
def __init__(self, c): super(Invertible1x1ConvLUS, self).__init__() # Sample a random orthonormal matrix to initialize weights W, _ = torch.linalg.qr(torch.randn(c, c)) # Ensure determinant is 1.0 not -1.0 if torch.det(W) < 0: W[:, 0] = -1 * W[:, 0] p, lower, upper = torch.lu_unpack(*torch.lu(W)) self.register_buffer('p', p) # diagonals of lower will always be 1s anyway lower = torch.tril(lower, -1) lower_diag = torch.diag(torch.eye(c, c)) self.register_buffer('lower_diag', lower_diag) self.lower = nn.Parameter(lower) self.upper_diag = nn.Parameter(torch.diag(upper)) self.upper = nn.Parameter(torch.triu(upper, 1))
def __init__(self, in_channel): super().__init__() weight = torch.randn(in_channel, in_channel) q, _ = torch.qr(weight) w_p, w_l, w_u = torch.lu_unpack(*q.lu()) w_s = torch.diag(w_u) w_u = torch.triu(w_u, 1) u_mask = torch.triu(torch.ones_like(w_u), 1) l_mask = u_mask.T self.register_buffer("w_p", w_p) self.register_buffer("u_mask", u_mask) self.register_buffer("l_mask", l_mask) self.register_buffer("s_sign", torch.sign(w_s)) self.register_buffer("l_eye", torch.eye(l_mask.size(0))) self.w_l = nn.Parameter(w_l) self.w_s = nn.Parameter(logabs(w_s)) self.w_u = nn.Parameter(w_u)
def __init__(self, dim): super(Invertible1x1Conv, self).__init__() self.dim = dim # Grab the weight and bias from a randomly initialized Conv2d. m = nn.Conv2d(dim, dim, kernel_size=1) W = m.weight.clone().detach().reshape(dim, dim) LU, pivots = torch.lu(W) P, _, _ = torch.lu_unpack(LU, pivots) s = torch.diag(LU) # noinspection PyTypeChecker LU = torch.where(torch.eye(dim) == 0, LU, torch.zeros_like(LU)) self.register_buffer("P", P) self.register_buffer("s_sign", torch.sign(s)) self.register_parameter("s_log", nn.Parameter(torch.log(torch.abs(s) + 1e-3))) self.register_parameter("LU", nn.Parameter(LU))
def forward(ctx, input): # LUP decompose the matrices inp_lu, pivots = input.lu() perm, inpl, inpu = torch.lu_unpack(inp_lu, pivots) # get the number of permuations s = (pivots != torch.as_tensor(range( 1, input.shape[1] + 1)).int()).sum(1).type( torch.get_default_dtype()) # get the prod of the diag of U d = torch.diagonal(inpu, dim1=-2, dim2=-1).prod(1) # assemble det = ((-1)**s * d) ctx.save_for_backward(input, det) return det
def __init__(self, dim_inputs): super().__init__() self.dim_inputs = dim_inputs w_shape = [dim_inputs, dim_inputs] w_init = torch.qr(torch.randn(w_shape))[0] p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) s = torch.diag(upper) sign_s = torch.sign(s) log_s = torch.log(torch.abs(s)) upper = torch.triu(upper, 1) l_mask = torch.tril(torch.ones(w_shape), -1) eye = torch.eye(*w_shape) self.register_buffer('p', p) self.register_buffer('sign_s', sign_s) self.lower = nn.Parameter(lower) self.log_s = nn.Parameter(log_s) self.upper = nn.Parameter(upper) self.register_buffer('l_mask', l_mask) self.register_buffer('eye', eye)
def __init__(self, num_channels, lu_decomposition=False): """ Invertible 1x1 convolution layer :param num_channels: number of channels :type num_channels: int :param lu_decomposition: whether to use LU decomposition :type lu_decomposition: bool """ super().__init__() self.num_channels = num_channels self.lu_decomposition = lu_decomposition w_shape = [num_channels, num_channels] tolerance = 1e-4 # Sample a random orthogonal matrix w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32') if self.lu_decomposition: w_LU_pts, pivots = torch.lu(torch.Tensor(w_init)) p, w_l, w_u = torch.lu_unpack(w_LU_pts, pivots) s = torch.diag(torch.diag(w_u)) w_u -= s print((torch.Tensor(w_init) - torch.matmul(p, torch.matmul(w_l, (w_u + s)))).abs().sum()) assert (torch.Tensor(w_init) - torch.matmul( p, torch.matmul(w_l, (w_u + s)))).abs().sum() < tolerance self.register_parameter('weight_L', nn.Parameter(torch.Tensor(w_l))) self.register_parameter('weight_U', nn.Parameter(torch.Tensor(w_u))) self.register_parameter('s', nn.Parameter(torch.Tensor(s))) self.register_buffer('weight_P', torch.FloatTensor(p)) else: #w_shape = [num_channels, num_channels] # Sample a random orthogonal matrix #w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32') self.register_parameter('weight', nn.Parameter(torch.Tensor(w_init)))
def __init__(self, num_channels, use_lu=False): """ Constructor :param num_channels: Number of channels of the data :param use_lu: Flag whether to parametrize weights through the LU decomposition """ super().__init__() self.num_channels = num_channels self.use_lu = use_lu Q = torch.qr(torch.randn(self.num_channels, self.num_channels))[0] if use_lu: P, L, U = torch.lu_unpack(*Q.lu()) self.register_buffer('P', P) # remains fixed during optimization self.L = nn.Parameter(L) # lower triangular portion S = U.diag() # "crop out" the diagonal to its own parameter self.register_buffer("sign_S", torch.sign(S)) self.log_S = nn.Parameter(torch.log(torch.abs(S))) self.U = nn.Parameter(torch.triu( U, diagonal=1)) # "crop out" diagonal, stored in S self.register_buffer("eye", torch.diag(torch.ones(self.num_channels))) else: self.W = nn.Parameter(Q)
def rank_revealing_LUP_GPU(A, threshold=10**-10): """ It returns the rank of A using GPU acceleration. This is based on PyTorch's LUP implementation. Parameters __________ A: (2D float64 Numpy Array) Input array for which rank needs to be computed threshold: (Float) See documentation of the function "sort_rows_fast" Returns: _________ rank: (Integer) rank of A """ try: import torch except ImportError as e: print(str(e)) print("rank_revealing_LUP_GPU requires PyTorch to run.") A = A / np.abs(A).max() A_tensor = torch.from_numpy(A).cuda() A_LU, pivots = torch.lu(A_tensor, get_infos=False) _, _, u = torch.lu_unpack(A_LU, pivots, unpack_pivots=False) u = u.cpu().numpy() ref = compute_ref_LUP_fast(u, threshold) rank = compute_rank_ref(ref, threshold) return rank
def __init__(self, in_channel, weight=None): # in_channel indicates the channel size of the input super().__init__() self.in_channel = in_channel # set an random orthogonal matrix as the initial weight if weight is None: weight = torch.randn(in_channel, in_channel) Weight, _ = torch.qr(weight) # PLU decomposition self.weight = Weight # only for testing Weight_LU, pivots = torch.lu(Weight) w_p, w_l, w_u = torch.lu_unpack(Weight_LU, pivots) w_s = torch.diag(w_u) w_logs = torch.log(torch.abs(w_s)) s_sign = torch.sign(w_s) w_u = torch.triu(w_u, 1) u_mask = torch.triu(torch.ones_like(w_u), 1) l_mask = u_mask.T l_eye = torch.eye(l_mask.shape[0]) # fix P self.register_buffer("w_p", w_p) self.register_buffer("u_mask", u_mask) self.register_buffer("l_mask", l_mask) self.register_buffer("s_sign", s_sign) self.register_buffer("l_eye", l_eye) self.w_l = torch.nn.Parameter(w_l) self.w_u = torch.nn.Parameter(w_u) # self.w_s = torch.nn.Parameter(w_s) self.w_logs = torch.nn.Parameter(w_logs)
def __init__(self, nb_channels: int, lu_decomposition: bool = False): """ Invertible 1x1 Convolution for 2D inputs with LU parameterization. References : See Glow paper for details https://arxiv.org/abs/1807.03039 :param nb_channels: the number of input/output channels :param lu_decomposition: whether to use LU parameterization """ super(Invertible1x1Conv, self).__init__() self.__channels = nb_channels # Initialize W as a random orthogonal matrix W = torch.zeros(nb_channels, nb_channels) nn.init.orthogonal_(W) # Make sure the det is 1 (and not -1) to have a defined log-det if torch.det(W) < 0: W[:, 0] = -1 * W[:, 0] self.__lu = lu_decomposition if self.__lu: W_LU, pivots = W.lu() P, L, U = torch.lu_unpack(W_LU, pivots) s = torch.diag(U) U = U - torch.diag(s) # Assign the module trainable parameters self.sign_s = nn.Parameter(torch.sign(s), requires_grad=True) self.log_s = nn.Parameter(torch.log(s.abs()), requires_grad=True) self.lower = nn.Parameter(L, requires_grad=True) self.upper = nn.Parameter(U, requires_grad=True) # Assign the non trainable ones self.register_buffer('permutation', P) self.register_buffer('permutation_inv', torch.inverse(P)) self.register_buffer('eye', torch.eye(nb_channels)) self.register_buffer( 'l_mask', torch.tril(torch.ones(nb_channels, nb_channels), -1)) else: self.weight = nn.Parameter(W, requires_grad=True)
def backward(ctx, LU_grad, pivots_grad, infors_grad): """ Here we derive the gradients for the LU decomposition. LIMITATIONS: square inputs of full rank. If not stated otherwise, for tensors A and B, `A B` means the matrix product of A and B. Forward AD: Note that PyTorch returns packed LU, it is a mapping A -> (B:= L + U - I, P), such that A = P L U, and P is a permutation matrix, and is non-differentiable. Using B = L + U - I, A = P L U, we get dB = dL + dU and (*) P^T dA = dL U + L dU (**) By left/right multiplication of (**) with L^{-1}/U^{-1} we get: L^{-1} P^T dA U^{-1} = L^{-1} dL + dU U^{-1}. Note that L^{-1} dL is lower-triangular with zero diagonal, and dU U^{-1} is upper-triangular. Define 1_U := triu(ones(n, n)), and 1_L := ones(n, n) - 1_U, so L^{-1} dL = 1_L * (L^{-1} P^T dA U^{-1}), dU U^{-1} = 1_U * (L^{-1} P^T dA U^{-1}), where * denotes the Hadamard product. Hence we finally get: dL = L 1_L * (L^{-1} P^T dA U^{-1}), dU = 1_U * (L^{-1} P^T dA U^{-1}) U Backward AD: The backward sensitivity is then: Tr(B_grad^T dB) = Tr(B_grad^T dL) + Tr(B_grad^T dU) = [1] + [2]. [1] = Tr(B_grad^T dL) = Tr(B_grad^T L 1_L * (L^{-1} P^T dA U^{-1})) = [using Tr(A (B * C)) = Tr((A * B^T) C)] = Tr((B_grad^T L * 1_L^T) L^{-1} P^T dA U^{-1}) = [cyclic property of trace] = Tr(U^{-1} (B_grad^T L * 1_L^T) L^{-1} P^T dA) = Tr((P L^{-T} (L^T B_grad * 1_L) U^{-T})^T dA). Similar, [2] can be rewritten as: [2] = Tr(P L^{-T} (B_grad U^T * 1_U) U^{-T})^T dA, hence Tr(A_grad^T dA) = [1] + [2] = Tr((P L^{-T} (L^T B_grad * 1_L + B_grad U^T * 1_U) U^{-T})^T dA), so A_grad = P L^{-T} (L^T B_grad * 1_L + B_grad U^T * 1_U) U^{-T}. In the code below we use the name `LU` instead of `B`, so that there is no confusion in the derivation above between the matrix product and a two-letter variable name. """ LU, pivots = ctx.saved_tensors P, L, U = torch.lu_unpack(LU, pivots) # To make sure MyPy infers types right assert (L is not None) and (U is not None) I = LU_grad.new_zeros(LU_grad.shape) I.diagonal(dim1=-2, dim2=-1).fill_(1) Lt_inv = torch.triangular_solve(I, L, upper=False).solution.transpose(-1, -2) Ut_inv = torch.triangular_solve(I, U, upper=True).solution.transpose(-1, -2) phi_L = (L.transpose(-1, -2) @ LU_grad).tril_() phi_L.diagonal(dim1=-2, dim2=-1).fill_(0.0) phi_U = (LU_grad @ U.transpose(-1, -2)).triu_() self_grad_perturbed = Lt_inv @ (phi_L + phi_U) @ Ut_inv return P @ self_grad_perturbed, None, None
def pre_factor_kkt(Q, G, A): """ Perform all one-time factorizations and cache relevant matrix products""" nineq, nz, neq, nBatch = get_sizes(G, A) try: Q_LU = lu_hack(Q) except: raise RuntimeError(""" qpth Error: Cannot perform LU factorization on Q. Please make sure that your Q matrix is PSD and has a non-zero diagonal. """) # S = [ A Q^{-1} A^T A Q^{-1} G^T ] # [ G Q^{-1} A^T G Q^{-1} G^T + D^{-1} ] # # We compute a partial LU decomposition of the S matrix # that can be completed once D^{-1} is known. # See the 'Block LU factorization' part of our website # for more details. G_invQ_GT = torch.bmm(G, G.transpose(1, 2).lu_solve(*Q_LU)) R = G_invQ_GT.clone() S_LU_pivots = torch.IntTensor(range(1, 1 + neq + nineq)).unsqueeze(0) \ .repeat(nBatch, 1).type_as(Q).int() if neq > 0: invQ_AT = A.transpose(1, 2).lu_solve(*Q_LU) # if any(torch.isnan(torch.flatten(invQ_AT)).tolist()): # logging.info("nan comes in invq AT") # else: # logging.info("non NAN in invq AT") A_invQ_AT = torch.bmm(A, invQ_AT) G_invQ_AT = torch.bmm(G, invQ_AT) # if any(torch.isnan(torch.flatten(G_invQ_AT)).tolist()): # logging.info("nan comes in G_invQ_AT") # else: # logging.info("non NAN in G_invQ_AT") LU_A_invQ_AT = lu_hack(A_invQ_AT) P_A_invQ_AT, L_A_invQ_AT, U_A_invQ_AT = torch.lu_unpack(*LU_A_invQ_AT) P_A_invQ_AT = P_A_invQ_AT.type_as(A_invQ_AT) S_LU_11 = LU_A_invQ_AT[0] U_A_invQ_AT_inv = (P_A_invQ_AT.bmm(L_A_invQ_AT)).lu_solve( *LU_A_invQ_AT) # if any(torch.isnan(torch.flatten(U_A_invQ_AT_inv)).tolist()): # logging.info("nan in U_A_invQ_AT_inv") S_LU_21 = G_invQ_AT.bmm(U_A_invQ_AT_inv) T = sp_lu_solve(G_invQ_AT.transpose(1, 2), *LU_A_invQ_AT) # T = G_invQ_AT.transpose(1, 2).lu_solve(*LU_A_invQ_AT) S_LU_12 = U_A_invQ_AT.bmm(T) S_LU_22 = torch.zeros(nBatch, nineq, nineq).type_as(Q) S_LU_data = torch.cat((torch.cat( (S_LU_11, S_LU_12), 2), torch.cat((S_LU_21, S_LU_22), 2)), 1) S_LU_pivots[:, :neq] = LU_A_invQ_AT[1] # if any(torch.isnan(torch.flatten(T)).tolist()): # logging.info("nan comes in T") # else: # logging.info("non NAN in T") R -= G_invQ_AT.bmm(T) # if any(torch.isnan(torch.flatten(R)).tolist()): # logging.info("nan is here") # R[torch.isnan(R)] = 0 else: S_LU_data = torch.zeros(nBatch, nineq, nineq).type_as(Q) # S_LU_data[torch.isnan(S_LU_data)] = 0 S_LU = [S_LU_data, S_LU_pivots] return Q_LU, S_LU, R
def backward(ctx, LU_grad, pivots_grad, infors_grad): """ Here we derive the gradients for the LU decomposition. LIMITATIONS: square inputs of full rank. If not stated otherwise, for tensors A and B, `A B` means the matrix product of A and B. Let A^H = (A^T).conj() Forward AD: Note that PyTorch returns packed LU, it is a mapping A -> (B:= L + U - I, P), such that A = P L U, and P is a permutation matrix, and is non-differentiable. Using B = L + U - I, A = P L U, we get dB = dL + dU and (*) P^T dA = dL U + L dU (**) By left/right multiplication of (**) with L^{-1}/U^{-1} we get: L^{-1} P^T dA U^{-1} = L^{-1} dL + dU U^{-1}. Note that L^{-1} dL is lower-triangular with zero diagonal, and dU U^{-1} is upper-triangular. Define 1_U := triu(ones(n, n)), and 1_L := ones(n, n) - 1_U, so L^{-1} dL = 1_L * (L^{-1} P^T dA U^{-1}), dU U^{-1} = 1_U * (L^{-1} P^T dA U^{-1}), where * denotes the Hadamard product. Hence we finally get: dL = L 1_L * (L^{-1} P^T dA U^{-1}), dU = 1_U * (L^{-1} P^T dA U^{-1}) U Backward AD: The backward sensitivity is then: Tr(B_grad^H dB) = Tr(B_grad^H dL) + Tr(B_grad^H dU) = [1] + [2]. [1] = Tr(B_grad^H dL) = Tr(B_grad^H L 1_L * (L^{-1} P^T dA U^{-1})) = [using Tr(A (B * C)) = Tr((A * B^T) C)] = Tr((B_grad^H L * 1_L^T) L^{-1} P^T dA U^{-1}) = [cyclic property of trace] = Tr(U^{-1} (B_grad^H L * 1_L^T) L^{-1} P^T dA) = Tr((P L^{-H} (L^H B_grad * 1_L) U^{-H})^H dA). Similar, [2] can be rewritten as: [2] = Tr(P L^{-H} (B_grad U^H * 1_U) U^{-H})^H dA, hence Tr(A_grad^H dA) = [1] + [2] = Tr((P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H})^H dA), so A_grad = P L^{-H} (L^H B_grad * 1_L + B_grad U^H * 1_U) U^{-H}. In the code below we use the name `LU` instead of `B`, so that there is no confusion in the derivation above between the matrix product and a two-letter variable name. """ LU, pivots = ctx.saved_tensors P, L, U = torch.lu_unpack(LU, pivots) # To make sure MyPy infers types right assert (L is not None) and (U is not None) and (P is not None) # phi_L = L^H B_grad * 1_L phi_L = (L.transpose(-1, -2).conj() @ LU_grad).tril_() phi_L.diagonal(dim1=-2, dim2=-1).fill_(0.0) # phi_U = B_grad U^H * 1_U phi_U = (LU_grad @ U.transpose(-1, -2).conj()).triu_() phi = phi_L + phi_U # using the notation from above plus the variable names, note # A_grad = P L^{-H} phi U^{-H}. # Instead of inverting L and U, we solve two systems of equations, i.e., # the above expression could be rewritten as # L^H P^T A_grad U^H = phi. # Let X = P^T A_grad U_H, then # X = L^{-H} phi, where L^{-H} is upper triangular, or # X = torch.triangular_solve(phi, L^H) # using the definition of X we see: # X = P^T A_grad U_H => P X = A_grad U_H => U A_grad^H = X^H P^T, so # A_grad = (U^{-1} X^H P^T)^H, or # A_grad = torch.triangular_solve(X^H P^T, U)^H X = torch.triangular_solve(phi, L.transpose(-1, -2).conj(), upper=True).solution A_grad = torch.triangular_solve(X.transpose(-1, -2).conj() @ P.transpose(-1, -2), U, upper=True) \ .solution.transpose(-1, -2).conj() return A_grad, None, None