def cut_straight(dendrogram: np.ndarray, n_clusters: Optional[int] = None, threshold: Optional[float] = None, sort_clusters: bool = True, return_dendrogram: bool = False) \ -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Cut a dendrogram and return the corresponding clustering. Parameters ---------- dendrogram: Dendrogram. n_clusters : Number of clusters (optional). The number of clusters can be larger than n_clusters in case of equal heights in the dendrogram. threshold : Threshold on height (optional). If both n_clusters and threshold are ``None``, n_clusters is set to 2. sort_clusters : If ``True``, sorts clusters in decreasing order of size. return_dendrogram : If ``True``, returns the dendrogram formed by the clusters up to the root. Returns ------- labels : np.ndarray Cluster of each node. dendrogram_aggregate : np.ndarray Dendrogram starting from clusters (leaves = clusters). Example ------- >>> from sknetwork.hierarchy import cut_straight >>> dendrogram = np.array([[0, 1, 0, 2], [2, 3, 1, 3]]) >>> cut_straight(dendrogram) array([0, 0, 1]) """ check_dendrogram(dendrogram) n = dendrogram.shape[0] + 1 if return_dendrogram and not np.all(np.diff(dendrogram[:, 2]) >= 0): raise ValueError( "The third column of the dendrogram must be non-decreasing.") cluster = {i: [i] for i in range(n)} if n_clusters is None: if threshold is None: n_clusters = 2 else: n_clusters = n else: check_n_clusters(n_clusters, n, n_min=1) cut = np.sort(dendrogram[:, 2])[n - n_clusters] if threshold is not None: cut = max(cut, threshold) for t in range(n - 1): i = int(dendrogram[t][0]) j = int(dendrogram[t][1]) if dendrogram[t][2] < cut and i in cluster and j in cluster: cluster[n + t] = cluster.pop(i) + cluster.pop(j) return get_labels(dendrogram, cluster, sort_clusters, return_dendrogram)
def cut_straight( dendrogram: np.ndarray, n_clusters: int = 2, sort_clusters: bool = True, return_dendrogram: bool = False ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Cut a dendrogram and return the corresponding clustering. Parameters ---------- dendrogram: Dendrogram n_clusters : Number of clusters. sort_clusters : If ``True``, sorts clusters in decreasing order of size. return_dendrogram : If ``True``, returns the dendrogram formed by the clusters up to the root. Returns ------- labels : np.ndarray Cluster of each node. dendrogram_aggregate : np.ndarray Dendrogram starting from clusters (leaves = clusters). Example ------- >>> from sknetwork.hierarchy import cut_straight >>> dendrogram = np.array([[0, 1, 0, 2], [2, 3, 1, 3]]) >>> cut_straight(dendrogram) array([0, 0, 1]) """ check_dendrogram(dendrogram) n = dendrogram.shape[0] + 1 check_n_clusters(n_clusters, n, n_min=1) if return_dendrogram and not np.all(np.diff(dendrogram[:, 2]) >= 0): raise ValueError( "The third column of the dendrogram must be non-decreasing.") cluster = {i: [i] for i in range(n)} cut = np.sort(dendrogram[:, 2])[n - n_clusters] for t in range(n - 1): i = int(dendrogram[t][0]) j = int(dendrogram[t][1]) if dendrogram[t][2] < cut and i in cluster and j in cluster: cluster[n + t] = cluster.pop(i) + cluster.pop(j) return get_labels(dendrogram, cluster, sort_clusters, return_dendrogram)
def cut_balanced( dendrogram: np.ndarray, max_cluster_size: int = 20, sort_clusters: bool = True, return_dendrogram: bool = False ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Cuts a dendrogram with a constraint on the cluster size and returns the corresponding clustering. Parameters ---------- dendrogram: Dendrogram max_cluster_size : Maximum size of each cluster. sort_clusters : If ``True``, sort labels in decreasing order of cluster size. return_dendrogram : If ``True``, returns the dendrogram formed by the clusters up to the root. Returns ------- labels : np.ndarray Label of each node. dendrogram_aggregate : np.ndarray Dendrogram starting from clusters (leaves = clusters). Example ------- >>> from sknetwork.hierarchy import cut_balanced >>> dendrogram = np.array([[0, 1, 0, 2], [2, 3, 1, 3]]) >>> cut_balanced(dendrogram, 2) array([0, 0, 1]) """ check_dendrogram(dendrogram) n = dendrogram.shape[0] + 1 if max_cluster_size < 2 or max_cluster_size > n: raise ValueError( "The maximum cluster size must be between 2 and the number of nodes." ) cluster = {i: [i] for i in range(n)} for t in range(n - 1): i = int(dendrogram[t][0]) j = int(dendrogram[t][1]) if i in cluster and j in cluster and len(cluster[i]) + len( cluster[j]) <= max_cluster_size: cluster[n + t] = cluster.pop(i) + cluster.pop(j) return get_labels(dendrogram, cluster, sort_clusters, return_dendrogram)