def weighted_dimwise_median(A: torch.sparse.FloatTensor, x: torch.Tensor, **kwargs) -> torch.Tensor: """A weighted dimension-wise Median aggregation. Parameters ---------- A : torch.sparse.FloatTensor Sparse [n, n] tensor of the weighted/normalized adjacency matrix x : torch.Tensor Dense [n, d] tensor containing the node attributes/embeddings Returns ------- torch.Tensor The new embeddings [n, d] """ if not A.is_cuda: return weighted_dimwise_median_cpu(A, x, **kwargs) assert A.is_sparse N, D = x.shape median_idx = custom_cuda_kernels.dimmedian_idx(x, A) col_idx = torch.arange(D, device=A.device).view(1, -1).expand(N, D) x_selected = x[median_idx, col_idx] a_row_sum = torch_scatter.scatter_sum(A._values(), A._indices()[0], dim=-1).view(-1, 1).expand(N, D) return a_row_sum * x_selected
def add_eye_sparse_tensor(x: torch.sparse.FloatTensor, bandwidth: int) -> torch.sparse.FloatTensor: """Not used!""" if not bandwidth: return x indices_list = [x._indices()] values_list = [x._values()] # Add diagonal indices_list.append( torch.arange(0, x.shape[1], dtype=torch.long).repeat(2, 1)) values_list.append(torch.ones(x.shape[1], dtype=torch.float)) # Add off-diagonals for k in range(1, bandwidth, 1): # Indices indices_upper = torch.stack([ torch.arange(x.shape[1] - k, dtype=torch.long), k + torch.arange(x.shape[1] - k, dtype=torch.long), ]) indices_lower = torch.stack([ k + torch.arange(x.shape[1] - k, dtype=torch.long), torch.arange(x.shape[1] - k, dtype=torch.long), ]) indices_list.extend([indices_upper, indices_lower]) # Values values_upper = torch.ones(x.shape[1] - k, dtype=torch.float) values_lower = torch.ones(x.shape[1] - k, dtype=torch.float) values_list.extend([values_upper, values_lower]) out = torch.sparse_coo_tensor(torch.cat(indices_list), torch.cat(values_list), size=x.size) return out
def torch_coo_to_scipy_coo(m: torch.sparse.FloatTensor) -> coo_matrix: """Convert torch :class:`torch.sparse.FloatTensor` tensor to. :class:`scipy.sparse.coo_matrix` """ data = m.values().numpy() indices = m.indices() return coo_matrix((data, (indices[0], indices[1])), tuple(m.size()))
def _sparse_masked_select_abs(self, sparse_tensor: torch.sparse.FloatTensor, thr): indices = sparse_tensor._indices() values = sparse_tensor._values() prune_mask = torch.abs(values) >= thr return torch.sparse_coo_tensor( indices=indices.masked_select(prune_mask).reshape(2, -1), values=values.masked_select(prune_mask), size=[self.n_outputs, self.n_inputs]).coalesce()
def _reduce(x: torch.sparse.FloatTensor): # dispatch table cannot distinguish between torch.sparse.FloatTensor and torch.Tensor if isinstance(x, torch.sparse.FloatTensor) or isinstance( x, torch.sparse.LongTensor): int_type = _get_int_type(torch.max(x._indices()).item()) return _sparse_tensor_constructor, (x._indices().to(int_type), x._values(), x.size()) else: return torch.Tensor.__reduce_ex__( x, pickle.HIGHEST_PROTOCOL) # use your own protocol
def get_nonzero_idx(sp: torch.sparse.FloatTensor) -> torch.LongTensor: """Get indices of nonzero elements of a sparse tensor Args: sp (torch.sparse.FloatTensor): a sparse tensor Returns: torch.LongTensor: indices of nonzero elements """ sp = sp.coalesce() return sp.indices()[:, sp.values() > 0]
def forward(self, input: torch.Tensor, adj: torch.sparse.FloatTensor): support = self.conv(input) adj = adj.to_dense() if self.normalize: adj_sum = adj.sum(dim=0) adj_sum[adj_sum == 0] = 1 adj = adj_sum.diag().inverse() @ adj output = support @ adj.transpose(0, 1) if self.bias is not None: output = (output.transpose(1, 2) + self.bias).transpose(2, 1) return output
def remove_eye_sparse_tensor(x: torch.sparse.FloatTensor, bandwidth: int) -> torch.sparse.FloatTensor: """Set diagonal (and offdiagonal) elements to zero. Args: x: Input array. bandwidth: Width of the diagonal 0 band. """ if not bandwidth: return x indices = x._indices() values = x._values() keep_mask = (indices[0, :] - indices[1, :]).abs() > bandwidth out = torch.sparse_coo_tensor(indices[:, keep_mask], values[keep_mask]) return out
def soft_weighted_medoid(A: torch.sparse.FloatTensor, x: torch.Tensor, temperature: float = 1.0, **kwargs) -> torch.Tensor: """A weighted Medoid aggregation. Parameters ---------- A : torch.sparse.FloatTensor Sparse [n, n] tensor of the weighted/normalized adjacency matrix. x : torch.Tensor Dense [n, d] tensor containing the node attributes/embeddings. temperature : float, optional Temperature for the argmin approximation by softmax, by default 1.0 Returns ------- torch.Tensor The new embeddings [n, d]. """ N, D = x.shape l2 = _distance_matrix(x) A_cpu_dense = A.cpu() l2_cpu = l2.cpu() if A.is_sparse: A_cpu_dense = A_cpu_dense.to_dense() distances = A_cpu_dense[:, None, :].expand(N, N, N) * l2_cpu distances[A_cpu_dense == 0] = torch.finfo(distances.dtype).max distances = distances.sum(-1).to(x.device) distances[~torch.isfinite(distances)] = torch.finfo(distances.dtype).max row_sum = A_cpu_dense.sum(-1)[:, None].to(x.device) return row_sum * (F.softmax(-distances / temperature, dim=-1) @ x)
def weighted_medoid(A: torch.sparse.FloatTensor, x: torch.Tensor, **kwargs) -> torch.Tensor: """A weighted Medoid aggregation. Parameters ---------- A : torch.sparse.FloatTensor Sparse [n, n] tensor of the weighted/normalized adjacency matrix. x : torch.Tensor Dense [n, d] tensor containing the node attributes/embeddings. Returns ------- torch.Tensor The new embeddings [n, d]. """ N, D = x.shape l2 = _distance_matrix(x) A_cpu_dense = A.cpu() l2_cpu = l2.cpu() if A.is_sparse: A_cpu_dense = A_cpu_dense.to_dense() distances = A_cpu_dense[:, None, :].expand(N, N, N) * l2_cpu distances[A_cpu_dense == 0] = torch.finfo(distances.dtype).max distances = distances.sum(-1).to(x.device) distances[~torch.isfinite(distances)] = torch.finfo(distances.dtype).max row_sum = A_cpu_dense.sum(-1)[:, None].to(x.device) return row_sum * x[distances.argmin(-1)]
def weighted_dimwise_median_cpu(A: torch.sparse.FloatTensor, x: torch.Tensor, **kwargs) -> torch.Tensor: """A weighted dimension-wise Median aggregation (cpu implementation). Parameters ---------- A : torch.sparse.FloatTensor Sparse [n, n] tensor of the weighted/normalized adjacency matrix. x : torch.Tensor Dense [n, d] tensor containing the node attributes/embeddings. Returns ------- torch.Tensor The new embeddings [n, d]. """ N, D = x.shape x_sorted, index_x = torch.sort(x, dim=0) matrix_index_for_each_node = torch.arange( N, dtype=torch.long)[:, None, None].expand(N, N, D) A_cpu_dense = A.cpu() if A.is_sparse: A_cpu_dense = A_cpu_dense.to_dense() cum_sorted_weights = A_cpu_dense[matrix_index_for_each_node, index_x].cumsum(1) weight_sum_per_node = cum_sorted_weights.max(1)[0] median_element = (cum_sorted_weights < (weight_sum_per_node / 2)[:, None].expand( N, N, D)).sum(1).to(A.device) matrix_reverse_index = torch.arange(D, dtype=torch.long)[None, :].expand( N, D).to(A.device) x_selected = x[index_x[median_element, matrix_reverse_index], matrix_reverse_index] return weight_sum_per_node.to(A.device) * x_selected
def dense_cpu_soft_weighted_medoid_k_neighborhood( A: torch.sparse.FloatTensor, x: torch.Tensor, k: int = 32, temperature: float = 1.0, with_weight_correction: bool = False, **kwargs) -> torch.Tensor: """Dense cpu implementation (for details see `soft_weighted_medoid_k_neighborhood`). """ n, d = x.size() A_dense = A.to_dense() l2 = _distance_matrix(x) topk_a, topk_a_idx = torch.topk(A_dense, k=k, dim=1) topk_l2_idx = topk_a_idx[:, None, :].expand(n, k, k) distances_k = (topk_a[:, None, :].expand(n, k, k) * l2[topk_l2_idx, topk_l2_idx.transpose(1, 2)]).sum(-1) distances_k[topk_a == 0] = torch.finfo(distances_k.dtype).max distances_k[~torch.isfinite(distances_k)] = torch.finfo( distances_k.dtype).max row_sum = A_dense.sum(-1)[:, None] topk_weights = torch.zeros(A.shape, device=A.device) topk_weights[torch.arange(n)[:, None].expand(n, k), topk_a_idx] = F.softmax(-distances_k / temperature, dim=-1) if with_weight_correction: topk_weights[torch.arange(n)[:, None].expand(n, k), topk_a_idx] *= topk_a topk_weights /= topk_weights.sum(-1)[:, None] return row_sum * (topk_weights @ x)
def expand_adjacency_tensor( adj: torch.sparse.FloatTensor) -> torch.sparse.FloatTensor: row, col = adj._indices() num_interactions = len(row) * 2 row_new = torch.cat([ torch.arange(0, num_interactions, 2, dtype=torch.long, device=settings.device), torch.arange(1, max(1, num_interactions), 2, dtype=torch.long, device=settings.device), ]) col_new = torch.cat([row, col]) data_new = torch.ones(num_interactions, dtype=torch.float, device=settings.device) size_new = (num_interactions, adj.shape[0] ) # number of interactions x sequence length adj_new = torch.sparse_coo_tensor(torch.stack([row_new, col_new]), data_new, size=size_new) return adj_new
def weighted_medoid_k_neighborhood(A: torch.sparse.FloatTensor, x: torch.Tensor, k: int = 32, **kwargs) -> torch.Tensor: """A weighted Medoid aggregation. Parameters ---------- A : torch.sparse.FloatTensor Sparse [n, n] tensor of the weighted/normalized adjacency matrix. x : torch.Tensor Dense [n, d] tensor containing the node attributes/embeddings. Returns ------- torch.Tensor The new embeddings [n, d]. """ N, D = x.shape if k > N: return weighted_medoid(A, x) l2 = _distance_matrix(x) if A.is_sparse: A_dense = A.to_dense() else: A_dense = A topk_a, topk_a_idx = torch.topk(A_dense, k=k, dim=1) topk_l2_idx = topk_a_idx[:, None, :].expand(N, k, k) distances_k = (topk_a[:, None, :].expand(N, k, k) * l2[topk_l2_idx, topk_l2_idx.transpose(1, 2)]).sum(-1) distances_k[topk_a == 0] = torch.finfo(distances_k.dtype).max distances_k[~torch.isfinite(distances_k)] = torch.finfo( distances_k.dtype).max row_sum = A_dense.sum(-1)[:, None] return row_sum * x[topk_a_idx[torch.arange(N), distances_k.argmin(-1)]]
def __init__(self, weights: torch.sparse.FloatTensor, biases=None, activation=None): super(SparseFCLayer, self).__init__() if not weights.is_sparse: raise ValueError("Left weights must be sparse") elif not weights.is_coalesced(): raise ValueError("Left weights must be coalesced") # Dimension is reversed self.n_inputs = weights.size(1) self.n_outputs = weights.size(0) self._activation = activation self._weights = Parameter(weights) if biases is None: self._biases = Parameter(torch.Tensor(self.n_outputs, 1)) torch.nn.init.zeros_(self._biases) else: self._biases = Parameter(biases)
def fit(self, adj: torch.sparse.FloatTensor, attr: torch.Tensor, labels: torch.Tensor, idx_train: np.ndarray, idx_val: np.ndarray, max_epochs: int = 200, **kwargs): self.device = adj.device super().fit(features=attr, adj=adj.to_dense(), labels=labels, idx_train=idx_train, idx_val=idx_val, train_iters=max_epochs)
def __init__(self, adj: torch.sparse.FloatTensor, X: torch.Tensor, labels: torch.Tensor, idx_attack: np.ndarray, model: DenseGCN, **kwargs): super().__init__() assert adj.device == X.device, 'The device of the features and adjacency matrix must match' self.device = X.device self.original_adj = adj.to_dense() self.adj = self.original_adj.clone().requires_grad_(True) self.X = X self.labels = labels self.idx_attack = idx_attack self.model = deepcopy(model).to(self.device) self.attr_adversary = None self.adj_adversary = None
def _sparse_top_k(A: torch.sparse.FloatTensor, k: int, return_sparse: bool = True): n = A.shape[0] if A.is_cuda: topk_values, topk_idx = custom_cuda_kernels.topk( A._indices(), A._values(), n, k) if not return_sparse: return topk_values, topk_idx.long() mask = topk_idx != -1 row_idx = torch.arange(n, device=A.device).view(-1, 1).expand(n, k) return torch.sparse.FloatTensor( torch.stack((row_idx[mask], topk_idx[mask].long())), topk_values[mask]) n_edges_per_row = torch_scatter.scatter_sum(torch.ones_like(A._values()), A._indices()[0], dim=0) k_per_row = torch.clamp(n_edges_per_row, max=k).long() new_idx, value_idx, unroll_idx = _select_k_idx_cpu( A.indices()[0].cpu().numpy(), A.indices()[1].cpu().numpy(), A._values().cpu().numpy(), k_per_row.cpu().numpy(), n, method='top') new_idx = torch.from_numpy(np.hstack(new_idx)).to(A.device) value_idx = torch.from_numpy(np.hstack(value_idx)).to(A.device) if return_sparse: return torch.sparse.FloatTensor(new_idx, A._values()[value_idx]) else: unroll_idx = np.hstack(unroll_idx) values = torch.zeros((n, k), device=A.device) indices = -torch.ones((n, k), device=A.device, dtype=torch.long) values[new_idx[0], unroll_idx] = A._values()[value_idx] indices[new_idx[0], unroll_idx] = new_idx[1] return values, indices
def forward(self, h: torch.sparse.FloatTensor, adj: torch.sparse.FloatTensor): """ Forward pass the sprase GAT layer. Inputs: h.shape == (num_nodes, in_features) adjacency.shape == (num_nodes, num_nodes) Outputs: out.shape == out_features """ num_nodes = h.size()[0] # h over here is actually w_h_transpose: we use the transposed form - in the end this doesn't matter # w_h_transpose.shape == (N, out_features) h = torch.mm(h, self.weight) assert not torch.isnan(h).any() # (Z, E) tensor of indices where there are Z non-zero indices edge = torch.nonzero(adj, as_tuple=False).t() # edge_h.shape == (2*D x E) - Select relevant Wh nodes edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() """ Self-attention on the nodes - Shared attention mechanism """ # edge_e.shape == (E) - This is the numerator in equation 3 edge_e = torch.exp(self.leakyrelu( self.alpha.mm(edge_h).squeeze())) # I just removed the minus here assert not torch.isnan(edge_e).any() # e_rowsum.shape == (num_nodes x 1) - This is the denominator in equation 3 e_rowsum = self.special_spmm(edge, edge_e, torch.Size([num_nodes, num_nodes]), torch.ones(size=(num_nodes, 1))) # h_prime.shape == (num_nodes x out) # Aggregate Wh using the attention coefficients edge_e = self.dropout(edge_e) h_prime = self.special_spmm(edge, edge_e, torch.Size([num_nodes, num_nodes]), h) assert not torch.isnan(h_prime).any() # h_prime.shape == (num_nodes x out) # Divide by normalising factor h_prime = h_prime.div(e_rowsum) assert not torch.isnan(h_prime).any() return h_prime
def forward(self, x: torch.Tensor, adj: torch.sparse.FloatTensor): if self.max_distance: adj = cut_to_max_distance(adj, self.max_distance) if self.barcode_method is not None: x = self._conv_wbarcode(x, adj) elif self.wself: x = self._conv_wself(x, adj) else: x = self._conv(x, adj) adj_sum = adj.to_dense().sum(dim=0) if self.normalize: adj_sum_clamped = adj_sum.clone() adj_sum_clamped[adj_sum_clamped == 0] = 1 x = x / adj_sum_clamped if self.add_counts: x = torch.cat([x, adj_sum.expand(x.size(0), 1, -1)], dim=1) return x