def test_split_function_node_type(): # Example graph: g = create_example_graph_1() # This doesn't work if g is not a StellarGraph with pytest.raises(TypeError): splits = train_val_test_split( g, node_type="movie", test_size=1, train_size=2, targets=None, split_equally=False, seed=None, ) gs = StellarGraph(g) splits = train_val_test_split( gs, node_type="movie", test_size=1, train_size=2, targets=None, split_equally=False, seed=None, ) assert all(g.node[s]["label"] == "movie" for split in splits for s in split)
def test_split_function_unlabelled(): # Example graph: sg = create_example_graph_1() # Leave some of the nodes unlabelled: targets = {} for ii, n in enumerate(sg): if ii > 2: targets[n] = 1 splits = train_val_test_split( sg, node_type=None, test_size=2, train_size=2, targets=targets, split_equally=False, seed=None, ) # For this number of nodes we should have 50% of the nodes as label 1 # Note the length of val is still 2 even though we requested 1 assert len(splits[0]) == 2 assert len(splits[1]) == 0 assert len(splits[2]) == 2 assert len(splits[3]) == 3
def test_split_function_percent(): # Example graph: for g in [create_example_graph_1(), create_example_graph_2()]: # Test splits by proportion - note floor of the # number of samples splits = train_val_test_split( g, node_type=None, test_size=2.8 / 7, train_size=3.2 / 7, targets=None, seed=None, ) # Note the length of val is still 2 even though we requested 1 assert len(splits[0]) == 3 assert len(splits[1]) == 2 assert len(splits[2]) == 2 assert len(splits[3]) == 0 print(splits) # Make sure the nodeIDs can be found in the graph assert all(s in g for s in it.chain(*splits))
def test_split_function_split_equally(): # Example graph: g = create_example_graph_2() # We have to have a target value for the nodes targets = {n: int(2 * ii / g.number_of_nodes()) for ii, n in enumerate(g)} splits = train_val_test_split( g, node_type=None, test_size=2, train_size=4, targets=targets, split_equally=True, seed=None, ) # For this number of nodes we should have 50% of the nodes as label 1 assert sum(targets[s] for s in splits[0]) == len(splits[0]) // 2 # Make sure the nodeIDs can be found in the graph assert all(s in g for s in it.chain(*splits))
def test_split_function(): # Example graph: for g in [create_example_graph_1(), create_example_graph_2()]: splits = train_val_test_split( g, node_type=None, test_size=2, train_size=3, targets=None, split_equally=False, seed=None, ) assert len(splits[0]) == 3 assert len(splits[1]) == 2 assert len(splits[2]) == 2 assert len(splits[3]) == 0 print(splits) # Make sure the nodeIDs can be found in the graph assert all(s in g for s in it.chain(*splits))