コード例 #1
def test_feast_conv():
    x1 = torch.randn(4, 16)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = FeaStConv(16, 32, heads=2)
    assert conv.__repr__() == 'FeaStConv(16, 32, heads=2)'

    out = conv(x1, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)

    adj = adj.sparse_resize((4, 2))
    out = conv((x1, x2), edge_index)
    assert out.size() == (2, 32)
    assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6)

    t = '(PairTensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), edge_index).tolist() == out.tolist()

    t = '(PairTensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6)
コード例 #2
def test_heat_conv():
    x = torch.randn(4, 8)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    edge_attr = torch.randn((4, 2))
    adj = SparseTensor(row=row, col=col, value=edge_attr, sparse_sizes=(4, 4))
    node_type = torch.tensor([0, 0, 1, 2])
    edge_type = torch.tensor([0, 2, 1, 2])

    conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3,
                    num_edge_types=3, edge_type_emb_dim=5, edge_dim=2,
                    edge_attr_emb_dim=6, heads=2, concat=True)
    assert str(conv) == 'HEATConv(8, 16, heads=2)'
    out = conv(x, edge_index, node_type, edge_type, edge_attr)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out)

    if is_full_test():
        t = '(Tensor, Tensor, Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(
            jit(x, edge_index, node_type, edge_type, edge_attr), out)

    conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3,
                    num_edge_types=3, edge_type_emb_dim=5, edge_dim=2,
                    edge_attr_emb_dim=6, heads=2, concat=False)
    assert str(conv) == 'HEATConv(8, 16, heads=2)'
    out = conv(x, edge_index, node_type, edge_type, edge_attr)
    assert out.size() == (4, 16)
    assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out)

    if is_full_test():
        t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, adj.t(), node_type, edge_type), out)
コード例 #3
def test_eg_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = EGConv(16, 32)
    assert conv.__repr__() == "EGConv(16, 32, aggregators=['symnorm'])"
    out = conv(x, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    conv.cached = True
    conv(x, edge_index)
    assert conv(x, edge_index).tolist() == out.tolist()
    conv(x, adj.t())
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    if is_full_test():
        t = '(Tensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x, edge_index).tolist() == out.tolist()

        t = '(Tensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
コード例 #4
def test_dna_conv():
    x = torch.randn((4, 3, 32))
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    value = torch.rand(row.size(0))
    adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
    adj1 = adj2.set_value(None)

    conv = DNAConv(32, heads=4, groups=8, dropout=0.0)
    assert conv.__repr__() == 'DNAConv(32, heads=4, groups=8)'
    out1 = conv(x, edge_index)
    assert out1.size() == (4, 32)
    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
    out2 = conv(x, edge_index, value)
    assert out2.size() == (4, 32)
    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)

    t = '(Tensor, Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out1.tolist()
    assert jit(x, edge_index, value).tolist() == out2.tolist()

    t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6)
    assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6)

    conv.cached = True
    conv(x, edge_index)
    assert conv(x, edge_index).tolist() == out1.tolist()
    conv(x, adj1.t())
    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
コード例 #5
def test_res_gated_graph_conv():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 32)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = ResGatedGraphConv(8, 32)
    assert conv.__repr__() == 'ResGatedGraphConv(8, 32)'
    out = conv(x1, edge_index)
    assert out.size() == (4, 32)
    assert conv(x1, adj.t()).tolist() == out.tolist()

    t = '(Tensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    conv = ResGatedGraphConv((8, 32), 32)
    assert conv.__repr__() == 'ResGatedGraphConv((8, 32), 32)'
    out = conv((x1, x2), edge_index)
    assert out.size() == (2, 32)
    assert conv((x1, x2), adj.t()).tolist() == out.tolist()

    t = '(PairTensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), edge_index).tolist() == out.tolist()

    t = '(PairTensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), adj.t()).tolist() == out.tolist()
コード例 #6
def test_gated_graph_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    value = torch.rand(row.size(0))
    adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
    adj1 = adj2.set_value(None)

    conv = GatedGraphConv(32, num_layers=3)
    assert conv.__repr__() == 'GatedGraphConv(32, num_layers=3)'
    out1 = conv(x, edge_index)
    assert out1.size() == (4, 32)
    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
    out2 = conv(x, edge_index, value)
    assert out2.size() == (4, 32)
    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)

    t = '(Tensor, Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out1.tolist()
    assert jit(x, edge_index, value).tolist() == out2.tolist()

    t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6)
    assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6)
コード例 #7
def test_gcn_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    value = torch.rand(row.size(0))
    adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
    adj1 = adj2.set_value(None)

    conv = GCNConv(16, 32)
    assert conv.__repr__() == 'GCNConv(16, 32)'
    out1 = conv(x, edge_index)
    assert out1.size() == (4, 32)
    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
    out2 = conv(x, edge_index, value)
    assert out2.size() == (4, 32)
    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x, edge_index).tolist() == out1.tolist()
        assert jit(x, edge_index, value).tolist() == out2.tolist()

        t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6)
        assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6)

    conv.cached = True
    conv(x, edge_index)
    assert conv(x, edge_index).tolist() == out1.tolist()
    conv(x, adj1.t())
    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
コード例 #8
def test_pna_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    value = torch.rand(row.size(0), 3)
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    conv = PNAConv(16,
                   deg=torch.tensor([0, 3, 0, 1]),
    assert conv.__repr__() == 'PNAConv(16, 32, towers=4, edge_dim=3)'
    out = conv(x, edge_index, value)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, edge_index, value), out, atol=1e-6)

    t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
コード例 #9
def test_lg_conv():
    x = torch.randn(4, 8)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.rand(row.size(0))
    adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
    adj1 = adj2.set_value(None)

    conv = LGConv()
    assert str(conv) == 'LGConv()'
    out1 = conv(x, edge_index)
    assert out1.size() == (4, 8)
    assert torch.allclose(conv(x, adj1.t()), out1)
    out2 = conv(x, edge_index, value)
    assert out2.size() == (4, 8)
    assert torch.allclose(conv(x, adj2.t()), out2)

    t = '(Tensor, Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, edge_index), out1)
    assert torch.allclose(jit(x, edge_index, value), out2)

    t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj1.t()), out1)
    assert torch.allclose(jit(x, adj2.t()), out2)
コード例 #10
def test_ppf_conv():
    x1 = torch.randn(4, 16)
    pos1 = torch.randn(4, 3)
    pos2 = torch.randn(2, 3)
    n1 = F.normalize(torch.rand(4, 3), dim=-1)
    n2 = F.normalize(torch.rand(2, 3), dim=-1)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32))
    global_nn = Seq(Lin(32, 32))
    conv = PPFConv(local_nn, global_nn)
    assert conv.__repr__() == (
        '  (0): Linear(in_features=20, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=32, bias=True)\n'
        '), global_nn=Sequential(\n'
        '  (0): Linear(in_features=32, out_features=32, bias=True)\n'
    out = conv(x1, pos1, n1, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, pos1, n1, adj.t()), out, atol=1e-6)

    t = '(OptTensor, Tensor, Tensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, pos1, n1, edge_index).tolist() == out.tolist()

    t = '(OptTensor, Tensor, Tensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x1, pos1, n1, adj.t()), out, atol=1e-6)

    adj = adj.sparse_resize((4, 2))
    out = conv(x1, (pos1, pos2), (n1, n2), edge_index)
    assert out.size() == (2, 32)
    assert conv((x1, None), (pos1, pos2), (n1, n2),
                edge_index).tolist() == out.tolist()
    assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj.t()),
    assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj.t()),

    t = '(PairOptTensor, PairTensor, PairTensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, None), (pos1, pos2), (n1, n2),
               edge_index).tolist() == out.tolist()

    t = '(PairOptTensor, PairTensor, PairTensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit((x1, None), (pos1, pos2), (n1, n2), adj.t()),
コード例 #11
def test_cg_conv():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = CGConv(8)
    assert conv.__repr__() == 'CGConv(8, dim=0)'
    out = conv(x1, edge_index)
    assert out.size() == (4, 8)
    assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
    assert conv(x1, adj.t()).tolist() == out.tolist()

    t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, edge_index).tolist() == out.tolist()
    assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()

    t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    conv = CGConv((8, 16))
    assert conv.__repr__() == 'CGConv((8, 16), dim=0)'
    out = conv((x1, x2), edge_index)
    assert out.size() == (2, 16)
    assert conv((x1, x2), edge_index, size=(4, 2)).tolist() == out.tolist()
    assert conv((x1, x2), adj.t()).tolist() == out.tolist()

    t = '(PairTensor, Tensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), edge_index).tolist() == out.tolist()
    assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out.tolist()

    t = '(PairTensor, SparseTensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), adj.t()).tolist() == out.tolist()

    # Test batch_norm true:
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
    conv = CGConv(8, batch_norm=True)
    assert conv.__repr__() == 'CGConv(8, dim=0)'
    out = conv(x1, edge_index)
    assert out.size() == (4, 8)
    assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
    assert conv(x1, adj.t()).tolist() == out.tolist()

    t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, edge_index).tolist() == out.tolist()
    assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
コード例 #12
def test_point_transformer_conv():
    x1 = torch.rand(4, 16)
    x2 = torch.randn(2, 8)
    pos1 = torch.rand(4, 3)
    pos2 = torch.randn(2, 3)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = PointTransformerConv(in_channels=16, out_channels=32)
    assert str(conv) == 'PointTransformerConv(16, 32)'

    out = conv(x1, pos1, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)

    if is_full_test():
        t = '(Tensor, Tensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, pos1, edge_index).tolist() == out.tolist()

        t = '(Tensor, Tensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6)

    pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32))
    attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32))
    conv = PointTransformerConv(16, 32, pos_nn, attn_nn)

    out = conv(x1, pos1, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)

    conv = PointTransformerConv((16, 8), 32)
    adj = adj.sparse_resize((4, 2))

    out = conv((x1, x2), (pos1, pos2), edge_index)
    assert out.size() == (2, 32)
    assert torch.allclose(conv((x1, x2), (pos1, pos2), adj.t()),

    if is_full_test():
        t = '(PairTensor, PairTensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), (pos1, pos2), edge_index).tolist() == out.tolist()

        t = '(PairTensor, PairTensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit((x1, x2), (pos1, pos2), adj.t()),
コード例 #13
def test_fa_conv():
    x = torch.randn(4, 16)
    x_0 = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = FAConv(16, eps=1.0)
    assert conv.__repr__() == 'FAConv(16, eps=1.0)'
    out = conv(x, x_0, edge_index)
    assert out.size() == (4, 16)
    assert torch.allclose(conv(x, x_0, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor, Tensor, OptTensor, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, x_0, edge_index).tolist() == out.tolist()

    t = '(Tensor, Tensor, SparseTensor, OptTensor, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, x_0, adj.t()), out, atol=1e-6)

    # Test `return_attention_weights`.
    result = conv(x, x_0, edge_index, return_attention_weights=True)
    assert result[0].tolist() == out.tolist()
    assert result[1][0].size() == (2, 10)
    assert result[1][1].size() == (10, )
    assert conv._alpha is None

    result = conv(x, x_0, adj.t(), return_attention_weights=True)
    assert torch.allclose(result[0], out, atol=1e-6)
    assert result[1].sizes() == [4, 4] and result[1].nnz() == 10
    assert conv._alpha is None

    t = ('(Tensor, Tensor, Tensor, OptTensor, bool) '
         '-> Tuple[Tensor, Tuple[Tensor, Tensor]]')
    jit = torch.jit.script(conv.jittable(t))
    result = jit(x, x_0, edge_index, return_attention_weights=True)
    assert result[0].tolist() == out.tolist()
    assert result[1][0].size() == (2, 10)
    assert result[1][1].size() == (10, )
    assert conv._alpha is None

    t = ('(Tensor, Tensor, SparseTensor, OptTensor, bool) '
         '-> Tuple[Tensor, SparseTensor]')
    jit = torch.jit.script(conv.jittable(t))
    result = jit(x, x_0, adj.t(), return_attention_weights=True)
    assert torch.allclose(result[0], out, atol=1e-6)
    assert result[1].sizes() == [4, 4] and result[1].nnz() == 10
    assert conv._alpha is None
コード例 #14
def test_sparse_tensor_spspmm(dtype, device):
    x = SparseTensor(
            [0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
            [0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15],
            1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5,
            -2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5,
            -2**-0.5, 2**-0.5, -2**-0.5

    expected = torch.eye(10, dtype=dtype, device=device)

    out = x @ x.to_dense().t()
    assert torch.allclose(out, expected, atol=1e-7)

    out = x @ x.t()
    out = out.to_dense()
    assert torch.allclose(out, expected, atol=1e-7)
コード例 #15
ファイル: sampler.py プロジェクト: zheng-da/pytorch_geometric
    def __init__(self,
                 edge_index: torch.Tensor,
                 sizes: List[int],
                 node_idx: Optional[torch.Tensor] = None,
                 num_nodes: Optional[int] = None,
                 flow: str = "source_to_target",

        N = int(edge_index.max() + 1) if num_nodes is None else num_nodes
        edge_attr = torch.arange(edge_index.size(1))
        adj = SparseTensor(row=edge_index[0],
                           sparse_sizes=(N, N),
        adj = adj.t() if flow == 'source_to_target' else adj
        self.adj = adj.to('cpu')

        if node_idx is None:
            node_idx = torch.arange(N)
        elif node_idx.dtype == torch.bool:
            node_idx = node_idx.nonzero().view(-1)

        self.sizes = sizes
        self.flow = flow
        assert self.flow in ['source_to_target', 'target_to_source']

        super(NeighborSampler, self).__init__(node_idx.tolist(),
コード例 #16
def test_rgat_conv_jittable():
    x = torch.randn(4, 8)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    edge_attr = torch.randn((4, 8))
    adj = SparseTensor(row=row, col=col, value=edge_attr, sparse_sizes=(4, 4))
    edge_type = torch.tensor([0, 2, 1, 2])

    conv = RGATConv(8,

    out = conv(x, edge_index, edge_type, edge_attr)
    assert out.size() == (4, 40)
    assert torch.allclose(conv(x, adj.t(), edge_type), out)

    t = '(Tensor, Tensor, OptTensor, OptTensor, Size, NoneType) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, edge_index, edge_type),
                          conv(x, edge_index, edge_type))
コード例 #17
 def forward(self, x, edge_index, edge_weight=None):
     x = self.bn(x)
     # edge_index, edge_weight = self.adj_dropout(edge_index=edge_index, edge_weight=edge_weight)
     if self.use_sparse:
         edge_index = SparseTensor(row=edge_index[0],
                                   sparse_sizes=(x.shape[0], x.shape[0]))
         # x = edge_index.t().to_dense()@x
         for i, layer in enumerate(self.gcn_layer):
             x = layer(x=x,
             # x = self.bn_layer[i](x)
             x = F.relu(x, inplace=True)
             # x = layer(x=x, edge_index=edge_index.t(), edge_weight=edge_weight)
             # x = F.relu(x, inplace=True)
             # x = self.bns[i](x)
             # x = F.dropout(x, 0.1)
         for layer in self.gcn_layer:
             x = layer(x=x, edge_index=edge_index, edge_weight=edge_weight)
             x = F.relu(x, inplace=True)
     # x = self.bn_layer[0](x)
     x = self.project(x)
     return x
コード例 #18
ファイル: asap.py プロジェクト: xugangwu95/pytorch_geometric
    def forward(self, x, edge_index, edge_weight=None, batch=None):
        N = x.size(0)

        edge_index, edge_weight = add_remaining_self_loops(edge_index,

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        x = x.unsqueeze(-1) if x.dim() == 1 else x

        x_pool = x
        if self.GNN is not None:
            x_pool = self.gnn_intra_cluster(x=x,

        x_pool_j = x_pool[edge_index[0]]
        x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max')
        x_q = self.lin(x_q)[edge_index[1]]

        score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1)
        score = F.leaky_relu(score, self.negative_slope)
        score = softmax(score, edge_index[1], num_nodes=N)

        # Sample attention coefficients stochastically.
        score = F.dropout(score, p=self.dropout, training=self.training)

        v_j = x[edge_index[0]] * score.view(-1, 1)
        x = scatter(v_j, edge_index[1], dim=0, reduce='add')

        # Cluster selection.
        fitness = self.gnn_score(x, edge_index).sigmoid().view(-1)
        perm = topk(fitness, self.ratio, batch)
        x = x[perm] * fitness[perm].view(-1, 1)
        batch = batch[perm]

        # Graph coarsening.
        row, col = edge_index
        A = SparseTensor(row=row,
                         sparse_sizes=(N, N))
        S = SparseTensor(row=row, col=col, value=score, sparse_sizes=(N, N))
        S = S[:, perm]

        A = S.t() @ A @ S

        if self.add_self_loops:
            A = A.fill_diag(1.)
            A = A.remove_diag()

        row, col, edge_weight = A.coo()
        edge_index = torch.stack([row, col], dim=0)

        return x, edge_index, edge_weight, batch, perm
コード例 #19
def test_gen_conv(aggr):
    x1 = torch.randn(4, 16)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.randn(row.size(0), 16)
    adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
    adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    conv = GENConv(16, 32, aggr)
    assert conv.__repr__() == f'GENConv(16, 32, aggr={aggr})'
    out11 = conv(x1, edge_index)
    assert out11.size() == (4, 32)
    assert conv(x1, edge_index, size=(4, 4)).tolist() == out11.tolist()
    assert conv(x1, adj1.t()).tolist() == out11.tolist()

    out12 = conv(x1, edge_index, value)
    assert out12.size() == (4, 32)
    assert conv(x1, edge_index, value, (4, 4)).tolist() == out12.tolist()
    assert conv(x1, adj2.t()).tolist() == out12.tolist()

    t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, edge_index).tolist() == out11.tolist()
    assert jit(x1, edge_index, size=(4, 4)).tolist() == out11.tolist()
    assert jit(x1, edge_index, value).tolist() == out12.tolist()
    assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out12.tolist()

    t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, adj1.t()).tolist() == out11.tolist()
    assert jit(x1, adj2.t()).tolist() == out12.tolist()

    adj1 = adj1.sparse_resize((4, 2))
    adj2 = adj2.sparse_resize((4, 2))

    out21 = conv((x1, x2), edge_index)
    assert out21.size() == (2, 32)
    assert conv((x1, x2), edge_index, size=(4, 2)).tolist() == out21.tolist()
    assert conv((x1, x2), adj1.t()).tolist() == out21.tolist()

    out22 = conv((x1, x2), edge_index, value)
    assert out22.size() == (2, 32)
    assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out22.tolist()
    assert conv((x1, x2), adj2.t()).tolist() == out22.tolist()

    t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), edge_index).tolist() == out21.tolist()
    assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out21.tolist()
    assert jit((x1, x2), edge_index, value).tolist() == out22.tolist()
    assert jit((x1, x2), edge_index, value, (4, 2)).tolist() == out22.tolist()

    t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), adj1.t()).tolist() == out21.tolist()
    assert jit((x1, x2), adj2.t()).tolist() == out22.tolist()
コード例 #20
def test_my_default_arg_conv():
    x = torch.randn(4, 1)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = MyDefaultArgConv()
    assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0]
    assert conv(x, adj.t()).view(-1).tolist() == [0, 0, 0, 0]
コード例 #21
def test_agnn_conv(requires_grad):
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = AGNNConv(requires_grad=requires_grad)
    assert conv.__repr__() == 'AGNNConv()'
    out = conv(x, edge_index)
    assert out.size() == (4, 16)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
コード例 #22
def test_cluster_gcn_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = ClusterGCNConv(16, 32, diag_lambda=1.)
    assert conv.__repr__() == 'ClusterGCNConv(16, 32, diag_lambda=1.0)'
    out = conv(x, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-5)

    t = '(Tensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj.t()), out, atol=1e-5)
コード例 #23
def test_arma_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = ARMAConv(16, 32, num_stacks=8, num_layers=4)
    assert conv.__repr__() == 'ARMAConv(16, 32, num_stacks=8, num_layers=4)'
    out = conv(x, edge_index)
    assert out.size() == (4, 32)
    assert conv(x, adj.t()).tolist() == out.tolist()

    t = '(Tensor, Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)
コード例 #24
def test_appnp():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = APPNP(K=10, alpha=0.1)
    assert conv.__repr__() == 'APPNP(K=10, alpha=0.1)'
    out = conv(x, edge_index)
    assert out.size() == (4, 16)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x, edge_index).tolist() == out.tolist()

    t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
コード例 #25
def test_message_passing_with_aggr_module(aggr_module):
    x = torch.randn(4, 8)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = MyAggregatorConv(aggr=aggr_module)
    assert isinstance(conv.aggr_module, aggr.Aggregation)
    out = conv(x, edge_index)
    assert out.size(0) == 4 and out.size(1) in {8, 16}
    assert torch.allclose(conv(x, adj.t()), out)
コード例 #26
def test_pdn_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    edge_attr = torch.randn(6, 8)
    adj = SparseTensor(row=row, col=col, value=edge_attr, sparse_sizes=(4, 4))

    conv = PDNConv(16, 32, edge_dim=8, hidden_channels=128)
    assert str(conv) == "PDNConv(16, 32)"
    out = conv(x, edge_index, edge_attr)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    t = '(Tensor, Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, edge_index, edge_attr), out)

    t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
コード例 #27
def test_my_edge_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='add')

    conv = MyEdgeConv()
    out = conv(x, edge_index)
    assert out.size() == (4, 16)
    assert torch.allclose(out, expected)
    assert torch.allclose(conv(x, adj.t()), out)

    t = '(Tensor, Tensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, edge_index), expected)

    t = '(Tensor, SparseTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert torch.allclose(jit(x, adj.t()), expected)
コード例 #28
def test_my_conv():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.randn(row.size(0))
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    conv = MyConv(8, 32)
    out = conv(x1, edge_index, value)
    assert out.size() == (4, 32)
    assert conv(x1, edge_index, value, (4, 4)).tolist() == out.tolist()
    assert conv(x1, adj.t()).tolist() == out.tolist()
    conv.fuse = False
    assert conv(x1, adj.t()).tolist() == out.tolist()
    conv.fuse = True

    adj = adj.sparse_resize((4, 2))
    conv = MyConv((8, 16), 32)
    out1 = conv((x1, x2), edge_index, value)
    out2 = conv((x1, None), edge_index, value, (4, 2))
    assert out1.size() == (2, 32)
    assert out2.size() == (2, 32)
    assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist()
    assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
    assert conv((x1, None), adj.t()).tolist() == out2.tolist()
    conv.fuse = False
    assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
    assert conv((x1, None), adj.t()).tolist() == out2.tolist()
    conv.fuse = True
コード例 #29
def test_cg_conv_with_edge_features():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.rand(row.size(0), 3)
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    conv = CGConv(8, dim=3)
    assert conv.__repr__() == 'CGConv(8, dim=3)'
    out = conv(x1, edge_index, value)
    assert out.size() == (4, 8)
    assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()
    assert conv(x1, adj.t()).tolist() == out.tolist()

    t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, edge_index, value).tolist() == out.tolist()
    assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()

    t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    conv = CGConv((8, 16), dim=3)
    assert conv.__repr__() == 'CGConv((8, 16), dim=3)'
    out = conv((x1, x2), edge_index, value)
    assert out.size() == (2, 16)
    assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out.tolist()
    assert conv((x1, x2), adj.t()).tolist() == out.tolist()

    t = '(PairTensor, Tensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), edge_index, value).tolist() == out.tolist()
    assert jit((x1, x2), edge_index, value, (4, 2)).tolist() == out.tolist()

    t = '(PairTensor, SparseTensor, OptTensor, Size) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2), adj.t()).tolist() == out.tolist()
コード例 #30
def test_nn_conv():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.rand(row.size(0), 3)
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    nn = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32))
    conv = NNConv(8, 32, nn=nn)
    assert conv.__repr__() == (
        'NNConv(8, 32, aggr=add, nn=Sequential(\n'
        '  (0): Linear(in_features=3, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=256, bias=True)\n'
    out = conv(x1, edge_index, value)
    assert out.size() == (4, 32)
    assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()
    assert conv(x1, adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index, value).tolist() == out.tolist()
        assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()

        t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    conv = NNConv((8, 16), 32, nn=nn)
    assert conv.__repr__() == (
        'NNConv((8, 16), 32, aggr=add, nn=Sequential(\n'
        '  (0): Linear(in_features=3, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=256, bias=True)\n'
    out1 = conv((x1, x2), edge_index, value)
    out2 = conv((x1, None), edge_index, value, (4, 2))
    assert out1.size() == (2, 32)
    assert out2.size() == (2, 32)
    assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist()
    assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
    assert conv((x1, None), adj.t()).tolist() == out2.tolist()

    if is_full_test():
        t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), edge_index, value).tolist() == out1.tolist()
        assert jit((x1, x2), edge_index, value,
                   size=(4, 2)).tolist() == out1.tolist()
        assert jit((x1, None), edge_index, value,
                   size=(4, 2)).tolist() == out2.tolist()

        t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
        assert jit((x1, None), adj.t()).tolist() == out2.tolist()