Exemplo n.º 1
0
def test_softmax_broadcasting(dtype, device):
    src = torch.randn(10, 5, dtype=dtype, device=device)
    index = tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)

    out = scatter_softmax(src, index, dim=0).view(5, 2, 5)
    out = out.sum(dim=1)
    assert torch.allclose(out, torch.ones_like(out))
Exemplo n.º 2
0
 def aggregate(self,
               inputs: torch.Tensor,
               index: torch.Tensor,
               dim_size: Optional[int] = None) -> torch.Tensor:
     out = scatter_softmax(inputs * self.beta, index, dim=self.node_dim)
     out = scatter_sum(inputs * out,
                       index,
                       dim=self.node_dim,
                       dim_size=dim_size)
     return out
Exemplo n.º 3
0
    def forward(self, x, edge_index, edge_attr, u, node_batch):
        """"""
        row, col = edge_index
        edge_batch = node_batch[row]
        edge_attr, unnormalized_wts, global_normalized_wts = self.edge_model(
            x[row], x[col], edge_attr, u, edge_batch)
        # unnormalized_wts = torch.sigmoid(unnormalized_wts)
        local_normalized_wts = scatter_softmax(unnormalized_wts, col, dim=0)
        x = self.node_model(x, edge_index, edge_attr, u, node_batch,
                            edge_batch, local_normalized_wts)

        return x, edge_attr, u, unnormalized_wts, global_normalized_wts
Exemplo n.º 4
0
    def forward(self, src, dest, edge_attr, u, edge_batch):
        # source, target: [E, F_x], where E is the number of edges.
        # edge_attr: [E, F_e]
        # u: [B, F_u], where B is the number of graphs.
        # batch: [E] with max entry B - 1.
        out = torch.cat([src, dest, edge_attr, u[edge_batch]], 1)

        out_1 = torch.cat([edge_attr, u[edge_batch]], 1)
        wts = self.weight_mlp(out_1)  # wts: [#edges, 1]
        unnormalized_wts = wts
        wts = scatter_softmax(wts.squeeze(1), edge_batch, dim=0)
        normalized_wts = wts.unsqueeze(1)
        return self.edge_mlp(out), unnormalized_wts, normalized_wts
Exemplo n.º 5
0
def test_softmax(dtype, device):
    src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
    index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)

    out = scatter_softmax(src, index)

    out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
    out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
    out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
    out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
                         dim=-1)

    expected = torch.stack([
        out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
    ],
                           dim=0).to(device)

    assert torch.allclose(out, expected)