示例#1
0
def test_nn_conv():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.rand(row.size(0), 3)
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    nn = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32))
    conv = NNConv(8, 32, nn=nn)
    assert conv.__repr__() == (
        'NNConv(8, 32, aggr=add, nn=Sequential(\n'
        '  (0): Linear(in_features=3, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=256, bias=True)\n'
        '))')
    out = conv(x1, edge_index, value)
    assert out.size() == (4, 32)
    assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()
    assert conv(x1, adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index, value).tolist() == out.tolist()
        assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist()

        t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    conv = NNConv((8, 16), 32, nn=nn)
    assert conv.__repr__() == (
        'NNConv((8, 16), 32, aggr=add, nn=Sequential(\n'
        '  (0): Linear(in_features=3, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=256, bias=True)\n'
        '))')
    out1 = conv((x1, x2), edge_index, value)
    out2 = conv((x1, None), edge_index, value, (4, 2))
    assert out1.size() == (2, 32)
    assert out2.size() == (2, 32)
    assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist()
    assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
    assert conv((x1, None), adj.t()).tolist() == out2.tolist()

    if is_full_test():
        t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), edge_index, value).tolist() == out1.tolist()
        assert jit((x1, x2), edge_index, value,
                   size=(4, 2)).tolist() == out1.tolist()
        assert jit((x1, None), edge_index, value,
                   size=(4, 2)).tolist() == out2.tolist()

        t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
        assert jit((x1, None), adj.t()).tolist() == out2.tolist()
示例#2
0
def test_dna_conv():
    channels = 32
    num_layers = 3
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    num_nodes = edge_index.max().item() + 1
    x = torch.randn((num_nodes, num_layers, channels))

    conv = DNAConv(channels, heads=4, groups=8, dropout=0.0)
    assert conv.__repr__() == 'DNAConv(32, heads=4, groups=8)'
    out = conv(x, edge_index)
    assert out.size() == (num_nodes, channels)

    if is_full_test():
        jit = torch.jit.script(conv.jittable())
        assert jit(x, edge_index).tolist() == out.tolist()

    conv = DNAConv(channels, heads=1, groups=1, dropout=0.0)
    assert conv.__repr__() == 'DNAConv(32, heads=1, groups=1)'
    out = conv(x, edge_index)
    assert out.size() == (num_nodes, channels)

    if is_full_test():
        jit = torch.jit.script(conv.jittable())
        assert jit(x, edge_index).tolist() == out.tolist()

    conv = DNAConv(channels, heads=1, groups=1, dropout=0.0, cached=True)
    out = conv(x, edge_index)
    out = conv(x, edge_index)
    assert out.size() == (num_nodes, channels)

    if is_full_test():
        jit = torch.jit.script(conv.jittable())
        assert jit(x, edge_index).tolist() == out.tolist()
示例#3
0
def test_jumping_knowledge():
    num_nodes, channels, num_layers = 100, 17, 5
    xs = list([torch.randn(num_nodes, channels) for _ in range(num_layers)])

    model = JumpingKnowledge('cat')
    assert model.__repr__() == 'JumpingKnowledge(cat)'

    out = model(xs)
    assert out.size() == (num_nodes, channels * num_layers)

    if is_full_test():
        jit = torch.jit.script(model)
        assert torch.allclose(jit(xs), out)

    model = JumpingKnowledge('max')
    assert model.__repr__() == 'JumpingKnowledge(max)'

    out = model(xs)
    assert out.size() == (num_nodes, channels)

    if is_full_test():
        jit = torch.jit.script(model)
        assert torch.allclose(jit(xs), out)

    model = JumpingKnowledge('lstm', channels, num_layers)
    assert model.__repr__() == 'JumpingKnowledge(lstm)'

    out = model(xs)
    assert out.size() == (num_nodes, channels)

    if is_full_test():
        jit = torch.jit.script(model)
        assert torch.allclose(jit(xs), out)
示例#4
0
def test_heat_conv():
    x = torch.randn(4, 8)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    edge_attr = torch.randn((4, 2))
    adj = SparseTensor(row=row, col=col, value=edge_attr, sparse_sizes=(4, 4))
    node_type = torch.tensor([0, 0, 1, 2])
    edge_type = torch.tensor([0, 2, 1, 2])

    conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3,
                    num_edge_types=3, edge_type_emb_dim=5, edge_dim=2,
                    edge_attr_emb_dim=6, heads=2, concat=True)
    assert str(conv) == 'HEATConv(8, 16, heads=2)'
    out = conv(x, edge_index, node_type, edge_type, edge_attr)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out)

    if is_full_test():
        t = '(Tensor, Tensor, Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(
            jit(x, edge_index, node_type, edge_type, edge_attr), out)

    conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3,
                    num_edge_types=3, edge_type_emb_dim=5, edge_dim=2,
                    edge_attr_emb_dim=6, heads=2, concat=False)
    assert str(conv) == 'HEATConv(8, 16, heads=2)'
    out = conv(x, edge_index, node_type, edge_type, edge_attr)
    assert out.size() == (4, 16)
    assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out)

    if is_full_test():
        t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, adj.t(), node_type, edge_type), out)
def test_signed_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv1 = SignedConv(16, 32, first_aggr=True)
    assert conv1.__repr__() == 'SignedConv(16, 32, first_aggr=True)'

    conv2 = SignedConv(32, 48, first_aggr=False)
    assert conv2.__repr__() == 'SignedConv(32, 48, first_aggr=False)'

    out1 = conv1(x, edge_index, edge_index)
    assert out1.size() == (4, 64)
    assert conv1(x, adj.t(), adj.t()).tolist() == out1.tolist()

    out2 = conv2(out1, edge_index, edge_index)
    assert out2.size() == (4, 96)
    assert conv2(out1, adj.t(), adj.t()).tolist() == out2.tolist()

    if is_full_test():
        t = '(Tensor, Tensor, Tensor) -> Tensor'
        jit1 = torch.jit.script(conv1.jittable(t))
        jit2 = torch.jit.script(conv2.jittable(t))
        assert jit1(x, edge_index, edge_index).tolist() == out1.tolist()
        assert jit2(out1, edge_index, edge_index).tolist() == out2.tolist()

        t = '(Tensor, SparseTensor, SparseTensor) -> Tensor'
        jit1 = torch.jit.script(conv1.jittable(t))
        jit2 = torch.jit.script(conv2.jittable(t))
        assert jit1(x, adj.t(), adj.t()).tolist() == out1.tolist()
        assert jit2(out1, adj.t(), adj.t()).tolist() == out2.tolist()

    adj = adj.sparse_resize((4, 2))
    assert torch.allclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2],
                          atol=1e-6)
    assert torch.allclose(conv1((x, x[:2]), adj.t(), adj.t()), out1[:2],
                          atol=1e-6)
    assert torch.allclose(conv2((out1, out1[:2]), edge_index, edge_index),
                          out2[:2], atol=1e-6)
    assert torch.allclose(conv2((out1, out1[:2]), adj.t(), adj.t()), out2[:2],
                          atol=1e-6)

    if is_full_test():
        t = '(PairTensor, Tensor, Tensor) -> Tensor'
        jit1 = torch.jit.script(conv1.jittable(t))
        jit2 = torch.jit.script(conv2.jittable(t))
        assert torch.allclose(jit1((x, x[:2]), edge_index, edge_index),
                              out1[:2], atol=1e-6)
        assert torch.allclose(jit2((out1, out1[:2]), edge_index, edge_index),
                              out2[:2], atol=1e-6)

        t = '(PairTensor, SparseTensor, SparseTensor) -> Tensor'
        jit1 = torch.jit.script(conv1.jittable(t))
        jit2 = torch.jit.script(conv2.jittable(t))
        assert torch.allclose(jit1((x, x[:2]), adj.t(), adj.t()), out1[:2],
                              atol=1e-6)
        assert torch.allclose(jit2((out1, out1[:2]), adj.t(), adj.t()),
                              out2[:2], atol=1e-6)
示例#6
0
def test_gmm_conv(separate_gaussians):
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.rand(row.size(0), 3)
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    conv = GMMConv(8,
                   32,
                   dim=3,
                   kernel_size=25,
                   separate_gaussians=separate_gaussians)
    assert conv.__repr__() == 'GMMConv(8, 32, dim=3)'
    out = conv(x1, edge_index, value)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out)
    assert torch.allclose(conv(x1, adj.t()), out)

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x1, edge_index, value), out)
        assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out)

        t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x1, adj.t()), out)

    adj = adj.sparse_resize((4, 2))
    conv = GMMConv((8, 16),
                   32,
                   dim=3,
                   kernel_size=5,
                   separate_gaussians=separate_gaussians)
    assert conv.__repr__() == 'GMMConv((8, 16), 32, dim=3)'
    out1 = conv((x1, x2), edge_index, value)
    out2 = conv((x1, None), edge_index, value, (4, 2))
    assert out1.size() == (2, 32)
    assert out2.size() == (2, 32)
    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)
    assert torch.allclose(conv((x1, x2), adj.t()), out1)
    assert torch.allclose(conv((x1, None), adj.t()), out2)

    if is_full_test():
        t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit((x1, x2), edge_index, value), out1)
        assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),
                              out1)
        assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),
                              out2)

        t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit((x1, x2), adj.t()), out1)
        assert torch.allclose(jit((x1, None), adj.t()), out2)
示例#7
0
def test_ppf_conv():
    x1 = torch.randn(4, 16)
    pos1 = torch.randn(4, 3)
    pos2 = torch.randn(2, 3)
    n1 = F.normalize(torch.rand(4, 3), dim=-1)
    n2 = F.normalize(torch.rand(2, 3), dim=-1)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32))
    global_nn = Seq(Lin(32, 32))
    conv = PPFConv(local_nn, global_nn)
    assert conv.__repr__() == (
        'PPFConv(local_nn=Sequential(\n'
        '  (0): Linear(in_features=20, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=32, bias=True)\n'
        '), global_nn=Sequential(\n'
        '  (0): Linear(in_features=32, out_features=32, bias=True)\n'
        '))')
    out = conv(x1, pos1, n1, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, pos1, n1, adj.t()), out, atol=1e-6)

    if is_full_test():
        t = '(OptTensor, Tensor, Tensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, pos1, n1, edge_index).tolist() == out.tolist()

        t = '(OptTensor, Tensor, Tensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x1, pos1, n1, adj.t()), out, atol=1e-6)

    adj = adj.sparse_resize((4, 2))
    out = conv(x1, (pos1, pos2), (n1, n2), edge_index)
    assert out.size() == (2, 32)
    assert conv((x1, None), (pos1, pos2), (n1, n2),
                edge_index).tolist() == out.tolist()
    assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj.t()), out,
                          atol=1e-6)
    assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj.t()),
                          out, atol=1e-6)

    if is_full_test():
        t = '(PairOptTensor, PairTensor, PairTensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, None), (pos1, pos2), (n1, n2),
                   edge_index).tolist() == out.tolist()

        t = '(PairOptTensor, PairTensor, PairTensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit((x1, None), (pos1, pos2), (n1, n2), adj.t()),
                              out, atol=1e-6)
示例#8
0
def test_point_transformer_conv():
    x1 = torch.rand(4, 16)
    x2 = torch.randn(2, 8)
    pos1 = torch.rand(4, 3)
    pos2 = torch.randn(2, 3)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = PointTransformerConv(in_channels=16, out_channels=32)
    assert str(conv) == 'PointTransformerConv(16, 32)'

    out = conv(x1, pos1, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)

    if is_full_test():
        t = '(Tensor, Tensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, pos1, edge_index).tolist() == out.tolist()

        t = '(Tensor, Tensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6)

    pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32))
    attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32))
    conv = PointTransformerConv(16, 32, pos_nn, attn_nn)

    out = conv(x1, pos1, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6)

    conv = PointTransformerConv((16, 8), 32)
    adj = adj.sparse_resize((4, 2))

    out = conv((x1, x2), (pos1, pos2), edge_index)
    assert out.size() == (2, 32)
    assert torch.allclose(conv((x1, x2), (pos1, pos2), adj.t()),
                          out,
                          atol=1e-6)

    if is_full_test():
        t = '(PairTensor, PairTensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), (pos1, pos2), edge_index).tolist() == out.tolist()

        t = '(PairTensor, PairTensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit((x1, x2), (pos1, pos2), adj.t()),
                              out,
                              atol=1e-6)
示例#9
0
def test_gin_conv():
    x1 = torch.randn(4, 16)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32))
    conv = GINConv(nn, train_eps=True)
    assert conv.__repr__() == (
        'GINConv(nn=Sequential(\n'
        '  (0): Linear(in_features=16, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=32, bias=True)\n'
        '))')
    out = conv(x1, edge_index)
    assert out.size() == (4, 32)
    assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
    assert conv(x1, adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(Tensor, Tensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index).tolist() == out.tolist()
        assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()

        t = '(Tensor, SparseTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    out1 = conv((x1, x2), edge_index)
    out2 = conv((x1, None), edge_index, (4, 2))
    assert out1.size() == (2, 32)
    assert out2.size() == (2, 32)
    assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist()
    assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
    assert conv((x1, None), adj.t()).tolist() == out2.tolist()

    if is_full_test():
        t = '(OptPairTensor, Tensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), edge_index).tolist() == out1.tolist()
        assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist()
        assert jit((x1, None), edge_index,
                   size=(4, 2)).tolist() == out2.tolist()

        t = '(OptPairTensor, SparseTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
        assert jit((x1, None), adj.t()).tolist() == out2.tolist()
示例#10
0
def test_spline_conv():
    warnings.filterwarnings('ignore', '.*non-optimized CPU version.*')

    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.rand(row.size(0), 3)
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    conv = SplineConv(8, 32, dim=3, kernel_size=5)
    assert conv.__repr__() == 'SplineConv(8, 32, dim=3)'
    out = conv(x1, edge_index, value)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out)
    assert torch.allclose(conv(x1, adj.t()), out)

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x1, edge_index, value), out)
        assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out)

        t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x1, adj.t()), out)

    adj = adj.sparse_resize((4, 2))
    conv = SplineConv((8, 16), 32, dim=3, kernel_size=5)
    assert conv.__repr__() == 'SplineConv((8, 16), 32, dim=3)'
    out1 = conv((x1, x2), edge_index, value)
    out2 = conv((x1, None), edge_index, value, (4, 2))
    assert out1.size() == (2, 32)
    assert out2.size() == (2, 32)
    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)
    assert torch.allclose(conv((x1, x2), adj.t()), out1)
    assert torch.allclose(conv((x1, None), adj.t()), out2)

    if is_full_test():
        t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit((x1, x2), edge_index, value), out1)
        assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),
                              out1)
        assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),
                              out2)

        t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit((x1, x2), adj.t()), out1)
        assert torch.allclose(jit((x1, None), adj.t()), out2)
示例#11
0
def test_cg_conv():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = CGConv(8)
    assert conv.__repr__() == 'CGConv(8, dim=0)'
    out = conv(x1, edge_index)
    assert out.size() == (4, 8)
    assert conv(x1, adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index).tolist() == out.tolist()

        t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    conv = CGConv((8, 16))
    assert conv.__repr__() == 'CGConv((8, 16), dim=0)'
    out = conv((x1, x2), edge_index)
    assert out.size() == (2, 16)
    assert conv((x1, x2), adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(PairTensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), edge_index).tolist() == out.tolist()

        t = '(PairTensor, SparseTensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), adj.t()).tolist() == out.tolist()

    # Test batch_norm true:
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
    conv = CGConv(8, batch_norm=True)
    assert conv.__repr__() == 'CGConv(8, dim=0)'
    out = conv(x1, edge_index)
    assert out.size() == (4, 8)
    assert conv(x1, adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index).tolist() == out.tolist()
示例#12
0
def test_lg_conv():
    x = torch.randn(4, 8)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.rand(row.size(0))
    adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
    adj1 = adj2.set_value(None)

    conv = LGConv()
    assert str(conv) == 'LGConv()'
    out1 = conv(x, edge_index)
    assert out1.size() == (4, 8)
    assert torch.allclose(conv(x, adj1.t()), out1)
    out2 = conv(x, edge_index, value)
    assert out2.size() == (4, 8)
    assert torch.allclose(conv(x, adj2.t()), out2)

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, edge_index), out1)
        assert torch.allclose(jit(x, edge_index, value), out2)

        t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, adj1.t()), out1)
        assert torch.allclose(jit(x, adj2.t()), out2)
def test_gravnet_conv():
    x1 = torch.randn(8, 16)
    x2 = torch.randn(4, 16)
    batch1 = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
    batch2 = torch.tensor([0, 0, 1, 1])

    conv = GravNetConv(16, 32, space_dimensions=4, propagate_dimensions=8, k=2)
    assert conv.__repr__() == 'GravNetConv(16, 32, k=2)'

    out11 = conv(x1)
    assert out11.size() == (8, 32)

    out12 = conv(x1, batch1)
    assert out12.size() == (8, 32)

    out21 = conv((x1, x2))
    assert out21.size() == (4, 32)

    out22 = conv((x1, x2), (batch1, batch2))
    assert out22.size() == (4, 32)

    if is_full_test():
        t = '(Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1).tolist() == out11.tolist()
        assert jit(x1, batch1).tolist() == out12.tolist()

        t = '(PairTensor, Optional[PairTensor]) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2)).tolist() == out21.tolist()
        assert jit((x1, x2), (batch1, batch2)).tolist() == out22.tolist()

        torch.jit.script(conv.jittable())  # Test without explicit typing.
示例#14
0
def test_eg_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = EGConv(16, 32)
    assert conv.__repr__() == "EGConv(16, 32, aggregators=['symnorm'])"
    out = conv(x, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    conv.cached = True
    conv(x, edge_index)
    assert conv(x, edge_index).tolist() == out.tolist()
    conv(x, adj.t())
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    if is_full_test():
        t = '(Tensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x, edge_index).tolist() == out.tolist()

        t = '(Tensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
示例#15
0
def test_mlp(batch_norm, act_first, plain_last):
    x = torch.randn(4, 16)

    torch.manual_seed(12345)
    mlp = MLP(
        [16, 32, 32, 64],
        batch_norm=batch_norm,
        act_first=act_first,
        plain_last=plain_last,
    )
    assert str(mlp) == 'MLP(16, 32, 32, 64)'
    out = mlp(x)
    assert out.size() == (4, 64)

    if is_full_test():
        jit = torch.jit.script(mlp)
        assert torch.allclose(jit(x), out)

    torch.manual_seed(12345)
    mlp = MLP(
        16,
        hidden_channels=32,
        out_channels=64,
        num_layers=3,
        batch_norm=batch_norm,
        act_first=act_first,
        plain_last=plain_last,
    )
    assert torch.allclose(mlp(x), out)
示例#16
0
def test_pna_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    value = torch.rand(row.size(0), 3)
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    conv = PNAConv(16,
                   32,
                   aggregators,
                   scalers,
                   deg=torch.tensor([0, 3, 0, 1]),
                   edge_dim=3,
                   towers=4)
    assert conv.__repr__() == 'PNAConv(16, 32, towers=4, edge_dim=3)'
    out = conv(x, edge_index, value)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, edge_index, value), out, atol=1e-6)

        t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
示例#17
0
def test_gcn_conv():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    row, col = edge_index
    value = torch.rand(row.size(0))
    adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))
    adj1 = adj2.set_value(None)

    conv = GCNConv(16, 32)
    assert conv.__repr__() == 'GCNConv(16, 32)'
    out1 = conv(x, edge_index)
    assert out1.size() == (4, 32)
    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
    out2 = conv(x, edge_index, value)
    assert out2.size() == (4, 32)
    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x, edge_index).tolist() == out1.tolist()
        assert jit(x, edge_index, value).tolist() == out2.tolist()

        t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6)
        assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6)

    conv.cached = True
    conv(x, edge_index)
    assert conv(x, edge_index).tolist() == out1.tolist()
    conv(x, adj1.t())
    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
示例#18
0
def test_sage_conv(project):
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = SAGEConv(8, 32, project=project)
    assert str(conv) == 'SAGEConv(8, 32, aggr=mean)'
    out = conv(x1, edge_index)
    assert out.size() == (4, 32)
    assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
    assert conv(x1, adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(Tensor, Tensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index).tolist() == out.tolist()
        assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()

        t = '(Tensor, SparseTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    conv = SAGEConv((8, 16), 32, project=project)
    assert str(conv) == 'SAGEConv((8, 16), 32, aggr=mean)'
    out1 = conv((x1, x2), edge_index)
    out2 = conv((x1, None), edge_index, (4, 2))
    assert out1.size() == (2, 32)
    assert out2.size() == (2, 32)
    assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist()
    assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
    assert conv((x1, None), adj.t()).tolist() == out2.tolist()

    if is_full_test():
        t = '(OptPairTensor, Tensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), edge_index).tolist() == out1.tolist()
        assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist()
        assert jit((x1, None), edge_index,
                   size=(4, 2)).tolist() == out2.tolist()

        t = '(OptPairTensor, SparseTensor, Size) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
        assert jit((x1, None), adj.t()).tolist() == out2.tolist()
def test_batch_norm(conf):
    x = torch.randn(100, 16)

    norm = BatchNorm(16, affine=conf, track_running_stats=conf)
    assert norm.__repr__() == 'BatchNorm(16)'

    if is_full_test():
        torch.jit.script(norm)

    out = norm(x)
    assert out.size() == (100, 16)
def test_diff_group_norm():
    x = torch.randn(6, 16)

    norm = DiffGroupNorm(16, groups=4, lamda=0)
    assert norm.__repr__() == 'DiffGroupNorm(16, groups=4)'

    assert norm(x).tolist() == x.tolist()

    if is_full_test():
        jit = torch.jit.script(norm)
        assert jit(x).tolist() == x.tolist()

    norm = DiffGroupNorm(16, groups=4, lamda=0.01)
    assert norm.__repr__() == 'DiffGroupNorm(16, groups=4)'

    out = norm(x)
    assert out.size() == x.size()

    if is_full_test():
        jit = torch.jit.script(norm)
        assert jit(x).tolist() == out.tolist()
示例#21
0
def test_hetero_linear():
    x = torch.randn((3, 16))
    node_type = torch.tensor([0, 1, 2])

    lin = HeteroLinear(in_channels=16, out_channels=32, num_types=3)
    assert str(lin) == 'HeteroLinear(16, 32, bias=True)'

    out = lin(x, node_type)
    assert out.size() == (3, 32)

    if is_full_test():
        jit = torch.jit.script(lin)
        assert torch.allclose(jit(x, node_type), out)
示例#22
0
def test_cheb_conv():
    in_channels, out_channels = (16, 32)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    num_nodes = edge_index.max().item() + 1
    edge_weight = torch.rand(edge_index.size(1))
    x = torch.randn((num_nodes, in_channels))

    conv = ChebConv(in_channels, out_channels, K=3)
    assert conv.__repr__() == 'ChebConv(16, 32, K=3, normalization=sym)'
    out1 = conv(x, edge_index)
    assert out1.size() == (num_nodes, out_channels)
    out2 = conv(x, edge_index, edge_weight)
    assert out2.size() == (num_nodes, out_channels)
    out3 = conv(x, edge_index, edge_weight, lambda_max=3.0)
    assert out3.size() == (num_nodes, out_channels)

    if is_full_test():
        jit = torch.jit.script(conv.jittable())
        assert jit(x, edge_index).tolist() == out1.tolist()
        assert jit(x, edge_index, edge_weight).tolist() == out2.tolist()
        assert jit(x, edge_index, edge_weight,
                   lambda_max=torch.tensor(3.0)).tolist() == out3.tolist()

    batch = torch.tensor([0, 0, 1, 1])
    edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])
    num_nodes = edge_index.max().item() + 1
    edge_weight = torch.rand(edge_index.size(1))
    x = torch.randn((num_nodes, in_channels))
    lambda_max = torch.tensor([2.0, 3.0])

    out4 = conv(x, edge_index, edge_weight, batch)
    assert out4.size() == (num_nodes, out_channels)
    out5 = conv(x, edge_index, edge_weight, batch, lambda_max)
    assert out5.size() == (num_nodes, out_channels)

    if is_full_test():
        assert jit(x, edge_index, edge_weight, batch).tolist() == out4.tolist()
        assert jit(x, edge_index, edge_weight, batch,
                   lambda_max).tolist() == out5.tolist()
示例#23
0
def test_cg_conv_with_edge_features():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    value = torch.rand(row.size(0), 3)
    adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

    conv = CGConv(8, dim=3)
    assert conv.__repr__() == 'CGConv(8, dim=3)'
    out = conv(x1, edge_index, value)
    assert out.size() == (4, 8)
    assert conv(x1, adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index, value).tolist() == out.tolist()

        t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    conv = CGConv((8, 16), dim=3)
    assert conv.__repr__() == 'CGConv((8, 16), dim=3)'
    out = conv((x1, x2), edge_index, value)
    assert out.size() == (2, 16)
    assert conv((x1, x2), adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(PairTensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), edge_index, value).tolist() == out.tolist()

        t = '(PairTensor, SparseTensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), adj.t()).tolist() == out.tolist()
def test_res_gated_graph_conv():
    x1 = torch.randn(4, 8)
    x2 = torch.randn(2, 32)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = ResGatedGraphConv(8, 32)
    assert conv.__repr__() == 'ResGatedGraphConv(8, 32)'
    out = conv(x1, edge_index)
    assert out.size() == (4, 32)
    assert conv(x1, adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(Tensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index).tolist() == out.tolist()

        t = '(Tensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, adj.t()).tolist() == out.tolist()

    adj = adj.sparse_resize((4, 2))
    conv = ResGatedGraphConv((8, 32), 32)
    assert conv.__repr__() == 'ResGatedGraphConv((8, 32), 32)'
    out = conv((x1, x2), edge_index)
    assert out.size() == (2, 32)
    assert conv((x1, x2), adj.t()).tolist() == out.tolist()

    if is_full_test():
        t = '(PairTensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), edge_index).tolist() == out.tolist()

        t = '(PairTensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), adj.t()).tolist() == out.tolist()
def test_feast_conv():
    x1 = torch.randn(4, 16)
    x2 = torch.randn(2, 16)
    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
    row, col = edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

    conv = FeaStConv(16, 32, heads=2)
    assert conv.__repr__() == 'FeaStConv(16, 32, heads=2)'

    out = conv(x1, edge_index)
    assert out.size() == (4, 32)
    assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)

    if is_full_test():
        t = '(Tensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x1, edge_index).tolist() == out.tolist()

        t = '(Tensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)

    adj = adj.sparse_resize((4, 2))
    out = conv((x1, x2), edge_index)
    assert out.size() == (2, 32)
    assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6)

    if is_full_test():
        t = '(PairTensor, Tensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit((x1, x2), edge_index).tolist() == out.tolist()

        t = '(PairTensor, SparseTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6)
示例#26
0
def test_le_conv():
    in_channels, out_channels = (16, 32)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
    num_nodes = edge_index.max().item() + 1
    x = torch.randn((num_nodes, in_channels))

    conv = LEConv(in_channels, out_channels)
    assert conv.__repr__() == 'LEConv(16, 32)'
    out = conv(x, edge_index)
    assert out.size() == (num_nodes, out_channels)

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(conv.jittable(t))
        assert jit(x, edge_index).tolist() == out.tolist()
示例#27
0
def test_pair_norm(scale_individually):
    x = torch.randn(100, 16)
    batch = torch.zeros(100, dtype=torch.long)

    norm = PairNorm(scale_individually=scale_individually)
    assert norm.__repr__() == 'PairNorm()'

    if is_full_test():
        torch.jit.script(norm)

    out1 = norm(x)
    assert out1.size() == (100, 16)

    out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0))
    assert torch.allclose(out1, out2[:100], atol=1e-6)
    assert torch.allclose(out1, out2[100:], atol=1e-6)
示例#28
0
def test_gcn_conv_with_decomposed_layers():
    x = torch.randn(4, 16)
    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])

    conv = GCNConv(16, 32)

    decomposed_conv = copy.deepcopy(conv)
    decomposed_conv.decomposed_layers = 2

    out1 = conv(x, edge_index)
    out2 = decomposed_conv(x, edge_index)
    assert torch.allclose(out1, out2)

    if is_full_test():
        t = '(Tensor, Tensor, OptTensor) -> Tensor'
        jit = torch.jit.script(decomposed_conv.jittable(t))
        assert jit(x, edge_index).tolist() == out1.tolist()
示例#29
0
def test_layer_norm(affine):
    x = torch.randn(100, 16)
    batch = torch.zeros(100, dtype=torch.long)

    norm = LayerNorm(16, affine=affine)
    assert norm.__repr__() == 'LayerNorm(16)'

    if is_full_test():
        torch.jit.script(norm)

    out1 = norm(x)
    assert out1.size() == (100, 16)
    assert torch.allclose(norm(x, batch), out1, atol=1e-6)

    out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0))
    assert torch.allclose(out1, out2[:100], atol=1e-6)
    assert torch.allclose(out1, out2[100:], atol=1e-6)
def test_instance_norm(conf):
    batch = torch.zeros(100, dtype=torch.long)

    x1 = torch.randn(100, 16)
    x2 = torch.randn(100, 16)

    norm1 = InstanceNorm(16, affine=conf, track_running_stats=conf)
    norm2 = InstanceNorm(16, affine=conf, track_running_stats=conf)
    assert norm1.__repr__() == 'InstanceNorm(16)'

    if is_full_test():
        torch.jit.script(norm1)

    out1 = norm1(x1)
    out2 = norm2(x1, batch)
    assert out1.size() == (100, 16)
    assert torch.allclose(out1, out2, atol=1e-7)
    if conf:
        assert torch.allclose(norm1.running_mean, norm2.running_mean)
        assert torch.allclose(norm1.running_var, norm2.running_var)

    out1 = norm1(x2)
    out2 = norm2(x2, batch)
    assert torch.allclose(out1, out2, atol=1e-7)
    if conf:
        assert torch.allclose(norm1.running_mean, norm2.running_mean)
        assert torch.allclose(norm1.running_var, norm2.running_var)

    norm1.eval()
    norm2.eval()

    out1 = norm1(x1)
    out2 = norm2(x1, batch)
    assert torch.allclose(out1, out2, atol=1e-7)

    out1 = norm1(x2)
    out2 = norm2(x2, batch)
    assert torch.allclose(out1, out2, atol=1e-7)

    out1 = norm2(x1)
    out2 = norm2(x2)
    out3 = norm2(torch.cat([x1, x2], dim=0), torch.cat([batch, batch + 1]))
    assert torch.allclose(out1, out3[:100], atol=1e-7)
    assert torch.allclose(out2, out3[100:], atol=1e-7)