def test_scatter_group_requires_grad(): idx = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]) x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float) x.requires_grad = True out = torch_scatter_group(x, idx) for o in out[1]: assert o.requires_grad
def test_scatter_group4(): idx = torch.tensor([0, 0, 0, 0, 0, 1, 1]) x = torch.randn(7, 3) out = torch_scatter_group(x, idx) assert torch.all(torch.eq(out[1][0], x[:5])) assert torch.all(torch.eq(out[1][1], x[5:]))
def test_scatter_group_0(): idx = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]) x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8]) out = torch_scatter_group(x, idx) assert torch.all(torch.eq(out[0], torch.tensor([0, 1, 2]))) assert len(out[1]) == 3 assert torch.all(torch.eq(out[1][0], torch.tensor([0, 1, 2]))) assert torch.all(torch.eq(out[1][1], torch.tensor([3, 4, 5]))) assert torch.all(torch.eq(out[1][2], torch.tensor([6, 7, 8])))
def test_scatter_group_3(): idx = torch.tensor([2, 2, 0, 1, 1, 1, 2]) x = torch.tensor([0, 1, 2, 3, 4, 5, 6]) out = torch_scatter_group(x, idx) assert torch.all(torch.eq(out[0], torch.tensor([0, 1, 2]))) assert torch.all(torch.eq(out[1][0], torch.tensor([2]))) assert torch.all(torch.eq(out[1][1], torch.tensor([3, 4, 5]))) assert torch.all(torch.eq(out[1][2], torch.tensor([0, 1, 6])))
def test_find_neighbors(): data = GraphData.random(5, 4, 3, min_nodes=1000, min_edges=1000) W = floyd_warshall(data) print(W.__class__) nodes = torch.LongTensor([[0], [1], [2], [3]]) x = W[nodes] noninf = x != float("inf") reachable = x <= 3 print(x.__class__) print(noninf.__class__) c = torch.logical_and(noninf, reachable) d = torch.where(c) print(d) neighbors = d[1] print(neighbors) from caldera.utils import torch_scatter_group neighbors = torch_scatter_group(d[-1], d[0]) # d[0], d[-1]) print(neighbors)