示例#1
0
def weighted_medoid_k_neighborhood(A: torch.sparse.FloatTensor,
                                   x: torch.Tensor,
                                   k: int = 32,
                                   **kwargs) -> torch.Tensor:
    """A weighted Medoid aggregation.

    Parameters
    ----------
    A : torch.sparse.FloatTensor
        Sparse [n, n] tensor of the weighted/normalized adjacency matrix.
    x : torch.Tensor
        Dense [n, d] tensor containing the node attributes/embeddings.

    Returns
    -------
    torch.Tensor
        The new embeddings [n, d].
    """
    N, D = x.shape
    if k > N:
        return weighted_medoid(A, x)
    l2 = _distance_matrix(x)
    if A.is_sparse:
        A_dense = A.to_dense()
    else:
        A_dense = A
    topk_a, topk_a_idx = torch.topk(A_dense, k=k, dim=1)
    topk_l2_idx = topk_a_idx[:, None, :].expand(N, k, k)
    distances_k = (topk_a[:, None, :].expand(N, k, k) *
                   l2[topk_l2_idx, topk_l2_idx.transpose(1, 2)]).sum(-1)
    distances_k[topk_a == 0] = torch.finfo(distances_k.dtype).max
    distances_k[~torch.isfinite(distances_k)] = torch.finfo(
        distances_k.dtype).max
    row_sum = A_dense.sum(-1)[:, None]
    return row_sum * x[topk_a_idx[torch.arange(N), distances_k.argmin(-1)]]
示例#2
0
def dense_cpu_soft_weighted_medoid_k_neighborhood(
        A: torch.sparse.FloatTensor,
        x: torch.Tensor,
        k: int = 32,
        temperature: float = 1.0,
        with_weight_correction: bool = False,
        **kwargs) -> torch.Tensor:
    """Dense cpu implementation (for details see `soft_weighted_medoid_k_neighborhood`).
    """
    n, d = x.size()
    A_dense = A.to_dense()

    l2 = _distance_matrix(x)

    topk_a, topk_a_idx = torch.topk(A_dense, k=k, dim=1)
    topk_l2_idx = topk_a_idx[:, None, :].expand(n, k, k)
    distances_k = (topk_a[:, None, :].expand(n, k, k) *
                   l2[topk_l2_idx, topk_l2_idx.transpose(1, 2)]).sum(-1)
    distances_k[topk_a == 0] = torch.finfo(distances_k.dtype).max
    distances_k[~torch.isfinite(distances_k)] = torch.finfo(
        distances_k.dtype).max

    row_sum = A_dense.sum(-1)[:, None]

    topk_weights = torch.zeros(A.shape, device=A.device)
    topk_weights[torch.arange(n)[:, None].expand(n, k),
                 topk_a_idx] = F.softmax(-distances_k / temperature, dim=-1)
    if with_weight_correction:
        topk_weights[torch.arange(n)[:, None].expand(n, k),
                     topk_a_idx] *= topk_a
        topk_weights /= topk_weights.sum(-1)[:, None]
    return row_sum * (topk_weights @ x)
示例#3
0
 def forward(self, input: torch.Tensor, adj: torch.sparse.FloatTensor):
     support = self.conv(input)
     adj = adj.to_dense()
     if self.normalize:
         adj_sum = adj.sum(dim=0)
         adj_sum[adj_sum == 0] = 1
         adj = adj_sum.diag().inverse() @ adj
     output = support @ adj.transpose(0, 1)
     if self.bias is not None:
         output = (output.transpose(1, 2) + self.bias).transpose(2, 1)
     return output
示例#4
0
    def fit(self,
            adj: torch.sparse.FloatTensor,
            attr: torch.Tensor,
            labels: torch.Tensor,
            idx_train: np.ndarray,
            idx_val: np.ndarray,
            max_epochs: int = 200,
            **kwargs):

        self.device = adj.device

        super().fit(features=attr,
                    adj=adj.to_dense(),
                    labels=labels,
                    idx_train=idx_train,
                    idx_val=idx_val,
                    train_iters=max_epochs)
示例#5
0
 def __init__(self,
              adj: torch.sparse.FloatTensor,
              X: torch.Tensor,
              labels: torch.Tensor,
              idx_attack: np.ndarray,
              model: DenseGCN,
              **kwargs):
     super().__init__()
     assert adj.device == X.device, 'The device of the features and adjacency matrix must match'
     self.device = X.device
     self.original_adj = adj.to_dense()
     self.adj = self.original_adj.clone().requires_grad_(True)
     self.X = X
     self.labels = labels
     self.idx_attack = idx_attack
     self.model = deepcopy(model).to(self.device)
     self.attr_adversary = None
     self.adj_adversary = None
示例#6
0
    def forward(self, x: torch.Tensor, adj: torch.sparse.FloatTensor):
        if self.max_distance:
            adj = cut_to_max_distance(adj, self.max_distance)

        if self.barcode_method is not None:
            x = self._conv_wbarcode(x, adj)
        elif self.wself:
            x = self._conv_wself(x, adj)
        else:
            x = self._conv(x, adj)

        adj_sum = adj.to_dense().sum(dim=0)
        if self.normalize:
            adj_sum_clamped = adj_sum.clone()
            adj_sum_clamped[adj_sum_clamped == 0] = 1
            x = x / adj_sum_clamped
        if self.add_counts:
            x = torch.cat([x, adj_sum.expand(x.size(0), 1, -1)], dim=1)
        return x