def test_mask_no_nodes(): deterministic_seed(0) data = GraphData( torch.randn(5, 5), torch.randn(3, 2), torch.randn(1, 1), edges=torch.LongTensor([[0, 0, 0], [1, 2, 3]]), ) node_mask = torch.BoolTensor([True, True, True, True, True]) data2 = data.apply_node_mask(node_mask) assert data2.num_nodes == 5 assert data2.num_edges == 3
def test_mask_one_node(): deterministic_seed(0) edges = torch.LongTensor([[0, 1, 0], [1, 2, 3]]) expected_edges = torch.LongTensor([[0], [1]]) node_mask = torch.BoolTensor([False, True, True, True, True]) x = torch.randn(5, 5) expected_x = x[node_mask] e = torch.randn(3, 2) expected_e = e[torch.LongTensor([1])] data = GraphData(x, e, torch.randn(1, 1), edges=edges) data2 = data.apply_node_mask(node_mask) assert torch.all(data2.edges == expected_edges) print(data2.x) print(expected_x) assert torch.allclose(data2.x, expected_x) assert torch.allclose(data2.e, expected_e) assert torch.allclose(data2.g, data.g)