Exemplo n.º 1
0
def test_scatter_group_requires_grad():
    idx = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float)
    x.requires_grad = True
    out = torch_scatter_group(x, idx)
    for o in out[1]:
        assert o.requires_grad
Exemplo n.º 2
0
def test_scatter_group4():

    idx = torch.tensor([0, 0, 0, 0, 0, 1, 1])
    x = torch.randn(7, 3)
    out = torch_scatter_group(x, idx)
    assert torch.all(torch.eq(out[1][0], x[:5]))
    assert torch.all(torch.eq(out[1][1], x[5:]))
Exemplo n.º 3
0
def test_scatter_group_0():

    idx = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])

    out = torch_scatter_group(x, idx)
    assert torch.all(torch.eq(out[0], torch.tensor([0, 1, 2])))
    assert len(out[1]) == 3
    assert torch.all(torch.eq(out[1][0], torch.tensor([0, 1, 2])))
    assert torch.all(torch.eq(out[1][1], torch.tensor([3, 4, 5])))
    assert torch.all(torch.eq(out[1][2], torch.tensor([6, 7, 8])))
Exemplo n.º 4
0
def test_scatter_group_3():

    idx = torch.tensor([2, 2, 0, 1, 1, 1, 2])
    x = torch.tensor([0, 1, 2, 3, 4, 5, 6])

    out = torch_scatter_group(x, idx)
    assert torch.all(torch.eq(out[0], torch.tensor([0, 1, 2])))

    assert torch.all(torch.eq(out[1][0], torch.tensor([2])))
    assert torch.all(torch.eq(out[1][1], torch.tensor([3, 4, 5])))
    assert torch.all(torch.eq(out[1][2], torch.tensor([0, 1, 6])))
Exemplo n.º 5
0
def test_find_neighbors():
    data = GraphData.random(5, 4, 3, min_nodes=1000, min_edges=1000)
    W = floyd_warshall(data)
    print(W.__class__)
    nodes = torch.LongTensor([[0], [1], [2], [3]])

    x = W[nodes]
    noninf = x != float("inf")
    reachable = x <= 3

    print(x.__class__)
    print(noninf.__class__)
    c = torch.logical_and(noninf, reachable)
    d = torch.where(c)
    print(d)
    neighbors = d[1]
    print(neighbors)

    from caldera.utils import torch_scatter_group

    neighbors = torch_scatter_group(d[-1], d[0])  # d[0], d[-1])

    print(neighbors)