def test_collate_dicts(graphs_nx, features_shapes, device): graphs_in = [ add_random_features(Graph.from_networkx(g), **features_shapes).to(device) for g in graphs_nx ] graphs_out = list(reversed(graphs_in)) xs = torch.rand(len(graphs_in), 10, 32) ys = torch.rand(len(graphs_in), 7) samples = [{ 'in': gi, 'x': x, 'y': y, 'out': go } for gi, x, y, go in zip(graphs_in, xs, ys, graphs_out)] batch = GraphBatch.collate(samples) for g1, g2 in zip(graphs_in, batch['in']): assert_graphs_equal(g1, g2) torch.testing.assert_allclose(xs, batch['x']) torch.testing.assert_allclose(ys, batch['y']) for g1, g2 in zip(graphs_out, batch['out']): assert_graphs_equal(g1, g2)
def test_device(graph, features_shapes, device): graph = add_random_features(graph, **features_shapes) other_graph = graph.to(device) for k in other_graph._feature_fields: assert (getattr(other_graph, k) is None) or (getattr( other_graph, k).device == device) assert_graphs_equal(graph, other_graph.cpu())
def test_from_graphs(graphs, features_shapes, device): graphs = [ add_random_features(g, **features_shapes).to(device) for g in graphs ] graphbatch = GraphBatch.from_graphs(graphs) validate_batch(graphbatch) assert len(graphs) == len(graphbatch) == graphbatch.num_graphs assert [g.num_nodes for g in graphs] == graphbatch.num_nodes_by_graph.tolist() assert [g.num_edges for g in graphs] == graphbatch.num_edges_by_graph.tolist() # Test sequential access (__iter__) for g, gb in zip(graphs, graphbatch): assert_graphs_equal(g, gb) # Test random access (__getitem__) for i in range(len(graphbatch)): assert_graphs_equal(graphs[i], graphbatch[i])
def test_collate_tuples(graphs_nx, features_shapes, device): graphs_in = [ add_random_features(Graph.from_networkx(g), **features_shapes).to(device) for g in graphs_nx ] graphs_out = list(reversed(graphs_in)) xs = torch.rand(len(graphs_in), 10, 32) ys = torch.rand(len(graphs_in), 7) samples = list(zip(graphs_in, xs, ys, graphs_out)) batch = GraphBatch.collate(samples) for g1, g2 in zip(graphs_in, batch[0]): assert_graphs_equal(g1, g2) torch.testing.assert_allclose(xs, batch[1]) torch.testing.assert_allclose(ys, batch[2]) for g1, g2 in zip(graphs_out, batch[3]): assert_graphs_equal(g1, g2)
def test_corner_cases(features_shapes, device): # Only some graphs have node/edge features, global features are either present on all of them or absent from all gfs = features_shapes['global_features_shape'] graphs = [ add_random_features(Graph(num_nodes=0, num_edges=0), global_features_shape=gfs), add_random_features(Graph(num_nodes=0, num_edges=0), global_features_shape=gfs), add_random_features(Graph(num_nodes=3, num_edges=0), **features_shapes), add_random_features(Graph(num_nodes=0, num_edges=0), **features_shapes), add_random_features( Graph(num_nodes=2, senders=torch.tensor([0, 1]), receivers=torch.tensor([1, 0])), **features_shapes) ] graphbatch = GraphBatch.from_graphs(graphs).to(device) validate_batch(graphbatch) for g_orig, g_batch in zip(graphs, graphbatch): assert_graphs_equal(g_orig, g_batch.cpu()) # Global features should be either present on all graphs or absent from all graphs with pytest.raises(ValueError): GraphBatch.from_graphs([ Graph(num_nodes=0, num_edges=0), add_random_features(Graph(num_nodes=0, num_edges=0), global_features_shape=10) ]) with pytest.raises(ValueError): GraphBatch.from_graphs([ add_random_features(Graph(num_nodes=0, num_edges=0), global_features_shape=10), Graph(num_nodes=0, num_edges=0) ])
def test_from_to_networkxs(graphs_nx, features_shapes, device): graphs_nx = [add_random_features(g, **features_shapes) for g in graphs_nx] graphbatch = GraphBatch.from_networkxs(graphs_nx).to(device) validate_batch(graphbatch) assert len(graphs_nx) == len(graphbatch) == graphbatch.num_graphs assert [g.number_of_nodes() for g in graphs_nx] == graphbatch.num_nodes_by_graph.tolist() assert [g.number_of_edges() for g in graphs_nx] == graphbatch.num_edges_by_graph.tolist() # Test sequential access (__iter__) for g_nx, g in zip(graphs_nx, graphbatch): assert_graphs_equal(g_nx, g.cpu()) # Test random access (__getitem__) for i in range(len(graphbatch)): assert_graphs_equal(graphs_nx[i], graphbatch[i].cpu()) # Test back conversion graphs_nx_back = graphbatch.cpu().to_networkxs() for g1, g2 in zip(graphs_nx, graphs_nx_back): assert_graphs_equal(g1, g2)
def test_from_networkx(graph_nx, features_shapes): graph_nx = add_random_features(graph_nx, **features_shapes) graph = Graph.from_networkx(graph_nx) assert_graphs_equal(graph_nx, graph)
def test_to_networkx(graph, features_shapes): graph = add_random_features(graph, **features_shapes) graph_nx = graph.to_networkx() assert_graphs_equal(graph, graph_nx)