Beispiel #1
0
def batched_spmm(nzt, adj, x, m=None, n=None):
    """
    Args:
        nzt: Tensor [num_edges, heads]    -- non-zero tensor
        adj: Tensor or list(Tensor)       -- adjacency matrix (COO)
        x:   Tensor [num_nodes, channels] -- feature matrix
        m:   int
        n:   int
    """
    num_edges, heads = nzt.shape[-2:]
    num_nodes, channels = x.shape[-2:]
    # preparation of data
    # x_ = torch.cat(heads * [x])  # duplicate x for heads times
    # nzt_ = nzt.view(-1)
    x_ = repeat(x, 't n c -> t (h n) c', h=heads)
    nzt_ = rearrange(nzt, 't e h -> t (h e)')
    if isinstance(adj, Tensor):
        m = maybe_num_nodes(adj[0], m)
        n = max(num_nodes, maybe_num_nodes(adj[1], n))
        offset = torch.tensor([[m], [n]])
        adj_ = torch.cat([adj + offset * i for i in range(heads)], dim=1)
    else:  # adj is list of adjacency matrices
        assert heads == len(
            adj), "the number of heads and the number of adjacency matrices are not matched"
        m = max([maybe_num_nodes(adj_[0], m) for adj_ in adj])
        n = max([maybe_num_nodes(adj_[1], n) for adj_ in adj])
        offset = torch.tensor([[m], [n]])
        adj_ = torch.cat([adj[i] + offset * i for i in range(heads)], dim=1)
    if len(x.shape) == 2:
        out = spmm(adj_, nzt_, heads * m, heads * n, x_)
        return out.view(-1, m, channels)  # [heads, m, channels]
    else:
        _size = x_.shape[0]
        out = torch.stack([spmm(adj_, nzt_[i], heads * m, heads * n, x_[i]) for i in range(_size)])
        return out  # [batch, heads * num_nodes, channels]
Beispiel #2
0
def gcn_no_norm(edge_index,
                edge_weight=None,
                num_nodes=None,
                improved=False,
                add_self_loops=True,
                dtype=None):

    fill_value = 2. if improved else 1.

    if isinstance(edge_index, SparseTensor):
        adj_t = edge_index
        if not adj_t.has_value():
            adj_t = adj_t.fill_value(1., dtype=dtype)
        if add_self_loops:
            adj_t = fill_diag(adj_t, fill_value)
        return adj_t

    else:
        num_nodes = maybe_num_nodes(edge_index, num_nodes)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value, num_nodes)
            assert tmp_edge_weight is not None
            edge_weight = tmp_edge_weight

        row, col = edge_index[0], edge_index[1]
        return edge_index, edge_weight
Beispiel #3
0
def add_self_loops(edge_index, edge_weight: Optional[torch.Tensor] = None,
                   fill_value: float = 1., num_nodes: Optional[int] = None):
    r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node
    :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`.
    In case the graph is weighted, self-loops will be added with edge weights
    denoted by :obj:`fill_value`.

    Args:
        edge_index (LongTensor): The edge indices.
        edge_weight (Tensor, optional): One-dimensional edge weights.
            (default: :obj:`None`)
        fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`,
            will add self-loops with edge weights of :obj:`fill_value` to the
            graph. (default: :obj:`1.`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
    N = maybe_num_nodes(edge_index, num_nodes)

    loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)


    new_edge_index = torch.cat([edge_index, loop_index], dim=1)

    return new_edge_index, edge_weight
Beispiel #4
0
    def __init__(self, edge_index: Union[Tensor, SparseTensor],
                 sizes: List[int], node_idx: Optional[Tensor] = None,
                 num_nodes: Optional[int] = None, return_e_id: bool = True,
                 **kwargs):

        self.sizes = sizes
        self.return_e_id = return_e_id
        self.is_sparse_tensor = isinstance(edge_index, SparseTensor)
        self.__val__ = None

        # Obtain a *transposed* `SparseTensor` instance.
        edge_index = edge_index.to('cpu')
        if not self.is_sparse_tensor:
            num_nodes = maybe_num_nodes(edge_index, num_nodes)
            value = torch.arange(edge_index.size(1)) if return_e_id else None
            self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],
                                      value=value,
                                      sparse_sizes=(num_nodes, num_nodes)).t()
        else:
            adj_t = edge_index
            if return_e_id:
                self.__val__ = adj_t.storage.value()
                value = torch.arange(adj_t.nnz())
                adj_t = adj_t.set_value(value, layout='coo')
            self.adj_t = adj_t

        self.adj_t.storage.rowptr()

        if node_idx is None:
            node_idx = torch.arange(self.adj_t.sparse_size(0))
        elif node_idx.dtype == torch.bool:
            node_idx = node_idx.nonzero(as_tuple=False).view(-1)

        super(NeighborSampler, self).__init__(
            node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs)
Beispiel #5
0
def real_softmax(src, index, num_nodes=None):
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """

    num_nodes = maybe_num_nodes(index, num_nodes)

    src = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
    out = src.exp()
    assert not nan_or_inf(out)
    oout = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] +
                  1e-16)
    assert not nan_or_inf(oout)

    return oout
Beispiel #6
0
def remove_isolated_nodes(edge_index, edge_attr=None, num_nodes=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    out = segregate_self_loops(edge_index, edge_attr)
    edge_index, edge_attr, loop_edge_index, loop_edge_attr = out

    mask = torch.zeros(num_nodes, dtype=torch.bool, device=edge_index.device)
    mask[edge_index.view(-1)] = 1

    assoc = torch.full((num_nodes, ), -1, dtype=torch.long, device=mask.device)
    assoc[mask] = torch.arange(mask.type(torch.long).sum(), device=assoc.device)
    edge_index = assoc[edge_index]

    loop_mask = torch.zeros_like(mask)
    loop_mask[loop_edge_index[0]] = 1
    loop_mask = loop_mask.type(torch.uint8) & mask.type(torch.uint8)
    loop_assoc = torch.full_like(assoc, -1)
    loop_assoc[loop_edge_index[0]] = torch.arange(loop_edge_index.size(1),
                                                  device=loop_assoc.device)
    loop_idx = loop_assoc[loop_mask]
    loop_edge_index = assoc[loop_edge_index[:, loop_idx]]

    edge_index = torch.cat([edge_index, loop_edge_index], dim=1)

    if edge_attr is not None:
        loop_edge_attr = loop_edge_attr[loop_idx]
        edge_attr = torch.cat([edge_attr, loop_edge_attr], dim=0)

    return edge_index, edge_attr, mask
Beispiel #7
0
def entropy(src, edge_index, num_nodes=None):
    EPS = 1e-32
    index, _ = edge_index
    num_nodes = maybe_num_nodes(index, num_nodes)

    # out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
    out = -src * torch.log(src + EPS)
    out = scatter_add(out, index, dim=0, dim_size=num_nodes)

    return out
Beispiel #8
0
def batched_transpose(adj, value, m=None, n=None):
    """
    Args:
        adj: Tensor or list of Tensor
        value: Tensor [num_edges, ]
        m: int
        n: int
    """
    if isinstance(adj, Tensor):
        m = maybe_num_nodes(adj[0], m)
        n = maybe_num_nodes(adj[1], n)
        return transpose(adj, value, m, n)
    else:  # adj is a list of Tensor
        adj_ = [None] * value.shape[1]
        vs = torch.zeros(value.shape)
        m = max([maybe_num_nodes(a_[0], m) for a_ in adj])
        n = max([maybe_num_nodes(a_[1], n) for a_ in adj])
        for j in range(len(adj)):
            adj_[j], vs[:, j] = transpose(adj[j], value[:, j], m, n)
        return adj_, vs
    def message(self, x_j, pos_j, pos_i, edge_index):
        dist = (pos_j - pos_i).pow(2).sum(dim=1).pow(0.5)
        dist = torch.max(dist, torch.Tensor([1e-10]).to(dist.device, dist.dtype))
        weight = 1.0 / dist  # (E,)

        row, col = edge_index
        index = col
        num_nodes = maybe_num_nodes(index, None)
        wsum = scatter_add(weight, col, dim=0, dim_size=num_nodes)[index] + 1e-16  # (E,)
        weight /= wsum

        return weight.view(-1, 1) * x_j
def softmax(src: Tensor, index: Optional[Tensor], ptr: Optional[Tensor] = None,
            num_nodes: Optional[int] = None) -> Tensor:
    out = src
    if src.numel() > 0:
        out = out - src.max()
    out = out.exp()

    if ptr is not None:
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
    elif index is not None:
        N = maybe_num_nodes(index, num_nodes)
        out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index]
    else:
        raise NotImplementedError

    return out / (out_sum + 1e-16)
Beispiel #11
0
def gcn_norm(edge_index,
             edge_weight=None,
             num_nodes=None,
             improved=False,
             add_self_loops=True,
             flow="source_to_target",
             dtype=None):

    fill_value = 2. if improved else 1.

    if isinstance(edge_index, SparseTensor):
        assert flow in ["source_to_target"]
        adj_t = edge_index
        if not adj_t.has_value():
            adj_t = adj_t.fill_value(1., dtype=dtype)
        if add_self_loops:
            adj_t = fill_diag(adj_t, fill_value)
        deg = sparsesum(adj_t, dim=1)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
        adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
        adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
        return adj_t

    else:
        assert flow in ["source_to_target", "target_to_source"]
        num_nodes = maybe_num_nodes(edge_index, num_nodes)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value, num_nodes)
            assert tmp_edge_weight is not None
            edge_weight = tmp_edge_weight

        row, col = edge_index[0], edge_index[1]
        idx = col if flow == "source_to_target" else row
        deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
Beispiel #12
0
def add_self_loops_mul(edge_index,
                       edge_weight=None,
                       fill_value=1,
                       num_nodes=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    loop_index = torch.arange(0,
                              num_nodes,
                              dtype=torch.long,
                              device=edge_index.device)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)

    if edge_weight is not None:
        loop_weight = edge_weight.new_full((num_nodes, edge_weight.shape[1]),
                                           fill_value)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

    edge_index = torch.cat([edge_index, loop_index], dim=1)

    return edge_index, edge_weight
Beispiel #13
0
    def forward(self, x, edge_index, M=None, batch=None, num_nodes=None):
        """"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        # Path integral
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
        edge_index, edge_weight = self.panentropy_sparse(edge_index, num_nodes)

        # weighted degree
        num_nodes = x.size(0)
        degree = torch.zeros(num_nodes, device=edge_index.device)
        degree = scatter_add(edge_weight, edge_index[0], out=degree)

        # linear transform
        xtransform = torch.matmul(x, self.transform)

        # aggregate score
        x_transform_norm = xtransform  #/ xtransform.norm(p=2, dim=-1)
        degree_norm = degree  #/ degree.norm(p=2, dim=-1)
        score = self.pan_pool_weight[
            0] * x_transform_norm + self.pan_pool_weight[1] * degree_norm

        if self.min_score is None:
            score = self.nonlinearity(score)
        else:
            score = softmax(score, batch)

        perm = self.topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_weight = self.filter_adj(edge_index,
                                                  edge_weight,
                                                  perm,
                                                  num_nodes=score.size(0))

        return x, edge_index, edge_weight, batch, perm, score[perm]
Beispiel #14
0
def add_self_loops_no_zero(edge_index, num_nodes=None):
    r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node
    :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`, for
    :math:`i \neq 0`. Modified from `torch_geometric.utils.loop.add_self_loops`.

    Args:
        edge_index (LongTensor): The edge indices.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

    :rtype: (:class:`LongTensor`)
    """
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    loop_index = torch.arange(1,
                              num_nodes,
                              dtype=torch.long,
                              device=edge_index.device)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)

    edge_index = torch.cat([edge_index, loop_index], dim=1)

    return edge_index
Beispiel #15
0
    def gcn_norm(self,
                 edge_index,
                 num_nodes=None,
                 dtype=None,
                 add_self_loops=True):
        num_nodes = maybe_num_nodes(edge_index, num_nodes)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value=1, num_nodes=num_nodes)
            assert tmp_edge_weight is not None
            edge_weight = tmp_edge_weight

        row, col = edge_index[0], edge_index[1]
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        return deg_inv_sqrt[row] * deg_inv_sqrt[col]
Beispiel #16
0
def softmax(src, index, num_nodes=None):
    r"""Computes a sparsely evaluated softmax along with its log in a numerically stable way.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """

    num_nodes = maybe_num_nodes(index, num_nodes)

    shifted = src - scatter_max(src, index, dim=0,
                                dim_size=num_nodes)[0][index]
    exp_shifted = shifted.exp()
    exp_sum = scatter_add(exp_shifted, index, dim=0, dim_size=num_nodes)[index]
    scores = exp_shifted / (exp_sum + 1e-16)

    return scores, shifted - (exp_sum + 1e-16).log()
Beispiel #17
0
def to_undirected(edge_index, num_nodes=None):
    r"""Converts the graph given by :attr:`edge_index` to an undirected graph,
    so that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in
    \mathcal{E}`.

    Args:
        edge_index (LongTensor): The edge indices.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

    :rtype: :class:`LongTensor`
    """
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    row, col = edge_index
    value_ori = torch.full([row.size(0)], 2)
    value_add = torch.full([col.size(0)], 1)
    row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
    edge_attr = torch.cat([value_ori, value_add], dim=0)
    edge_index = torch.stack([row, col], dim=0)
    edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,
                                     num_nodes)

    return edge_index, edge_attr.view(-1, 1).type(torch.float32)
Beispiel #18
0
def topk_softmax(src, index, k, num_nodes=None):
    r"""Computes a sparsely evaluated softmax using only top-k values.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then select top-k values for each group, compute the softmax individually.
    The output of not selected indices will be zero.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        k (int): The number of indexes to select from each group.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """
    num_nodes = maybe_num_nodes(index, num_nodes)

    # Create mask the topk values of which are 1., otherwise 0.
    out = src.clone()
    topk_mask = torch.zeros_like(out)
    for _ in range(k):
        v_max = scatter_max(out, index, dim=0, dim_size=num_nodes)[0]
        i_max = (out == v_max[index])
        topk_mask[i_max] = 1.
        out[i_max] = float("-Inf")

    out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
    out = out.exp()

    # Mask except topk values
    out = out * topk_mask

    out = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] +
                 1e-16)
    return out