def test_height_linkage_tree(): # Check that the height of the results of linkage tree is sorted. rng = np.random.RandomState(0) mask = np.ones([10, 10], dtype=np.bool) X = rng.randn(50, 100) connectivity = grid_to_graph(*mask.shape) for linkage_func in _TREE_BUILDERS.values(): children, n_nodes, n_leaves, parent = linkage_func(X.T, connectivity) n_nodes = 2 * X.shape[1] - 1 assert len(children) + n_leaves == n_nodes
def test_n_components(): # Test n_components returned by linkage, average and ward tree rng = np.random.RandomState(0) X = rng.rand(5, 5) # Connectivity matrix having five components. connectivity = np.eye(5) for linkage_func in _TREE_BUILDERS.values(): assert ignore_warnings(linkage_func)(X, connectivity)[1] == 5
def test_structured_linkage_tree(): # Check that we obtain the correct solution for structured linkage trees. rng = np.random.RandomState(0) mask = np.ones([10, 10], dtype=np.bool) # Avoiding a mask with only 'True' entries mask[4:7, 4:7] = 0 X = rng.randn(50, 100) connectivity = grid_to_graph(*mask.shape) for tree_builder in _TREE_BUILDERS.values(): children, n_components, n_leaves, parent = \ tree_builder(X.T, connectivity) n_nodes = 2 * X.shape[1] - 1 assert len(children) + n_leaves == n_nodes # Check that ward_tree raises a ValueError with a connectivity matrix # of the wrong shape with pytest.raises(ValueError): tree_builder(X.T, np.ones((4, 4))) # Check that fitting with no samples raises an error with pytest.raises(ValueError): tree_builder(X.T[:0], connectivity)
def test_unstructured_linkage_tree(): # Check that we obtain the correct solution for unstructured linkage trees. rng = np.random.RandomState(0) X = rng.randn(50, 100) for this_X in (X, X[0]): # With specified a number of clusters just for the sake of # raising a warning and testing the warning code with ignore_warnings(): children, n_nodes, n_leaves, parent = assert_warns( UserWarning, ward_tree, this_X.T, n_clusters=10) n_nodes = 2 * X.shape[1] - 1 assert len(children) + n_leaves == n_nodes for tree_builder in _TREE_BUILDERS.values(): for this_X in (X, X[0]): with ignore_warnings(): children, n_nodes, n_leaves, parent = assert_warns( UserWarning, tree_builder, this_X.T, n_clusters=10) n_nodes = 2 * X.shape[1] - 1 assert len(children) + n_leaves == n_nodes