def test_from_to_networkxs(graphs_nx, features_shapes, device): graphs_nx = [add_random_features(g, **features_shapes) for g in graphs_nx] graphbatch = GraphBatch.from_networkxs(graphs_nx).to(device) validate_batch(graphbatch) assert len(graphs_nx) == len(graphbatch) == graphbatch.num_graphs assert [g.number_of_nodes() for g in graphs_nx] == graphbatch.num_nodes_by_graph.tolist() assert [g.number_of_edges() for g in graphs_nx] == graphbatch.num_edges_by_graph.tolist() # Test sequential access (__iter__) for g_nx, g in zip(graphs_nx, graphbatch): assert_graphs_equal(g_nx, g.cpu()) # Test random access (__getitem__) for i in range(len(graphbatch)): assert_graphs_equal(graphs_nx[i], graphbatch[i].cpu()) # Test back conversion graphs_nx_back = graphbatch.cpu().to_networkxs() for g1, g2 in zip(graphs_nx, graphs_nx_back): assert_graphs_equal(g1, g2)
def graphbatch() -> GraphBatch: return GraphBatch.from_networkxs(graphs_for_test().values())