def test_dynamic_edge_conv_conv():
    x1 = torch.randn(8, 16)
    x2 = torch.randn(4, 16)
    batch1 = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
    batch2 = torch.tensor([0, 0, 1, 1])

    nn = Seq(Lin(32, 16), ReLU(), Lin(16, 32))
    conv = DynamicEdgeConv(nn, k=2)
    assert conv.__repr__() == (
        'DynamicEdgeConv(nn=Sequential(\n'
        '  (0): Linear(in_features=32, out_features=16, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=16, out_features=32, bias=True)\n'
        '), k=2)')
    out11 = conv(x1)
    assert out11.size() == (8, 32)

    out12 = conv(x1, batch1)
    assert out12.size() == (8, 32)

    out21 = conv((x1, x2))
    assert out21.size() == (4, 32)

    out22 = conv((x1, x2), (batch1, batch2))
    assert out22.size() == (4, 32)

    t = '(Tensor, OptTensor) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit(x1).tolist() == out11.tolist()
    assert jit(x1, batch1).tolist() == out12.tolist()

    t = '(PairTensor, Optional[PairTensor]) -> Tensor'
    jit = torch.jit.script(conv.jittable(t))
    assert jit((x1, x2)).tolist() == out21.tolist()
    assert jit((x1, x2), (batch1, batch2)).tolist() == out22.tolist()
Exemple #2
0
def test_dynamic_edge_conv_conv():
    x = torch.randn((4, 16))

    nn = Seq(Lin(32, 16), ReLU(), Lin(16, 32))
    conv = DynamicEdgeConv(nn, k=6)
    assert conv.__repr__() == (
        'DynamicEdgeConv(nn=Sequential(\n'
        '  (0): Linear(in_features=32, out_features=16, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=16, out_features=32, bias=True)\n'
        '), k=6)')
    assert conv(x).size() == (4, 32)
def test_dynamic_edge_conv_conv():
    in_channels, out_channels = (16, 32)
    num_nodes = 20
    x = torch.randn((num_nodes, in_channels))

    nn = Seq(Lin(2 * in_channels, 32), ReLU(), Lin(32, out_channels))
    conv = DynamicEdgeConv(nn, k=6)
    assert conv.__repr__() == (
        'DynamicEdgeConv(nn=Sequential(\n'
        '  (0): Linear(in_features=32, out_features=32, bias=True)\n'
        '  (1): ReLU()\n'
        '  (2): Linear(in_features=32, out_features=32, bias=True)\n'
        '), k=6)')
    assert conv(x).size() == (num_nodes, out_channels)