class Group: def __init__(self, name): self.name = name self.cnt = 0 self.seeding_counter = Counter() self.X = [] self.y_seed = [] self.clustering_model = None self.splitting_score = 0 def seeding_cnt(self): return sum(cnt for _, cnt in self.seeding_counter.items()) def add(self, x, y=None): self.cnt += 1 self.X.append(x) self.y_seed.append(y) if y is not None: self.seeding_counter[y] += 1 def major(self): # returns label, cnt return self.seeding_counter.most_common(1).pop() def get_seeds(self, X, y_seed): seeds = list(filter(lambda xy: xy[1] is not None, zip(X, y_seed))) return seeds def seeds(self): return self.get_seeds(self.X, self.y_seed) def has_collision(self, X, y_seed, model = None): # seeded group is said to be seeds = self.get_seeds(X, y_seed) # no seeds, no collision if len(seeds) == 0: return False seed_x, seed_y = list(zip(*seeds)) if model is None: seed_groups = [0 for i in range(len(seed_x))] else: seed_groups = model.predict(seed_x) y_by_label = {} for label, y in zip(seed_groups, seed_y): if not label in y_by_label: y_by_label[label] = y elif y_by_label[label] != y: return True return False def cluster(self, method='ward'): assert len(self.X) == len(self.y_seed) l_method = agglomerative_l_method(self.X, method=method) # first tier clustering, using agglomerative clustering self.clustering_model = DividableClustering() self.clustering_model.fit(self.X, l_method.labels_) # second tier, using kmeans for suspect_label in range(self.clustering_model.latest_label): ind_X = self.clustering_model.get_X_with_idx(suspect_label) y_seed = [] X = [] for x, idx in ind_X: X.append(x) y_seed.append(self.y_seed[idx]) # no collision in this sub-group if not self.has_collision(X, y_seed): continue # there is collisions in this sub-group low_cnt = 2 high_cnt = len(X) last_possible_labels = None while low_cnt <= high_cnt: # 1/4 biased binary search cluster_cnt = int((high_cnt - low_cnt) * 1/4 + low_cnt) kmeans = KMeans(cluster_cnt) kmeans.fit(X) if not self.has_collision(X, y_seed, kmeans): last_possible_labels = kmeans.labels_ high_cnt = cluster_cnt - 1 else: low_cnt = cluster_cnt + 1 self.splitting_score += cluster_cnt print('split sub_clusters_cnt:', cluster_cnt, 'cnt:', len(X), 'main cnt:', self.cnt) self.clustering_model.split(suspect_label, last_possible_labels) self.clustering_model.relabel()
dataset = get_iris() agg = AgglomerativeClustering(3) agg.fit(dataset.X) model = DividableClustering() model.fit(dataset.X, agg.labels_) print(len(model.X_by_label[0])) print(len(model.X_by_label[1])) print(len(model.X_by_label[2])) kmeans = KMeans(3) kmeans.fit(model.get_X(0)) model.split(0, kmeans.labels_) print(len(model.X_by_label[3])) print(len(model.X_by_label[4])) print(len(model.X_by_label[5])) print(model.X_by_label.keys()) model.relabel() print(model.X_by_label.keys()) print(len(model.X_by_label[0])) print(len(model.X_by_label[1])) print(len(model.X_by_label[2])) print(len(model.X_by_label[3])) print(len(model.X_by_label[4]))