def test_from_to_networkxs(graphs_nx, features_shapes, device):
    graphs_nx = [add_random_features(g, **features_shapes) for g in graphs_nx]
    graphbatch = GraphBatch.from_networkxs(graphs_nx).to(device)

    validate_batch(graphbatch)

    assert len(graphs_nx) == len(graphbatch) == graphbatch.num_graphs
    assert [g.number_of_nodes()
            for g in graphs_nx] == graphbatch.num_nodes_by_graph.tolist()
    assert [g.number_of_edges()
            for g in graphs_nx] == graphbatch.num_edges_by_graph.tolist()

    # Test sequential access (__iter__)
    for g_nx, g in zip(graphs_nx, graphbatch):
        assert_graphs_equal(g_nx, g.cpu())

    # Test random access (__getitem__)
    for i in range(len(graphbatch)):
        assert_graphs_equal(graphs_nx[i], graphbatch[i].cpu())

    # Test back conversion
    graphs_nx_back = graphbatch.cpu().to_networkxs()
    for g1, g2 in zip(graphs_nx, graphs_nx_back):
        assert_graphs_equal(g1, g2)
Exemplo n.º 2
0
def graphbatch() -> GraphBatch:
    return GraphBatch.from_networkxs(graphs_for_test().values())