def test_corner_cases(features_shapes, device):
    # Only some graphs have node/edge features, global features are either present on all of them or absent from all
    gfs = features_shapes['global_features_shape']
    graphs = [
        add_random_features(Graph(num_nodes=0, num_edges=0),
                            global_features_shape=gfs),
        add_random_features(Graph(num_nodes=0, num_edges=0),
                            global_features_shape=gfs),
        add_random_features(Graph(num_nodes=3, num_edges=0),
                            **features_shapes),
        add_random_features(Graph(num_nodes=0, num_edges=0),
                            **features_shapes),
        add_random_features(
            Graph(num_nodes=2,
                  senders=torch.tensor([0, 1]),
                  receivers=torch.tensor([1, 0])), **features_shapes)
    ]
    graphbatch = GraphBatch.from_graphs(graphs).to(device)
    validate_batch(graphbatch)

    for g_orig, g_batch in zip(graphs, graphbatch):
        assert_graphs_equal(g_orig, g_batch.cpu())

    # Global features should be either present on all graphs or absent from all graphs
    with pytest.raises(ValueError):
        GraphBatch.from_graphs([
            Graph(num_nodes=0, num_edges=0),
            add_random_features(Graph(num_nodes=0, num_edges=0),
                                global_features_shape=10)
        ])
    with pytest.raises(ValueError):
        GraphBatch.from_graphs([
            add_random_features(Graph(num_nodes=0, num_edges=0),
                                global_features_shape=10),
            Graph(num_nodes=0, num_edges=0)
        ])
def test_from_graphs(graphs, features_shapes, device):
    graphs = [
        add_random_features(g, **features_shapes).to(device) for g in graphs
    ]
    graphbatch = GraphBatch.from_graphs(graphs)

    validate_batch(graphbatch)

    assert len(graphs) == len(graphbatch) == graphbatch.num_graphs
    assert [g.num_nodes
            for g in graphs] == graphbatch.num_nodes_by_graph.tolist()
    assert [g.num_edges
            for g in graphs] == graphbatch.num_edges_by_graph.tolist()

    # Test sequential access (__iter__)
    for g, gb in zip(graphs, graphbatch):
        assert_graphs_equal(g, gb)

    # Test random access (__getitem__)
    for i in range(len(graphbatch)):
        assert_graphs_equal(graphs[i], graphbatch[i])