def test_nn_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0), 3) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) nn = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32)) conv = NNConv(8, 32, nn=nn) assert conv.__repr__() == ( 'NNConv(8, 32, aggr=add, nn=Sequential(\n' ' (0): Linear(in_features=3, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=256, bias=True)\n' '))') out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() if is_full_test(): t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index, value).tolist() == out.tolist() assert jit(x1, edge_index, value, size=(4, 4)).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) conv = NNConv((8, 16), 32, nn=nn) assert conv.__repr__() == ( 'NNConv((8, 16), 32, aggr=add, nn=Sequential(\n' ' (0): Linear(in_features=3, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=256, bias=True)\n' '))') out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out1.tolist() assert conv((x1, x2), adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() if is_full_test(): t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index, value).tolist() == out1.tolist() assert jit((x1, x2), edge_index, value, size=(4, 2)).tolist() == out1.tolist() assert jit((x1, None), edge_index, value, size=(4, 2)).tolist() == out2.tolist() t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out1.tolist() assert jit((x1, None), adj.t()).tolist() == out2.tolist()
def test_dna_conv(): channels = 32 num_layers = 3 edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) num_nodes = edge_index.max().item() + 1 x = torch.randn((num_nodes, num_layers, channels)) conv = DNAConv(channels, heads=4, groups=8, dropout=0.0) assert conv.__repr__() == 'DNAConv(32, heads=4, groups=8)' out = conv(x, edge_index) assert out.size() == (num_nodes, channels) if is_full_test(): jit = torch.jit.script(conv.jittable()) assert jit(x, edge_index).tolist() == out.tolist() conv = DNAConv(channels, heads=1, groups=1, dropout=0.0) assert conv.__repr__() == 'DNAConv(32, heads=1, groups=1)' out = conv(x, edge_index) assert out.size() == (num_nodes, channels) if is_full_test(): jit = torch.jit.script(conv.jittable()) assert jit(x, edge_index).tolist() == out.tolist() conv = DNAConv(channels, heads=1, groups=1, dropout=0.0, cached=True) out = conv(x, edge_index) out = conv(x, edge_index) assert out.size() == (num_nodes, channels) if is_full_test(): jit = torch.jit.script(conv.jittable()) assert jit(x, edge_index).tolist() == out.tolist()
def test_jumping_knowledge(): num_nodes, channels, num_layers = 100, 17, 5 xs = list([torch.randn(num_nodes, channels) for _ in range(num_layers)]) model = JumpingKnowledge('cat') assert model.__repr__() == 'JumpingKnowledge(cat)' out = model(xs) assert out.size() == (num_nodes, channels * num_layers) if is_full_test(): jit = torch.jit.script(model) assert torch.allclose(jit(xs), out) model = JumpingKnowledge('max') assert model.__repr__() == 'JumpingKnowledge(max)' out = model(xs) assert out.size() == (num_nodes, channels) if is_full_test(): jit = torch.jit.script(model) assert torch.allclose(jit(xs), out) model = JumpingKnowledge('lstm', channels, num_layers) assert model.__repr__() == 'JumpingKnowledge(lstm)' out = model(xs) assert out.size() == (num_nodes, channels) if is_full_test(): jit = torch.jit.script(model) assert torch.allclose(jit(xs), out)
def test_heat_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index edge_attr = torch.randn((4, 2)) adj = SparseTensor(row=row, col=col, value=edge_attr, sparse_sizes=(4, 4)) node_type = torch.tensor([0, 0, 1, 2]) edge_type = torch.tensor([0, 2, 1, 2]) conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3, num_edge_types=3, edge_type_emb_dim=5, edge_dim=2, edge_attr_emb_dim=6, heads=2, concat=True) assert str(conv) == 'HEATConv(8, 16, heads=2)' out = conv(x, edge_index, node_type, edge_type, edge_attr) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out) if is_full_test(): t = '(Tensor, Tensor, Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose( jit(x, edge_index, node_type, edge_type, edge_attr), out) conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3, num_edge_types=3, edge_type_emb_dim=5, edge_dim=2, edge_attr_emb_dim=6, heads=2, concat=False) assert str(conv) == 'HEATConv(8, 16, heads=2)' out = conv(x, edge_index, node_type, edge_type, edge_attr) assert out.size() == (4, 16) assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out) if is_full_test(): t = '(Tensor, SparseTensor, Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t(), node_type, edge_type), out)
def test_signed_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv1 = SignedConv(16, 32, first_aggr=True) assert conv1.__repr__() == 'SignedConv(16, 32, first_aggr=True)' conv2 = SignedConv(32, 48, first_aggr=False) assert conv2.__repr__() == 'SignedConv(32, 48, first_aggr=False)' out1 = conv1(x, edge_index, edge_index) assert out1.size() == (4, 64) assert conv1(x, adj.t(), adj.t()).tolist() == out1.tolist() out2 = conv2(out1, edge_index, edge_index) assert out2.size() == (4, 96) assert conv2(out1, adj.t(), adj.t()).tolist() == out2.tolist() if is_full_test(): t = '(Tensor, Tensor, Tensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) assert jit1(x, edge_index, edge_index).tolist() == out1.tolist() assert jit2(out1, edge_index, edge_index).tolist() == out2.tolist() t = '(Tensor, SparseTensor, SparseTensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) assert jit1(x, adj.t(), adj.t()).tolist() == out1.tolist() assert jit2(out1, adj.t(), adj.t()).tolist() == out2.tolist() adj = adj.sparse_resize((4, 2)) assert torch.allclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2], atol=1e-6) assert torch.allclose(conv1((x, x[:2]), adj.t(), adj.t()), out1[:2], atol=1e-6) assert torch.allclose(conv2((out1, out1[:2]), edge_index, edge_index), out2[:2], atol=1e-6) assert torch.allclose(conv2((out1, out1[:2]), adj.t(), adj.t()), out2[:2], atol=1e-6) if is_full_test(): t = '(PairTensor, Tensor, Tensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) assert torch.allclose(jit1((x, x[:2]), edge_index, edge_index), out1[:2], atol=1e-6) assert torch.allclose(jit2((out1, out1[:2]), edge_index, edge_index), out2[:2], atol=1e-6) t = '(PairTensor, SparseTensor, SparseTensor) -> Tensor' jit1 = torch.jit.script(conv1.jittable(t)) jit2 = torch.jit.script(conv2.jittable(t)) assert torch.allclose(jit1((x, x[:2]), adj.t(), adj.t()), out1[:2], atol=1e-6) assert torch.allclose(jit2((out1, out1[:2]), adj.t(), adj.t()), out2[:2], atol=1e-6)
def test_gmm_conv(separate_gaussians): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0), 3) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) conv = GMMConv(8, 32, dim=3, kernel_size=25, separate_gaussians=separate_gaussians) assert conv.__repr__() == 'GMMConv(8, 32, dim=3)' out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out) assert torch.allclose(conv(x1, adj.t()), out) if is_full_test(): t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, edge_index, value), out) assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out) t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, adj.t()), out) adj = adj.sparse_resize((4, 2)) conv = GMMConv((8, 16), 32, dim=3, kernel_size=5, separate_gaussians=separate_gaussians) assert conv.__repr__() == 'GMMConv((8, 16), 32, dim=3)' out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) assert torch.allclose(conv((x1, x2), adj.t()), out1) assert torch.allclose(conv((x1, None), adj.t()), out2) if is_full_test(): t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), edge_index, value), out1) assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), out2) t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), adj.t()), out1) assert torch.allclose(jit((x1, None), adj.t()), out2)
def test_ppf_conv(): x1 = torch.randn(4, 16) pos1 = torch.randn(4, 3) pos2 = torch.randn(2, 3) n1 = F.normalize(torch.rand(4, 3), dim=-1) n2 = F.normalize(torch.rand(2, 3), dim=-1) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32)) global_nn = Seq(Lin(32, 32)) conv = PPFConv(local_nn, global_nn) assert conv.__repr__() == ( 'PPFConv(local_nn=Sequential(\n' ' (0): Linear(in_features=20, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '), global_nn=Sequential(\n' ' (0): Linear(in_features=32, out_features=32, bias=True)\n' '))') out = conv(x1, pos1, n1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, n1, adj.t()), out, atol=1e-6) if is_full_test(): t = '(OptTensor, Tensor, Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, pos1, n1, edge_index).tolist() == out.tolist() t = '(OptTensor, Tensor, Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, pos1, n1, adj.t()), out, atol=1e-6) adj = adj.sparse_resize((4, 2)) out = conv(x1, (pos1, pos2), (n1, n2), edge_index) assert out.size() == (2, 32) assert conv((x1, None), (pos1, pos2), (n1, n2), edge_index).tolist() == out.tolist() assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj.t()), out, atol=1e-6) assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj.t()), out, atol=1e-6) if is_full_test(): t = '(PairOptTensor, PairTensor, PairTensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, None), (pos1, pos2), (n1, n2), edge_index).tolist() == out.tolist() t = '(PairOptTensor, PairTensor, PairTensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, None), (pos1, pos2), (n1, n2), adj.t()), out, atol=1e-6)
def test_point_transformer_conv(): x1 = torch.rand(4, 16) x2 = torch.randn(2, 8) pos1 = torch.rand(4, 3) pos2 = torch.randn(2, 3) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = PointTransformerConv(in_channels=16, out_channels=32) assert str(conv) == 'PointTransformerConv(16, 32)' out = conv(x1, pos1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, pos1, edge_index).tolist() == out.tolist() t = '(Tensor, Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, pos1, adj.t()), out, atol=1e-6) pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32)) attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32)) conv = PointTransformerConv(16, 32, pos_nn, attn_nn) out = conv(x1, pos1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, pos1, adj.t()), out, atol=1e-6) conv = PointTransformerConv((16, 8), 32) adj = adj.sparse_resize((4, 2)) out = conv((x1, x2), (pos1, pos2), edge_index) assert out.size() == (2, 32) assert torch.allclose(conv((x1, x2), (pos1, pos2), adj.t()), out, atol=1e-6) if is_full_test(): t = '(PairTensor, PairTensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), (pos1, pos2), edge_index).tolist() == out.tolist() t = '(PairTensor, PairTensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), (pos1, pos2), adj.t()), out, atol=1e-6)
def test_gin_conv(): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) conv = GINConv(nn, train_eps=True) assert conv.__repr__() == ( 'GINConv(nn=Sequential(\n' ' (0): Linear(in_features=16, out_features=32, bias=True)\n' ' (1): ReLU()\n' ' (2): Linear(in_features=32, out_features=32, bias=True)\n' '))') out = conv(x1, edge_index) assert out.size() == (4, 32) assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() if is_full_test(): t = '(Tensor, Tensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist() t = '(Tensor, SparseTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) out1 = conv((x1, x2), edge_index) out2 = conv((x1, None), edge_index, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist() assert conv((x1, x2), adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() if is_full_test(): t = '(OptPairTensor, Tensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out1.tolist() assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist() assert jit((x1, None), edge_index, size=(4, 2)).tolist() == out2.tolist() t = '(OptPairTensor, SparseTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out1.tolist() assert jit((x1, None), adj.t()).tolist() == out2.tolist()
def test_spline_conv(): warnings.filterwarnings('ignore', '.*non-optimized CPU version.*') x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0), 3) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) conv = SplineConv(8, 32, dim=3, kernel_size=5) assert conv.__repr__() == 'SplineConv(8, 32, dim=3)' out = conv(x1, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out) assert torch.allclose(conv(x1, adj.t()), out) if is_full_test(): t = '(Tensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, edge_index, value), out) assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out) t = '(Tensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, adj.t()), out) adj = adj.sparse_resize((4, 2)) conv = SplineConv((8, 16), 32, dim=3, kernel_size=5) assert conv.__repr__() == 'SplineConv((8, 16), 32, dim=3)' out1 = conv((x1, x2), edge_index, value) out2 = conv((x1, None), edge_index, value, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) assert torch.allclose(conv((x1, x2), adj.t()), out1) assert torch.allclose(conv((x1, None), adj.t()), out2) if is_full_test(): t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), edge_index, value), out1) assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)), out1) assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)), out2) t = '(OptPairTensor, SparseTensor, OptTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), adj.t()), out1) assert torch.allclose(jit((x1, None), adj.t()), out2)
def test_cg_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = CGConv(8) assert conv.__repr__() == 'CGConv(8, dim=0)' out = conv(x1, edge_index) assert out.size() == (4, 8) assert conv(x1, adj.t()).tolist() == out.tolist() if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) conv = CGConv((8, 16)) assert conv.__repr__() == 'CGConv((8, 16), dim=0)' out = conv((x1, x2), edge_index) assert out.size() == (2, 16) assert conv((x1, x2), adj.t()).tolist() == out.tolist() if is_full_test(): t = '(PairTensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out.tolist() t = '(PairTensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out.tolist() # Test batch_norm true: adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = CGConv(8, batch_norm=True) assert conv.__repr__() == 'CGConv(8, dim=0)' out = conv(x1, edge_index) assert out.size() == (4, 8) assert conv(x1, adj.t()).tolist() == out.tolist() if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist()
def test_lg_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = LGConv() assert str(conv) == 'LGConv()' out1 = conv(x, edge_index) assert out1.size() == (4, 8) assert torch.allclose(conv(x, adj1.t()), out1) out2 = conv(x, edge_index, value) assert out2.size() == (4, 8) assert torch.allclose(conv(x, adj2.t()), out2) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, edge_index), out1) assert torch.allclose(jit(x, edge_index, value), out2) t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1) assert torch.allclose(jit(x, adj2.t()), out2)
def test_gravnet_conv(): x1 = torch.randn(8, 16) x2 = torch.randn(4, 16) batch1 = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) batch2 = torch.tensor([0, 0, 1, 1]) conv = GravNetConv(16, 32, space_dimensions=4, propagate_dimensions=8, k=2) assert conv.__repr__() == 'GravNetConv(16, 32, k=2)' out11 = conv(x1) assert out11.size() == (8, 32) out12 = conv(x1, batch1) assert out12.size() == (8, 32) out21 = conv((x1, x2)) assert out21.size() == (4, 32) out22 = conv((x1, x2), (batch1, batch2)) assert out22.size() == (4, 32) if is_full_test(): t = '(Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1).tolist() == out11.tolist() assert jit(x1, batch1).tolist() == out12.tolist() t = '(PairTensor, Optional[PairTensor]) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2)).tolist() == out21.tolist() assert jit((x1, x2), (batch1, batch2)).tolist() == out22.tolist() torch.jit.script(conv.jittable()) # Test without explicit typing.
def test_eg_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = EGConv(16, 32) assert conv.__repr__() == "EGConv(16, 32, aggregators=['symnorm'])" out = conv(x, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv(x, edge_index).tolist() == out.tolist() conv(x, adj.t()) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
def test_mlp(batch_norm, act_first, plain_last): x = torch.randn(4, 16) torch.manual_seed(12345) mlp = MLP( [16, 32, 32, 64], batch_norm=batch_norm, act_first=act_first, plain_last=plain_last, ) assert str(mlp) == 'MLP(16, 32, 32, 64)' out = mlp(x) assert out.size() == (4, 64) if is_full_test(): jit = torch.jit.script(mlp) assert torch.allclose(jit(x), out) torch.manual_seed(12345) mlp = MLP( 16, hidden_channels=32, out_channels=64, num_layers=3, batch_norm=batch_norm, act_first=act_first, plain_last=plain_last, ) assert torch.allclose(mlp(x), out)
def test_pna_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0), 3) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) conv = PNAConv(16, 32, aggregators, scalers, deg=torch.tensor([0, 3, 0, 1]), edge_dim=3, towers=4) assert conv.__repr__() == 'PNAConv(16, 32, towers=4, edge_dim=3)' out = conv(x, edge_index, value) assert out.size() == (4, 32) assert torch.allclose(conv(x, adj.t()), out, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, edge_index, value), out, atol=1e-6) t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
def test_gcn_conv(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) row, col = edge_index value = torch.rand(row.size(0)) adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) adj1 = adj2.set_value(None) conv = GCNConv(16, 32) assert conv.__repr__() == 'GCNConv(16, 32)' out1 = conv(x, edge_index) assert out1.size() == (4, 32) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) out2 = conv(x, edge_index, value) assert out2.size() == (4, 32) assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, value).tolist() == out2.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj1.t()), out1, atol=1e-6) assert torch.allclose(jit(x, adj2.t()), out2, atol=1e-6) conv.cached = True conv(x, edge_index) assert conv(x, edge_index).tolist() == out1.tolist() conv(x, adj1.t()) assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)
def test_sage_conv(project): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = SAGEConv(8, 32, project=project) assert str(conv) == 'SAGEConv(8, 32, aggr=mean)' out = conv(x1, edge_index) assert out.size() == (4, 32) assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist() assert conv(x1, adj.t()).tolist() == out.tolist() if is_full_test(): t = '(Tensor, Tensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist() t = '(Tensor, SparseTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) conv = SAGEConv((8, 16), 32, project=project) assert str(conv) == 'SAGEConv((8, 16), 32, aggr=mean)' out1 = conv((x1, x2), edge_index) out2 = conv((x1, None), edge_index, (4, 2)) assert out1.size() == (2, 32) assert out2.size() == (2, 32) assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist() assert conv((x1, x2), adj.t()).tolist() == out1.tolist() assert conv((x1, None), adj.t()).tolist() == out2.tolist() if is_full_test(): t = '(OptPairTensor, Tensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out1.tolist() assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist() assert jit((x1, None), edge_index, size=(4, 2)).tolist() == out2.tolist() t = '(OptPairTensor, SparseTensor, Size) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out1.tolist() assert jit((x1, None), adj.t()).tolist() == out2.tolist()
def test_batch_norm(conf): x = torch.randn(100, 16) norm = BatchNorm(16, affine=conf, track_running_stats=conf) assert norm.__repr__() == 'BatchNorm(16)' if is_full_test(): torch.jit.script(norm) out = norm(x) assert out.size() == (100, 16)
def test_diff_group_norm(): x = torch.randn(6, 16) norm = DiffGroupNorm(16, groups=4, lamda=0) assert norm.__repr__() == 'DiffGroupNorm(16, groups=4)' assert norm(x).tolist() == x.tolist() if is_full_test(): jit = torch.jit.script(norm) assert jit(x).tolist() == x.tolist() norm = DiffGroupNorm(16, groups=4, lamda=0.01) assert norm.__repr__() == 'DiffGroupNorm(16, groups=4)' out = norm(x) assert out.size() == x.size() if is_full_test(): jit = torch.jit.script(norm) assert jit(x).tolist() == out.tolist()
def test_hetero_linear(): x = torch.randn((3, 16)) node_type = torch.tensor([0, 1, 2]) lin = HeteroLinear(in_channels=16, out_channels=32, num_types=3) assert str(lin) == 'HeteroLinear(16, 32, bias=True)' out = lin(x, node_type) assert out.size() == (3, 32) if is_full_test(): jit = torch.jit.script(lin) assert torch.allclose(jit(x, node_type), out)
def test_cheb_conv(): in_channels, out_channels = (16, 32) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) num_nodes = edge_index.max().item() + 1 edge_weight = torch.rand(edge_index.size(1)) x = torch.randn((num_nodes, in_channels)) conv = ChebConv(in_channels, out_channels, K=3) assert conv.__repr__() == 'ChebConv(16, 32, K=3, normalization=sym)' out1 = conv(x, edge_index) assert out1.size() == (num_nodes, out_channels) out2 = conv(x, edge_index, edge_weight) assert out2.size() == (num_nodes, out_channels) out3 = conv(x, edge_index, edge_weight, lambda_max=3.0) assert out3.size() == (num_nodes, out_channels) if is_full_test(): jit = torch.jit.script(conv.jittable()) assert jit(x, edge_index).tolist() == out1.tolist() assert jit(x, edge_index, edge_weight).tolist() == out2.tolist() assert jit(x, edge_index, edge_weight, lambda_max=torch.tensor(3.0)).tolist() == out3.tolist() batch = torch.tensor([0, 0, 1, 1]) edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) num_nodes = edge_index.max().item() + 1 edge_weight = torch.rand(edge_index.size(1)) x = torch.randn((num_nodes, in_channels)) lambda_max = torch.tensor([2.0, 3.0]) out4 = conv(x, edge_index, edge_weight, batch) assert out4.size() == (num_nodes, out_channels) out5 = conv(x, edge_index, edge_weight, batch, lambda_max) assert out5.size() == (num_nodes, out_channels) if is_full_test(): assert jit(x, edge_index, edge_weight, batch).tolist() == out4.tolist() assert jit(x, edge_index, edge_weight, batch, lambda_max).tolist() == out5.tolist()
def test_cg_conv_with_edge_features(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index value = torch.rand(row.size(0), 3) adj = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4)) conv = CGConv(8, dim=3) assert conv.__repr__() == 'CGConv(8, dim=3)' out = conv(x1, edge_index, value) assert out.size() == (4, 8) assert conv(x1, adj.t()).tolist() == out.tolist() if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index, value).tolist() == out.tolist() t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) conv = CGConv((8, 16), dim=3) assert conv.__repr__() == 'CGConv((8, 16), dim=3)' out = conv((x1, x2), edge_index, value) assert out.size() == (2, 16) assert conv((x1, x2), adj.t()).tolist() == out.tolist() if is_full_test(): t = '(PairTensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index, value).tolist() == out.tolist() t = '(PairTensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out.tolist()
def test_res_gated_graph_conv(): x1 = torch.randn(4, 8) x2 = torch.randn(2, 32) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = ResGatedGraphConv(8, 32) assert conv.__repr__() == 'ResGatedGraphConv(8, 32)' out = conv(x1, edge_index) assert out.size() == (4, 32) assert conv(x1, adj.t()).tolist() == out.tolist() if is_full_test(): t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, adj.t()).tolist() == out.tolist() adj = adj.sparse_resize((4, 2)) conv = ResGatedGraphConv((8, 32), 32) assert conv.__repr__() == 'ResGatedGraphConv((8, 32), 32)' out = conv((x1, x2), edge_index) assert out.size() == (2, 32) assert conv((x1, x2), adj.t()).tolist() == out.tolist() if is_full_test(): t = '(PairTensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out.tolist() t = '(PairTensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), adj.t()).tolist() == out.tolist()
def test_feast_conv(): x1 = torch.randn(4, 16) x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) conv = FeaStConv(16, 32, heads=2) assert conv.__repr__() == 'FeaStConv(16, 32, heads=2)' out = conv(x1, edge_index) assert out.size() == (4, 32) assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6) if is_full_test(): t = '(Tensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x1, edge_index).tolist() == out.tolist() t = '(Tensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6) adj = adj.sparse_resize((4, 2)) out = conv((x1, x2), edge_index) assert out.size() == (2, 32) assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6) if is_full_test(): t = '(PairTensor, Tensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit((x1, x2), edge_index).tolist() == out.tolist() t = '(PairTensor, SparseTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6)
def test_le_conv(): in_channels, out_channels = (16, 32) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) num_nodes = edge_index.max().item() + 1 x = torch.randn((num_nodes, in_channels)) conv = LEConv(in_channels, out_channels) assert conv.__repr__() == 'LEConv(16, 32)' out = conv(x, edge_index) assert out.size() == (num_nodes, out_channels) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert jit(x, edge_index).tolist() == out.tolist()
def test_pair_norm(scale_individually): x = torch.randn(100, 16) batch = torch.zeros(100, dtype=torch.long) norm = PairNorm(scale_individually=scale_individually) assert norm.__repr__() == 'PairNorm()' if is_full_test(): torch.jit.script(norm) out1 = norm(x) assert out1.size() == (100, 16) out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0)) assert torch.allclose(out1, out2[:100], atol=1e-6) assert torch.allclose(out1, out2[100:], atol=1e-6)
def test_gcn_conv_with_decomposed_layers(): x = torch.randn(4, 16) edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) conv = GCNConv(16, 32) decomposed_conv = copy.deepcopy(conv) decomposed_conv.decomposed_layers = 2 out1 = conv(x, edge_index) out2 = decomposed_conv(x, edge_index) assert torch.allclose(out1, out2) if is_full_test(): t = '(Tensor, Tensor, OptTensor) -> Tensor' jit = torch.jit.script(decomposed_conv.jittable(t)) assert jit(x, edge_index).tolist() == out1.tolist()
def test_layer_norm(affine): x = torch.randn(100, 16) batch = torch.zeros(100, dtype=torch.long) norm = LayerNorm(16, affine=affine) assert norm.__repr__() == 'LayerNorm(16)' if is_full_test(): torch.jit.script(norm) out1 = norm(x) assert out1.size() == (100, 16) assert torch.allclose(norm(x, batch), out1, atol=1e-6) out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0)) assert torch.allclose(out1, out2[:100], atol=1e-6) assert torch.allclose(out1, out2[100:], atol=1e-6)
def test_instance_norm(conf): batch = torch.zeros(100, dtype=torch.long) x1 = torch.randn(100, 16) x2 = torch.randn(100, 16) norm1 = InstanceNorm(16, affine=conf, track_running_stats=conf) norm2 = InstanceNorm(16, affine=conf, track_running_stats=conf) assert norm1.__repr__() == 'InstanceNorm(16)' if is_full_test(): torch.jit.script(norm1) out1 = norm1(x1) out2 = norm2(x1, batch) assert out1.size() == (100, 16) assert torch.allclose(out1, out2, atol=1e-7) if conf: assert torch.allclose(norm1.running_mean, norm2.running_mean) assert torch.allclose(norm1.running_var, norm2.running_var) out1 = norm1(x2) out2 = norm2(x2, batch) assert torch.allclose(out1, out2, atol=1e-7) if conf: assert torch.allclose(norm1.running_mean, norm2.running_mean) assert torch.allclose(norm1.running_var, norm2.running_var) norm1.eval() norm2.eval() out1 = norm1(x1) out2 = norm2(x1, batch) assert torch.allclose(out1, out2, atol=1e-7) out1 = norm1(x2) out2 = norm2(x2, batch) assert torch.allclose(out1, out2, atol=1e-7) out1 = norm2(x1) out2 = norm2(x2) out3 = norm2(torch.cat([x1, x2], dim=0), torch.cat([batch, batch + 1])) assert torch.allclose(out1, out3[:100], atol=1e-7) assert torch.allclose(out2, out3[100:], atol=1e-7)