Exemplo n.º 1
0
 def forward(self, x, edge_index=None, batch=None):
     if not self.global_pool:
         row, col = edge_index
         dense_x, num_nodes = to_batch(x[col], row, dim_size=x.size(0))
     else:
         dense_x, num_nodes = to_batch(x, batch)
     dense_x = dense_x.transpose(1, 2)
     x, _ = self.pool(dense_x, num_nodes)
     return x
def test_to_batch():
    x = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
    batch = torch.tensor([0, 0, 1, 2, 2, 2])

    x, num_nodes = to_batch(x, batch)
    expected = [
        [[1, 2], [3, 4], [0, 0]],
        [[5, 6], [0, 0], [0, 0]],
        [[7, 8], [9, 10], [11, 12]],
    ]
    assert x.tolist() == expected
    assert num_nodes.tolist() == [2, 1, 3]
Exemplo n.º 3
0
def global_sort_pool(x, batch, k):
    r"""The global pooling operator from the `"An End-to-End Deep Learning
    Architecture for Graph Classification"
    <https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf>`_ paper,
    where node features are first sorted individually and then  sorted in
    descending order based on their last features. The first :math:`k` nodes
    form the output of the layer.

    Args:
        x (Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
        batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots,
            B-1\}}^N`, which assigns each node to a specific example.
        k (int): The number of nodes to hold for each graph.

    :rtype: :class:`Tensor`
    """
    x, _ = x.sort(dim=-1)

    fill_value = x.min().item() - 1
    batch_x, num_nodes = to_batch(x, batch, fill_value)
    B, N, D = batch_x.size()

    _, perm = batch_x[:, :, -1].sort(dim=-1, descending=True)
    arange = torch.arange(B, dtype=torch.long, device=perm.device) * N
    perm = perm + arange.view(-1, 1)

    batch_x = batch_x.view(B * N, D)
    batch_x = batch_x[perm]
    batch_x = batch_x.view(B, N, D)

    if N >= k:
        batch_x = batch_x[:, :k].contiguous()
    else:
        expand_batch_x = torch.full(size=(B, k, D), fill_value=fill_value)
        expand_batch_x[:, :N, :] = batch_x
        batch_x = expand_batch_x.contiguous()

    batch_x[batch_x == fill_value] = 0
    x = batch_x.view(B, k * D)

    return x
Exemplo n.º 4
0
def sort_pool(x, batch, k):
    x, _ = x.sort(dim=-1)

    fill_value = x.min().item() - 1
    batch_x, num_nodes = to_batch(x, batch, fill_value)
    B, N, D = batch_x.size()

    _, perm = batch_x[:, :, -1].sort(dim=-1, descending=True)
    arange = torch.arange(B, dtype=torch.long, device=perm.device) * N
    perm = perm + arange.view(-1, 1)

    batch_x = batch_x.view(B * N, D)
    batch_x = batch_x[perm]
    batch_x = batch_x.view(B, N, D)

    batch_x = batch_x[:, :k].contiguous()
    batch_x[batch_x == fill_value] = 0
    x = batch_x.view(B, k * D)

    return x
Exemplo n.º 5
0
    def forward(self, x, batch):
        """"""
        x, _ = to_batch(x, batch)
        batch_size, max_nodes, _ = x.size()

        h = (x.new_zeros((self.num_layers, batch_size, self.in_channels)),
             x.new_zeros((self.num_layers, batch_size, self.in_channels)))
        q_star = x.new_zeros(1, batch_size, self.out_channels)

        for i in range(self.processing_steps):
            q, h = self.lstm(q_star, h)
            q = q.view(batch_size, 1, self.in_channels)
            e = (x * q).sum(dim=-1)  # Dot product.
            a = torch.softmax(e, dim=-1)
            a = a.view(batch_size, max_nodes, 1)
            r = (a * x).sum(dim=1, keepdim=True)
            q_star = torch.cat([q, r], dim=-1)
            q_star = q_star.view(1, batch_size, self.out_channels)

        q_star = q_star.view(batch_size, self.out_channels)
        return q_star