Esempio n. 1
0
def test_mask_no_edges():
    deterministic_seed(0)

    data = GraphData(
        torch.randn(5, 5),
        torch.randn(3, 2),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 0, 0], [1, 2, 3]]),
    )
    edge_mask = torch.BoolTensor([True, True, True])
    data2 = data.apply_edge_mask(edge_mask)
    assert data2.num_nodes == 5
    assert data2.num_edges == 3
Esempio n. 2
0
def test_mask_one_edges():
    deterministic_seed(0)

    edges = torch.LongTensor([[0, 0, 0], [1, 2, 3]])
    expected_edges = torch.LongTensor([[0, 0], [1, 3]])

    e = torch.randn(3, 2)
    edge_mask = torch.BoolTensor([True, False, True])
    eidx = torch.where(edge_mask)
    expected_e = e[eidx]

    data = GraphData(torch.randn(5, 5), e, torch.randn(1, 1), edges=edges)

    data2 = data.apply_edge_mask(edge_mask)
    assert torch.all(data2.edges == expected_edges)
    assert torch.all(data2.e == expected_e)
    assert torch.all(data2.g == data.g)
    assert torch.all(data2.x == data.x)