def one_hot(src, size=None, out=None): src_data = src if torch.is_tensor(src) else src.data size = src_data.max() + 1 if size is None else size sizes = src.size(0), size out = new(src.float()) if out is None else out out.resize_(*sizes) if torch.is_tensor(out) else out.data.resize_(*sizes) out.fill_(0) out[torch.arange(sizes[0], out=src_data.new()), src_data] = 1 return out
def softmax(index, src, num_nodes=None): num_nodes = index.max() + 1 if num_nodes is None else num_nodes index = index if torch.is_tensor(src) else Variable(index) sizes = list(src.size())[1:] output = src.exp() output_sum = new(output, num_nodes, *sizes).fill_(0) index_expand = index.view(-1, *repeat(1, len(sizes))).expand_as(src) output /= output_sum.scatter_add_(0, index_expand, output)[index] return output
def matmul(edge_index, edge_attr, tensor): tensor = tensor if tensor.dim() > 1 else tensor.unsqueeze(-1) assert edge_attr.dim() == 1 and tensor.dim() == 2 num_nodes, dim = tensor.size() row, col = edge_index row = row if torch.is_tensor(tensor) else Variable(row) output_col = edge_attr.unsqueeze(-1) * tensor[col] output = new(output_col, num_nodes, dim).fill_(0) row_expand = row.unsqueeze(-1).expand_as(output_col) output.scatter_add_(0, row_expand, output_col) return output
def degree(index, num_nodes=None, out=None): """Computes the degree of a given index tensor. Args: index (LongTensor): Source or target indices of edges num_nodes (int, optional): The number of nodes in :obj:`index` out (Tensor, optional): The result tensor :rtype: :class:`Tensor` .. testsetup:: import torch .. testcode:: from torch_geometric.utils import degree index = torch.LongTensor([0, 1, 0, 2, 0]) output = degree(index) print(output) .. testoutput:: 3 1 1 [torch.FloatTensor of size 3] """ num_nodes = index.max() + 1 if num_nodes is None else num_nodes out = index.new().float() if out is None else out index = index if torch.is_tensor(out) else Variable(index) if torch.is_tensor(out): out.resize_(num_nodes) else: out.data.resize_(num_nodes) one = new(out, index.size(0)).fill_(1) return out.fill_(0).scatter_add_(0, index, one)
def normalized_cut(edge_index, edge_attr, num_nodes=None): row, col = edge_index deg = 1 / degree(row, num_nodes, new(edge_attr)) deg = deg[row] + deg[col] return edge_attr * deg