def test_saint_subgraph(): row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4]) col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3]) adj = SparseTensor(row=row, col=col) node_idx = torch.tensor([0, 1, 2]) adj, edge_index = adj.saint_subgraph(node_idx)
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: row, col, value = src.coo() inv_mask = row != col if k == 0 else row != (col - k) new_row, new_col = row[inv_mask], col[inv_mask] if value is not None: value = value[inv_mask] rowcount = src.storage._rowcount colcount = src.storage._colcount if rowcount is not None or colcount is not None: mask = ~inv_mask if rowcount is not None: rowcount = rowcount.clone() rowcount[row[mask]] -= 1 if colcount is not None: colcount = colcount.clone() colcount[col[mask]] -= 1 storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value, sparse_sizes=src.sparse_sizes(), rowcount=rowcount, colptr=None, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage)
def partition( src: SparseTensor, num_parts: int, recursive: bool = False, weighted=False) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: rowptr, col, value = src.csr() rowptr, col = rowptr.cpu(), col.cpu() if value is not None and weighted: assert value.numel() == col.numel() value = value.view(-1).detach().cpu() if value.is_floating_point(): value = weight2metis(value) else: value = None cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts, recursive) cluster = cluster.to(src.device()) cluster, perm = cluster.sort() out = permute(src, perm) partptr = torch.ops.torch_sparse.ind2ptr(cluster, num_parts) return out, partptr, perm
def t(src: SparseTensor) -> SparseTensor: csr2csc = src.storage.csr2csc() row, col, value = src.coo() if value is not None: value = value[csr2csc] sparse_sizes = src.storage.sparse_sizes() storage = SparseStorage( row=col[csr2csc], rowptr=src.storage._colptr, col=row[csr2csc], value=value, sparse_sizes=(sparse_sizes[1], sparse_sizes[0]), rowcount=src.storage._colcount, colptr=src.storage._rowptr, colcount=src.storage._rowcount, csr2csc=src.storage._csc2csr, csc2csr=csr2csc, is_sorted=True, ) return src.from_storage(storage)
def test_fill_diag(dtype, device): row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], torch.long, device) value = tensor([1, 2, 3, 4], dtype, device) mat = SparseTensor(row=row, col=col, value=value) mat = mat.fill_diag(-8, k=-1) mat = mat.fill_diag(-8, k=1)
def adamic_adar(indexA, valueA, indexB, valueB, m, k, n, coalesced=False, sampling=True): A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA, sparse_sizes=(m, k), is_sorted=not coalesced) B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB, sparse_sizes=(k, n), is_sorted=not coalesced) deg_A = A.storage.colcount() deg_B = B.storage.rowcount() deg_normalized = 1.0 / (deg_A + deg_B).to(torch.float) deg_normalized[deg_normalized == float('inf')] = 0.0 D = SparseTensor(row=torch.arange(deg_normalized.size(0), device=valueA.device), col=torch.arange(deg_normalized.size(0), device=valueA.device), value=deg_normalized.type_as(valueA), sparse_sizes=(deg_normalized.size(0), deg_normalized.size(0))) out = A @ D @ B row, col, values = out.coo() num_samples = min(int(valueA.numel()), int(valueB.numel()), values.numel()) if sampling and values.numel() > num_samples: idx = torch.multinomial(values, num_samples=num_samples, replacement=False) row, col, values = row[idx], col[idx], values[idx] return torch.stack([row, col], dim=0), values
def spspmm(indexA, valueA, indexB, valueB, m, k, n, coalesced=False): """Matrix product of two sparse tensors. Both input sparse matrices need to be coalesced (use the :obj:`coalesced` attribute to force). Args: indexA (:class:`LongTensor`): The index tensor of first sparse matrix. valueA (:class:`Tensor`): The value tensor of first sparse matrix. indexB (:class:`LongTensor`): The index tensor of second sparse matrix. valueB (:class:`Tensor`): The value tensor of second sparse matrix. m (int): The first dimension of first corresponding dense matrix. k (int): The second dimension of first corresponding dense matrix and first dimension of second corresponding dense matrix. n (int): The second dimension of second corresponding dense matrix. coalesced (bool, optional): If set to :obj:`True`, will coalesce both input sparse matrices. (default: :obj:`False`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ A = SparseTensor(row=indexA[0], col=indexA[1], value=valueA, sparse_sizes=(m, k), is_sorted=not coalesced) B = SparseTensor(row=indexB[0], col=indexB[1], value=valueB, sparse_sizes=(k, n), is_sorted=not coalesced) C = matmul(A, B) row, col, value = C.coo() return torch.stack([row, col], dim=0), value
def test_permute(device): row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], torch.long, device) value = tensor([1, 2, 3, 4, 5], torch.float, device) adj = SparseTensor(row=row, col=col, value=value) row, col, value = adj.permute(torch.tensor([1, 0, 2])).coo() assert row.tolist() == [0, 1, 1, 2, 2] assert col.tolist() == [1, 0, 1, 0, 2] assert value.tolist() == [3, 2, 1, 4, 5]
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int], length: Tuple[int, int]) -> SparseTensor: # This function builds the inverse operation of `cat_diag` and should hence # only be used on *diagonally stacked* sparse matrices. # That's the reason why this method is marked as *private*. rowptr, col, value = src.csr() rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1) row_start = int(rowptr[0]) rowptr = rowptr - row_start row_length = int(rowptr[-1]) row = src.storage._row if row is not None: row = row.narrow(0, row_start, row_length) - start[0] col = col.narrow(0, row_start, row_length) - start[1] if value is not None: value = value.narrow(0, row_start, row_length) sparse_sizes = length rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.narrow(0, start[0], length[0]) colptr = src.storage._colptr if colptr is not None: colptr = colptr.narrow(0, start[1], length[1] + 1) colptr = colptr - int(colptr[0]) # i.e. `row_start` colcount = src.storage._colcount if colcount is not None: colcount = colcount.narrow(0, start[1], length[1]) csr2csc = src.storage._csr2csc if csr2csc is not None: csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start csc2csr = src.storage._csc2csr if csc2csr is not None: csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, sparse_sizes=sparse_sizes, rowcount=rowcount, colptr=colptr, colcount=colcount, csr2csc=csr2csc, csc2csr=csc2csr, is_sorted=True) return src.from_storage(storage)
def fill_diag(src: SparseTensor, fill_value: int, k: int = 0) -> SparseTensor: num_diag = min(src.sparse_size(0), src.sparse_size(1) - k) if k < 0: num_diag = min(src.sparse_size(0) + k, src.sparse_size(1)) value = src.storage.value() if value is not None: sizes = [num_diag] + src.sizes()[2:] return set_diag(src, value.new_full(sizes, fill_value), k) else: return set_diag(src, None, k)
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None, k: int = 0) -> SparseTensor: src = remove_diag(src, k=k) row, col, value = src.coo() mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0), src.size(1), k) inv_mask = ~mask start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel() diag = torch.arange(start, start + num_diag, device=row.device) new_row = row.new_empty(mask.size(0)) new_row[mask] = row new_row[inv_mask] = diag new_col = col.new_empty(mask.size(0)) new_col[mask] = col new_col[inv_mask] = diag.add_(k) new_value: Optional[torch.Tensor] = None if value is not None: new_value = value.new_empty((mask.size(0), ) + value.size()[1:]) new_value[mask] = value if values is not None: new_value[inv_mask] = values else: new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype, device=value.device) rowcount = src.storage._rowcount if rowcount is not None: rowcount = rowcount.clone() rowcount[start:start + num_diag] += 1 colcount = src.storage._colcount if colcount is not None: colcount = colcount.clone() colcount[start + k:start + num_diag + k] += 1 storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=new_value, sparse_sizes=src.sparse_sizes(), rowcount=rowcount, colptr=None, colcount=colcount, csr2csc=None, csc2csr=None, is_sorted=True) return src.from_storage(storage)
def get_diag(src: SparseTensor) -> Tensor: row, col, value = src.coo() if value is None: value = torch.ones(row.size(0)) sizes = list(value.size()) sizes[0] = min(src.size(0), src.size(1)) out = value.new_zeros(sizes) mask = row == col out[row[mask]] = value[mask] return out
def forward(self, x, g, **kwargs): """ """ # Pre-compute (I+D)^{-1} (A+I) (multiply each row of A+I with (1+di)^-1) adj_mat_filled_diag = SparseTensor.from_torch_sparse_coo_tensor( g.adjacency_matrix(False)).fill_diag(1.) adj_mat_filled_diag = adj_mat_filled_diag / adj_mat_filled_diag.sum( -1).unsqueeze(-1) # Divide each row by (1+di) if torch.cuda.is_available() and x.is_cuda: adj_mat_filled_diag = adj_mat_filled_diag.cuda() # This will be of shape [num_output_dim, nx, nx] -> Prohibitive for big nx covar_xx = self.covar_module(x).evaluate() # covar_xx = (I+D)^{-1} (A+I) K_xx (A+I)^top (I+D)^{-1} # First compute (I+D)^{-1} (A+I) @ K_xx xx_t1 = self.sparse_adj_matmul(adj_mat_filled_diag, covar_xx) # Then compute (I+D)^{-1} (A+I) @ ((A+I) @ K_xx).T = (A+I) @ K_xx @ (A+I).T covar_full = self.sparse_adj_matmul(adj_mat_filled_diag, xx_t1.transpose(-2, -1)) mean_full = self.mean_module(x) mean_full = self.sparse_adj_matmul(adj_mat_filled_diag, mean_full) return gpytorch.distributions.MultivariateNormal(mean_full, covar_full)
def test_metis(device): value1 = torch.randn(6 * 6, device=device).view(6, 6) value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6) value3 = torch.ones(6 * 6, device=device).view(6, 6) for value in [value1, value2, value3]: mat = SparseTensor.from_dense(value) _, partptr, perm = mat.partition(num_parts=2, recursive=False, weighted=True) assert partptr.numel() == 3 assert perm.numel() == 6 _, partptr, perm = mat.partition(num_parts=2, recursive=False, weighted=False) assert partptr.numel() == 3 assert perm.numel() == 6 _, partptr, perm = mat.partition(num_parts=1, recursive=False, weighted=True) assert partptr.numel() == 2 assert perm.numel() == 6
def test_spmm(dtype, device, reduce): src = torch.randn((10, 8), dtype=dtype, device=device) src[2:4, :] = 0 # Remove multiple rows. src[:, 2:4] = 0 # Remove multiple columns. src = SparseTensor.from_dense(src).requires_grad_() row, col, value = src.coo() other = torch.randn((2, 8, 2), dtype=dtype, device=device, requires_grad=True) src_col = other.index_select(-2, col) * value.unsqueeze(-1) expected = torch_scatter.scatter(src_col, row, dim=-2, reduce=reduce) if reduce == 'min': expected[expected > 1000] = 0 if reduce == 'max': expected[expected < -1000] = 0 grad_out = torch.randn_like(expected) expected.backward(grad_out) expected_grad_value = value.grad value.grad = None expected_grad_other = other.grad other.grad = None out = matmul(src, other, reduce) out.backward(grad_out) assert torch.allclose(expected, out, atol=1e-6) assert torch.allclose(expected_grad_value, value.grad, atol=1e-6) assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor: rowptr, col, value = src.csr() if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise. other = gather_csr(other.squeeze(1), rowptr) elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise. other = other.squeeze(0)[col] else: raise ValueError( f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' f'(1, {src.size(1)}, ...), but got size {other.size()}.') if value is not None: value = value.add_(other.to(value.dtype)) else: value = other.add_(1) return src.set_value_(value, layout='coo')
def spmm_max(src: SparseTensor, other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: rowptr, col, value = src.csr() if value is not None: value = value.to(other.dtype) return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
def add_nnz_(src: SparseTensor, other: torch.Tensor, layout: Optional[str] = None) -> SparseTensor: value = src.storage.value() if value is not None: value = value.add_(other.to(value.dtype)) else: value = other.add(1) return src.set_value_(value, layout=layout)
def sparse_select(adj, dim, index): """index select on sparse tensor (temporary function) torch.index_select on sparse tesnor is too slow to be useful https://github.com/pytorch/pytorch/issues/54561 """ adj = SparseTensor.from_torch_sparse_coo_tensor(adj) adj = index_select(adj, dim, index) return adj.to_torch_sparse_coo_tensor()
def test_eye(dtype, device): mat = SparseTensor.eye(3, dtype=dtype, device=device) assert mat.storage.col().device == device assert mat.storage.sparse_sizes() == (3, 3) assert mat.storage.row().tolist() == [0, 1, 2] assert mat.storage.rowptr().tolist() == [0, 1, 2, 3] assert mat.storage.col().tolist() == [0, 1, 2] assert mat.storage.value().tolist() == [1, 1, 1] assert mat.storage.value().dtype == dtype assert mat.storage.num_cached_keys() == 0 mat = SparseTensor.eye(3, has_value=False) assert mat.storage.col().device == device assert mat.storage.sparse_sizes() == (3, 3) assert mat.storage.row().tolist() == [0, 1, 2] assert mat.storage.rowptr().tolist() == [0, 1, 2, 3] assert mat.storage.col().tolist() == [0, 1, 2] assert mat.storage.value() is None assert mat.storage.num_cached_keys() == 0 mat = SparseTensor.eye(3, 4, fill_cache=True) assert mat.storage.col().device == device assert mat.storage.sparse_sizes() == (3, 4) assert mat.storage.row().tolist() == [0, 1, 2] assert mat.storage.rowptr().tolist() == [0, 1, 2, 3] assert mat.storage.col().tolist() == [0, 1, 2] assert mat.storage.num_cached_keys() == 5 assert mat.storage.rowcount().tolist() == [1, 1, 1] assert mat.storage.colptr().tolist() == [0, 1, 2, 3, 3] assert mat.storage.colcount().tolist() == [1, 1, 1, 0] assert mat.storage.csr2csc().tolist() == [0, 1, 2] assert mat.storage.csc2csr().tolist() == [0, 1, 2] mat = SparseTensor.eye(4, 3, fill_cache=True) assert mat.storage.col().device == device assert mat.storage.sparse_sizes() == (4, 3) assert mat.storage.row().tolist() == [0, 1, 2] assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3] assert mat.storage.col().tolist() == [0, 1, 2] assert mat.storage.num_cached_keys() == 5 assert mat.storage.rowcount().tolist() == [1, 1, 1, 0] assert mat.storage.colptr().tolist() == [0, 1, 2, 3] assert mat.storage.colcount().tolist() == [1, 1, 1] assert mat.storage.csr2csc().tolist() == [0, 1, 2] assert mat.storage.csc2csr().tolist() == [0, 1, 2]
def reverse_cuthill_mckee( src: SparseTensor, is_symmetric: Optional[bool] = None ) -> Tuple[SparseTensor, torch.Tensor]: if is_symmetric is None: is_symmetric = src.is_symmetric() if not is_symmetric: src = src.to_symmetric() sp_src = src.to_scipy(layout='csr') perm = sp.csgraph.reverse_cuthill_mckee(sp_src, symmetric_mode=True).copy() perm = torch.from_numpy(perm).to(torch.long).to(src.device()) out = permute(src, perm) return out, perm
def mul_nnz(src: SparseTensor, other: torch.Tensor, layout: Optional[str] = None) -> SparseTensor: value = src.storage.value() if value is not None: value = value.mul(other.to(value.dtype)) else: value = other return src.set_value(value, layout=layout)
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor: assert src.sparse_size(1) == other.sparse_size(0) rowptrA, colA, valueA = src.csr() rowptrB, colB, valueB = other.csr() M, K = src.sparse_size(0), other.sparse_size(1) rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum( rowptrA, colA, valueA, rowptrB, colB, valueB, K) return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC, sparse_sizes=(M, K), is_sorted=True)
def masked_select_nnz(src: SparseTensor, mask: torch.Tensor, layout: Optional[str] = None) -> SparseTensor: assert mask.dim() == 1 if get_layout(layout) == 'csc': mask = mask[src.storage.csc2csr()] row, col, value = src.coo() row, col = row[mask], col[mask] if value is not None: value = value[mask] return SparseTensor(row=row, rowptr=None, col=col, value=value, sparse_sizes=src.sparse_sizes(), is_sorted=True)
def saint_subgraph( src: SparseTensor, node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]: row, col, value = src.coo() rowptr = src.storage.rowptr() data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col) row, col, edge_index = data if value is not None: value = value[edge_index] out = SparseTensor(row=row, rowptr=None, col=col, value=value, sparse_sizes=(node_idx.size(0), node_idx.size(0)), is_sorted=True) return out, edge_index
def index_select_nnz(src: SparseTensor, idx: torch.Tensor, layout: Optional[str] = None) -> SparseTensor: assert idx.dim() == 1 if get_layout(layout) == 'csc': idx = src.storage.csc2csr()[idx] row, col, value = src.coo() row, col = row[idx], col[idx] if value is not None: value = value[idx] return SparseTensor(row=row, rowptr=None, col=col, value=value, sparse_sizes=src.sparse_sizes(), is_sorted=True)
def add(src, other): # noqa: F811 if isinstance(other, Tensor): rowptr, col, value = src.csr() if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise. other = gather_csr(other.squeeze(1), rowptr) elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise. other = other.squeeze(0)[col] else: raise ValueError( f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or ' f'(1, {src.size(1)}, ...), but got size {other.size()}.') if value is not None: value = other.to(value.dtype).add_(value) else: value = other.add_(1) return src.set_value(value, layout='coo') elif isinstance(other, SparseTensor): rowA, colA, valueA = src.coo() rowB, colB, valueB = other.coo() row = torch.cat([rowA, rowB], dim=0) col = torch.cat([colA, colB], dim=0) value: Optional[Tensor] = None if valueA is not None and valueB is not None: value = torch.cat([valueA, valueB], dim=0) M = max(src.size(0), other.size(0)) N = max(src.size(1), other.size(1)) sparse_sizes = (M, N) out = SparseTensor(row=row, col=col, value=value, sparse_sizes=sparse_sizes) out = out.coalesce(reduce='sum') return out else: raise NotImplementedError
def test_spmm_half_precision(): src_dense = torch.randn((10, 8), dtype=torch.half, device='cpu') src_dense[2:4, :] = 0 # Remove multiple rows. src_dense[:, 2:4] = 0 # Remove multiple columns. src = SparseTensor.from_dense(src_dense) other = torch.randn((2, 8, 2), dtype=torch.float, device='cpu') expected = (src_dense.to(torch.float) @ other).to(torch.half) out = src @ other.to(torch.half) assert torch.allclose(expected, out, atol=1e-2)
def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int, replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]: rowptr, col, value = src.csr() rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj( rowptr, col, subset, num_neighbors, replace) if value is not None: value = value[e_id] out = SparseTensor(rowptr=rowptr, row=None, col=col, value=value, sparse_sizes=(subset.size(0), n_id.size(0)), is_sorted=True) return out, n_id
def test_get_diag(dtype, device): row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device) value = tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype, device) mat = SparseTensor(row=row, col=col, value=value) assert mat.get_diag().tolist() == [[1, 1], [0, 0], [4, 4]] row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], torch.long, device) mat = SparseTensor(row=row, col=col) assert mat.get_diag().tolist() == [1, 0, 1]