def batched_spmm(nzt, adj, x, m=None, n=None): """ Args: nzt: Tensor [num_edges, heads] -- non-zero tensor adj: Tensor or list(Tensor) -- adjacency matrix (COO) x: Tensor [num_nodes, channels] -- feature matrix m: int n: int """ num_edges, heads = nzt.shape[-2:] num_nodes, channels = x.shape[-2:] # preparation of data # x_ = torch.cat(heads * [x]) # duplicate x for heads times # nzt_ = nzt.view(-1) x_ = repeat(x, 't n c -> t (h n) c', h=heads) nzt_ = rearrange(nzt, 't e h -> t (h e)') if isinstance(adj, Tensor): m = maybe_num_nodes(adj[0], m) n = max(num_nodes, maybe_num_nodes(adj[1], n)) offset = torch.tensor([[m], [n]]) adj_ = torch.cat([adj + offset * i for i in range(heads)], dim=1) else: # adj is list of adjacency matrices assert heads == len( adj), "the number of heads and the number of adjacency matrices are not matched" m = max([maybe_num_nodes(adj_[0], m) for adj_ in adj]) n = max([maybe_num_nodes(adj_[1], n) for adj_ in adj]) offset = torch.tensor([[m], [n]]) adj_ = torch.cat([adj[i] + offset * i for i in range(heads)], dim=1) if len(x.shape) == 2: out = spmm(adj_, nzt_, heads * m, heads * n, x_) return out.view(-1, m, channels) # [heads, m, channels] else: _size = x_.shape[0] out = torch.stack([spmm(adj_, nzt_[i], heads * m, heads * n, x_[i]) for i in range(_size)]) return out # [batch, heads * num_nodes, channels]
def gcn_no_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, dtype=None): fill_value = 2. if improved else 1. if isinstance(edge_index, SparseTensor): adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) if add_self_loops: adj_t = fill_diag(adj_t, fill_value) return adj_t else: num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) if add_self_loops: edge_index, tmp_edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) assert tmp_edge_weight is not None edge_weight = tmp_edge_weight row, col = edge_index[0], edge_index[1] return edge_index, edge_weight
def add_self_loops(edge_index, edge_weight: Optional[torch.Tensor] = None, fill_value: float = 1., num_nodes: Optional[int] = None): r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. In case the graph is weighted, self-loops will be added with edge weights denoted by :obj:`fill_value`. Args: edge_index (LongTensor): The edge indices. edge_weight (Tensor, optional): One-dimensional edge weights. (default: :obj:`None`) fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`, will add self-loops with edge weights of :obj:`fill_value` to the graph. (default: :obj:`1.`) num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (:class:`LongTensor`, :class:`Tensor`) """ N = maybe_num_nodes(edge_index, num_nodes) loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) new_edge_index = torch.cat([edge_index, loop_index], dim=1) return new_edge_index, edge_weight
def __init__(self, edge_index: Union[Tensor, SparseTensor], sizes: List[int], node_idx: Optional[Tensor] = None, num_nodes: Optional[int] = None, return_e_id: bool = True, **kwargs): self.sizes = sizes self.return_e_id = return_e_id self.is_sparse_tensor = isinstance(edge_index, SparseTensor) self.__val__ = None # Obtain a *transposed* `SparseTensor` instance. edge_index = edge_index.to('cpu') if not self.is_sparse_tensor: num_nodes = maybe_num_nodes(edge_index, num_nodes) value = torch.arange(edge_index.size(1)) if return_e_id else None self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], value=value, sparse_sizes=(num_nodes, num_nodes)).t() else: adj_t = edge_index if return_e_id: self.__val__ = adj_t.storage.value() value = torch.arange(adj_t.nnz()) adj_t = adj_t.set_value(value, layout='coo') self.adj_t = adj_t self.adj_t.storage.rowptr() if node_idx is None: node_idx = torch.arange(self.adj_t.sparse_size(0)) elif node_idx.dtype == torch.bool: node_idx = node_idx.nonzero(as_tuple=False).view(-1) super(NeighborSampler, self).__init__( node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs)
def real_softmax(src, index, num_nodes=None): r"""Computes a sparsely evaluated softmax. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements for applying the softmax. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) :rtype: :class:`Tensor` """ num_nodes = maybe_num_nodes(index, num_nodes) src = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] out = src.exp() assert not nan_or_inf(out) oout = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) assert not nan_or_inf(oout) return oout
def remove_isolated_nodes(edge_index, edge_attr=None, num_nodes=None): num_nodes = maybe_num_nodes(edge_index, num_nodes) out = segregate_self_loops(edge_index, edge_attr) edge_index, edge_attr, loop_edge_index, loop_edge_attr = out mask = torch.zeros(num_nodes, dtype=torch.bool, device=edge_index.device) mask[edge_index.view(-1)] = 1 assoc = torch.full((num_nodes, ), -1, dtype=torch.long, device=mask.device) assoc[mask] = torch.arange(mask.type(torch.long).sum(), device=assoc.device) edge_index = assoc[edge_index] loop_mask = torch.zeros_like(mask) loop_mask[loop_edge_index[0]] = 1 loop_mask = loop_mask.type(torch.uint8) & mask.type(torch.uint8) loop_assoc = torch.full_like(assoc, -1) loop_assoc[loop_edge_index[0]] = torch.arange(loop_edge_index.size(1), device=loop_assoc.device) loop_idx = loop_assoc[loop_mask] loop_edge_index = assoc[loop_edge_index[:, loop_idx]] edge_index = torch.cat([edge_index, loop_edge_index], dim=1) if edge_attr is not None: loop_edge_attr = loop_edge_attr[loop_idx] edge_attr = torch.cat([edge_attr, loop_edge_attr], dim=0) return edge_index, edge_attr, mask
def entropy(src, edge_index, num_nodes=None): EPS = 1e-32 index, _ = edge_index num_nodes = maybe_num_nodes(index, num_nodes) # out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] out = -src * torch.log(src + EPS) out = scatter_add(out, index, dim=0, dim_size=num_nodes) return out
def batched_transpose(adj, value, m=None, n=None): """ Args: adj: Tensor or list of Tensor value: Tensor [num_edges, ] m: int n: int """ if isinstance(adj, Tensor): m = maybe_num_nodes(adj[0], m) n = maybe_num_nodes(adj[1], n) return transpose(adj, value, m, n) else: # adj is a list of Tensor adj_ = [None] * value.shape[1] vs = torch.zeros(value.shape) m = max([maybe_num_nodes(a_[0], m) for a_ in adj]) n = max([maybe_num_nodes(a_[1], n) for a_ in adj]) for j in range(len(adj)): adj_[j], vs[:, j] = transpose(adj[j], value[:, j], m, n) return adj_, vs
def message(self, x_j, pos_j, pos_i, edge_index): dist = (pos_j - pos_i).pow(2).sum(dim=1).pow(0.5) dist = torch.max(dist, torch.Tensor([1e-10]).to(dist.device, dist.dtype)) weight = 1.0 / dist # (E,) row, col = edge_index index = col num_nodes = maybe_num_nodes(index, None) wsum = scatter_add(weight, col, dim=0, dim_size=num_nodes)[index] + 1e-16 # (E,) weight /= wsum return weight.view(-1, 1) * x_j
def softmax(src: Tensor, index: Optional[Tensor], ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None) -> Tensor: out = src if src.numel() > 0: out = out - src.max() out = out.exp() if ptr is not None: out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr) elif index is not None: N = maybe_num_nodes(index, num_nodes) out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index] else: raise NotImplementedError return out / (out_sum + 1e-16)
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, flow="source_to_target", dtype=None): fill_value = 2. if improved else 1. if isinstance(edge_index, SparseTensor): assert flow in ["source_to_target"] adj_t = edge_index if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) if add_self_loops: adj_t = fill_diag(adj_t, fill_value) deg = sparsesum(adj_t, dim=1) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) return adj_t else: assert flow in ["source_to_target", "target_to_source"] num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) if add_self_loops: edge_index, tmp_edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value, num_nodes) assert tmp_edge_weight is not None edge_weight = tmp_edge_weight row, col = edge_index[0], edge_index[1] idx = col if flow == "source_to_target" else row deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes) deg_inv_sqrt = deg.pow_(-0.5) deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def add_self_loops_mul(edge_index, edge_weight=None, fill_value=1, num_nodes=None): num_nodes = maybe_num_nodes(edge_index, num_nodes) loop_index = torch.arange(0, num_nodes, dtype=torch.long, device=edge_index.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) if edge_weight is not None: loop_weight = edge_weight.new_full((num_nodes, edge_weight.shape[1]), fill_value) edge_weight = torch.cat([edge_weight, loop_weight], dim=0) edge_index = torch.cat([edge_index, loop_index], dim=1) return edge_index, edge_weight
def forward(self, x, edge_index, M=None, batch=None, num_nodes=None): """""" if batch is None: batch = edge_index.new_zeros(x.size(0)) # Path integral num_nodes = maybe_num_nodes(edge_index, num_nodes) edge_index, edge_weight = self.panentropy_sparse(edge_index, num_nodes) # weighted degree num_nodes = x.size(0) degree = torch.zeros(num_nodes, device=edge_index.device) degree = scatter_add(edge_weight, edge_index[0], out=degree) # linear transform xtransform = torch.matmul(x, self.transform) # aggregate score x_transform_norm = xtransform #/ xtransform.norm(p=2, dim=-1) degree_norm = degree #/ degree.norm(p=2, dim=-1) score = self.pan_pool_weight[ 0] * x_transform_norm + self.pan_pool_weight[1] * degree_norm if self.min_score is None: score = self.nonlinearity(score) else: score = softmax(score, batch) perm = self.topk(score, self.ratio, batch, self.min_score) x = x[perm] * score[perm].view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x batch = batch[perm] edge_index, edge_weight = self.filter_adj(edge_index, edge_weight, perm, num_nodes=score.size(0)) return x, edge_index, edge_weight, batch, perm, score[perm]
def add_self_loops_no_zero(edge_index, num_nodes=None): r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`, for :math:`i \neq 0`. Modified from `torch_geometric.utils.loop.add_self_loops`. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: (:class:`LongTensor`) """ num_nodes = maybe_num_nodes(edge_index, num_nodes) loop_index = torch.arange(1, num_nodes, dtype=torch.long, device=edge_index.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) edge_index = torch.cat([edge_index, loop_index], dim=1) return edge_index
def gcn_norm(self, edge_index, num_nodes=None, dtype=None, add_self_loops=True): num_nodes = maybe_num_nodes(edge_index, num_nodes) if edge_weight is None: edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device) if add_self_loops: edge_index, tmp_edge_weight = add_remaining_self_loops( edge_index, edge_weight, fill_value=1, num_nodes=num_nodes) assert tmp_edge_weight is not None edge_weight = tmp_edge_weight row, col = edge_index[0], edge_index[1] deg = degree(col, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 return deg_inv_sqrt[row] * deg_inv_sqrt[col]
def softmax(src, index, num_nodes=None): r"""Computes a sparsely evaluated softmax along with its log in a numerically stable way. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then proceeds to compute the softmax individually for each group. Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements for applying the softmax. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) :rtype: :class:`Tensor` """ num_nodes = maybe_num_nodes(index, num_nodes) shifted = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] exp_shifted = shifted.exp() exp_sum = scatter_add(exp_shifted, index, dim=0, dim_size=num_nodes)[index] scores = exp_shifted / (exp_sum + 1e-16) return scores, shifted - (exp_sum + 1e-16).log()
def to_undirected(edge_index, num_nodes=None): r"""Converts the graph given by :attr:`edge_index` to an undirected graph, so that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in \mathcal{E}`. Args: edge_index (LongTensor): The edge indices. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) :rtype: :class:`LongTensor` """ num_nodes = maybe_num_nodes(edge_index, num_nodes) row, col = edge_index value_ori = torch.full([row.size(0)], 2) value_add = torch.full([col.size(0)], 1) row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0) edge_attr = torch.cat([value_ori, value_add], dim=0) edge_index = torch.stack([row, col], dim=0) edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, num_nodes) return edge_index, edge_attr.view(-1, 1).type(torch.float32)
def topk_softmax(src, index, k, num_nodes=None): r"""Computes a sparsely evaluated softmax using only top-k values. Given a value tensor :attr:`src`, this function first groups the values along the first dimension based on the indices specified in :attr:`index`, and then select top-k values for each group, compute the softmax individually. The output of not selected indices will be zero. Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements for applying the softmax. k (int): The number of indexes to select from each group. num_nodes (int, optional): The number of nodes, *i.e.* :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) :rtype: :class:`Tensor` """ num_nodes = maybe_num_nodes(index, num_nodes) # Create mask the topk values of which are 1., otherwise 0. out = src.clone() topk_mask = torch.zeros_like(out) for _ in range(k): v_max = scatter_max(out, index, dim=0, dim_size=num_nodes)[0] i_max = (out == v_max[index]) topk_mask[i_max] = 1. out[i_max] = float("-Inf") out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index] out = out.exp() # Mask except topk values out = out * topk_mask out = out / (scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16) return out