def forward(self, x, edge_index=None, batch=None): if not self.global_pool: row, col = edge_index dense_x, num_nodes = to_batch(x[col], row, dim_size=x.size(0)) else: dense_x, num_nodes = to_batch(x, batch) dense_x = dense_x.transpose(1, 2) x, _ = self.pool(dense_x, num_nodes) return x
def test_to_batch(): x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]]) batch = torch.tensor([0, 0, 1, 2, 2, 2]) x, num_nodes = to_batch(x, batch) expected = [ [[1, 2], [3, 4], [0, 0]], [[5, 6], [0, 0], [0, 0]], [[7, 8], [9, 10], [11, 12]], ] assert x.tolist() == expected assert num_nodes.tolist() == [2, 1, 3]
def global_sort_pool(x, batch, k): r"""The global pooling operator from the `"An End-to-End Deep Learning Architecture for Graph Classification" <https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf>`_ paper, where node features are first sorted individually and then sorted in descending order based on their last features. The first :math:`k` nodes form the output of the layer. Args: x (Tensor): Node feature matrix :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. k (int): The number of nodes to hold for each graph. :rtype: :class:`Tensor` """ x, _ = x.sort(dim=-1) fill_value = x.min().item() - 1 batch_x, num_nodes = to_batch(x, batch, fill_value) B, N, D = batch_x.size() _, perm = batch_x[:, :, -1].sort(dim=-1, descending=True) arange = torch.arange(B, dtype=torch.long, device=perm.device) * N perm = perm + arange.view(-1, 1) batch_x = batch_x.view(B * N, D) batch_x = batch_x[perm] batch_x = batch_x.view(B, N, D) if N >= k: batch_x = batch_x[:, :k].contiguous() else: expand_batch_x = torch.full(size=(B, k, D), fill_value=fill_value) expand_batch_x[:, :N, :] = batch_x batch_x = expand_batch_x.contiguous() batch_x[batch_x == fill_value] = 0 x = batch_x.view(B, k * D) return x
def sort_pool(x, batch, k): x, _ = x.sort(dim=-1) fill_value = x.min().item() - 1 batch_x, num_nodes = to_batch(x, batch, fill_value) B, N, D = batch_x.size() _, perm = batch_x[:, :, -1].sort(dim=-1, descending=True) arange = torch.arange(B, dtype=torch.long, device=perm.device) * N perm = perm + arange.view(-1, 1) batch_x = batch_x.view(B * N, D) batch_x = batch_x[perm] batch_x = batch_x.view(B, N, D) batch_x = batch_x[:, :k].contiguous() batch_x[batch_x == fill_value] = 0 x = batch_x.view(B, k * D) return x
def forward(self, x, batch): """""" x, _ = to_batch(x, batch) batch_size, max_nodes, _ = x.size() h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)), x.new_zeros((self.num_layers, batch_size, self.in_channels))) q_star = x.new_zeros(1, batch_size, self.out_channels) for i in range(self.processing_steps): q, h = self.lstm(q_star, h) q = q.view(batch_size, 1, self.in_channels) e = (x * q).sum(dim=-1) # Dot product. a = torch.softmax(e, dim=-1) a = a.view(batch_size, max_nodes, 1) r = (a * x).sum(dim=1, keepdim=True) q_star = torch.cat([q, r], dim=-1) q_star = q_star.view(1, batch_size, self.out_channels) q_star = q_star.view(batch_size, self.out_channels) return q_star