def weighted_medoid_k_neighborhood(A: torch.sparse.FloatTensor, x: torch.Tensor, k: int = 32, **kwargs) -> torch.Tensor: """A weighted Medoid aggregation. Parameters ---------- A : torch.sparse.FloatTensor Sparse [n, n] tensor of the weighted/normalized adjacency matrix. x : torch.Tensor Dense [n, d] tensor containing the node attributes/embeddings. Returns ------- torch.Tensor The new embeddings [n, d]. """ N, D = x.shape if k > N: return weighted_medoid(A, x) l2 = _distance_matrix(x) if A.is_sparse: A_dense = A.to_dense() else: A_dense = A topk_a, topk_a_idx = torch.topk(A_dense, k=k, dim=1) topk_l2_idx = topk_a_idx[:, None, :].expand(N, k, k) distances_k = (topk_a[:, None, :].expand(N, k, k) * l2[topk_l2_idx, topk_l2_idx.transpose(1, 2)]).sum(-1) distances_k[topk_a == 0] = torch.finfo(distances_k.dtype).max distances_k[~torch.isfinite(distances_k)] = torch.finfo( distances_k.dtype).max row_sum = A_dense.sum(-1)[:, None] return row_sum * x[topk_a_idx[torch.arange(N), distances_k.argmin(-1)]]
def dense_cpu_soft_weighted_medoid_k_neighborhood( A: torch.sparse.FloatTensor, x: torch.Tensor, k: int = 32, temperature: float = 1.0, with_weight_correction: bool = False, **kwargs) -> torch.Tensor: """Dense cpu implementation (for details see `soft_weighted_medoid_k_neighborhood`). """ n, d = x.size() A_dense = A.to_dense() l2 = _distance_matrix(x) topk_a, topk_a_idx = torch.topk(A_dense, k=k, dim=1) topk_l2_idx = topk_a_idx[:, None, :].expand(n, k, k) distances_k = (topk_a[:, None, :].expand(n, k, k) * l2[topk_l2_idx, topk_l2_idx.transpose(1, 2)]).sum(-1) distances_k[topk_a == 0] = torch.finfo(distances_k.dtype).max distances_k[~torch.isfinite(distances_k)] = torch.finfo( distances_k.dtype).max row_sum = A_dense.sum(-1)[:, None] topk_weights = torch.zeros(A.shape, device=A.device) topk_weights[torch.arange(n)[:, None].expand(n, k), topk_a_idx] = F.softmax(-distances_k / temperature, dim=-1) if with_weight_correction: topk_weights[torch.arange(n)[:, None].expand(n, k), topk_a_idx] *= topk_a topk_weights /= topk_weights.sum(-1)[:, None] return row_sum * (topk_weights @ x)
def forward(self, input: torch.Tensor, adj: torch.sparse.FloatTensor): support = self.conv(input) adj = adj.to_dense() if self.normalize: adj_sum = adj.sum(dim=0) adj_sum[adj_sum == 0] = 1 adj = adj_sum.diag().inverse() @ adj output = support @ adj.transpose(0, 1) if self.bias is not None: output = (output.transpose(1, 2) + self.bias).transpose(2, 1) return output
def fit(self, adj: torch.sparse.FloatTensor, attr: torch.Tensor, labels: torch.Tensor, idx_train: np.ndarray, idx_val: np.ndarray, max_epochs: int = 200, **kwargs): self.device = adj.device super().fit(features=attr, adj=adj.to_dense(), labels=labels, idx_train=idx_train, idx_val=idx_val, train_iters=max_epochs)
def __init__(self, adj: torch.sparse.FloatTensor, X: torch.Tensor, labels: torch.Tensor, idx_attack: np.ndarray, model: DenseGCN, **kwargs): super().__init__() assert adj.device == X.device, 'The device of the features and adjacency matrix must match' self.device = X.device self.original_adj = adj.to_dense() self.adj = self.original_adj.clone().requires_grad_(True) self.X = X self.labels = labels self.idx_attack = idx_attack self.model = deepcopy(model).to(self.device) self.attr_adversary = None self.adj_adversary = None
def forward(self, x: torch.Tensor, adj: torch.sparse.FloatTensor): if self.max_distance: adj = cut_to_max_distance(adj, self.max_distance) if self.barcode_method is not None: x = self._conv_wbarcode(x, adj) elif self.wself: x = self._conv_wself(x, adj) else: x = self._conv(x, adj) adj_sum = adj.to_dense().sum(dim=0) if self.normalize: adj_sum_clamped = adj_sum.clone() adj_sum_clamped[adj_sum_clamped == 0] = 1 x = x / adj_sum_clamped if self.add_counts: x = torch.cat([x, adj_sum.expand(x.size(0), 1, -1)], dim=1) return x