コード例 #1
0
ファイル: test_apply_mask.py プロジェクト: jvrana/caldera
def test_mask_no_nodes():
    deterministic_seed(0)

    data = GraphData(
        torch.randn(5, 5),
        torch.randn(3, 2),
        torch.randn(1, 1),
        edges=torch.LongTensor([[0, 0, 0], [1, 2, 3]]),
    )
    node_mask = torch.BoolTensor([True, True, True, True, True])
    data2 = data.apply_node_mask(node_mask)
    assert data2.num_nodes == 5
    assert data2.num_edges == 3
コード例 #2
0
ファイル: test_apply_mask.py プロジェクト: jvrana/caldera
def test_mask_one_node():
    deterministic_seed(0)

    edges = torch.LongTensor([[0, 1, 0], [1, 2, 3]])
    expected_edges = torch.LongTensor([[0], [1]])

    node_mask = torch.BoolTensor([False, True, True, True, True])

    x = torch.randn(5, 5)
    expected_x = x[node_mask]

    e = torch.randn(3, 2)
    expected_e = e[torch.LongTensor([1])]

    data = GraphData(x, e, torch.randn(1, 1), edges=edges)

    data2 = data.apply_node_mask(node_mask)
    assert torch.all(data2.edges == expected_edges)

    print(data2.x)
    print(expected_x)
    assert torch.allclose(data2.x, expected_x)
    assert torch.allclose(data2.e, expected_e)
    assert torch.allclose(data2.g, data.g)