示例#1
0
def test_split_function_node_type():
    # Example graph:
    g = create_example_graph_1()

    # This doesn't work if g is not a StellarGraph
    with pytest.raises(TypeError):
        splits = train_val_test_split(
            g,
            node_type="movie",
            test_size=1,
            train_size=2,
            targets=None,
            split_equally=False,
            seed=None,
        )

    gs = StellarGraph(g)
    splits = train_val_test_split(
        gs,
        node_type="movie",
        test_size=1,
        train_size=2,
        targets=None,
        split_equally=False,
        seed=None,
    )
    assert all(g.node[s]["label"] == "movie" for split in splits for s in split)
示例#2
0
def test_split_function_unlabelled():
    # Example graph:
    sg = create_example_graph_1()

    # Leave some of the nodes unlabelled:
    targets = {}
    for ii, n in enumerate(sg):
        if ii > 2:
            targets[n] = 1

    splits = train_val_test_split(
        sg,
        node_type=None,
        test_size=2,
        train_size=2,
        targets=targets,
        split_equally=False,
        seed=None,
    )

    # For this number of nodes we should have 50% of the nodes as label 1
    # Note the length of val is still 2 even though we requested 1
    assert len(splits[0]) == 2
    assert len(splits[1]) == 0
    assert len(splits[2]) == 2
    assert len(splits[3]) == 3
示例#3
0
def test_split_function_percent():
    # Example graph:
    for g in [create_example_graph_1(), create_example_graph_2()]:

        # Test splits by proportion - note floor of the
        # number of samples
        splits = train_val_test_split(
            g,
            node_type=None,
            test_size=2.8 / 7,
            train_size=3.2 / 7,
            targets=None,
            seed=None,
        )

        # Note the length of val is still 2 even though we requested 1
        assert len(splits[0]) == 3
        assert len(splits[1]) == 2
        assert len(splits[2]) == 2
        assert len(splits[3]) == 0

        print(splits)

        # Make sure the nodeIDs can be found in the graph
        assert all(s in g for s in it.chain(*splits))
示例#4
0
def test_split_function_split_equally():
    # Example graph:
    g = create_example_graph_2()

    # We have to have a target value for the nodes
    targets = {n: int(2 * ii / g.number_of_nodes()) for ii, n in enumerate(g)}

    splits = train_val_test_split(
        g,
        node_type=None,
        test_size=2,
        train_size=4,
        targets=targets,
        split_equally=True,
        seed=None,
    )
    # For this number of nodes we should have 50% of the nodes as label 1
    assert sum(targets[s] for s in splits[0]) == len(splits[0]) // 2

    # Make sure the nodeIDs can be found in the graph
    assert all(s in g for s in it.chain(*splits))
示例#5
0
def test_split_function():
    # Example graph:
    for g in [create_example_graph_1(), create_example_graph_2()]:

        splits = train_val_test_split(
            g,
            node_type=None,
            test_size=2,
            train_size=3,
            targets=None,
            split_equally=False,
            seed=None,
        )
        assert len(splits[0]) == 3
        assert len(splits[1]) == 2
        assert len(splits[2]) == 2
        assert len(splits[3]) == 0

        print(splits)

        # Make sure the nodeIDs can be found in the graph
        assert all(s in g for s in it.chain(*splits))