class Group:
    def __init__(self, name):
        self.name = name
        self.cnt = 0
        self.seeding_counter = Counter()
        self.X = []
        self.y_seed = []
        self.clustering_model = None
        self.splitting_score = 0

    def seeding_cnt(self):
        return sum(cnt for _, cnt in self.seeding_counter.items())

    def add(self, x, y=None):
        self.cnt += 1
        self.X.append(x)
        self.y_seed.append(y)

        if y is not None:
            self.seeding_counter[y] += 1

    def major(self):
        # returns label, cnt
        return self.seeding_counter.most_common(1).pop()

    def get_seeds(self, X, y_seed):
        seeds = list(filter(lambda xy: xy[1] is not None,
                            zip(X, y_seed)))
        return seeds

    def seeds(self):
        return self.get_seeds(self.X, self.y_seed)

    def has_collision(self, X, y_seed, model = None):
        # seeded group is said to be
        seeds = self.get_seeds(X, y_seed)

        # no seeds, no collision
        if len(seeds) == 0:
            return False

        seed_x, seed_y = list(zip(*seeds))

        if model is None:
            seed_groups = [0 for i in range(len(seed_x))]
        else:
            seed_groups = model.predict(seed_x)

        y_by_label = {}
        for label, y in zip(seed_groups, seed_y):
            if not label in y_by_label:
                y_by_label[label] = y
            elif y_by_label[label] != y:
                return True

        return False

    def cluster(self, method='ward'):
        assert len(self.X) == len(self.y_seed)

        l_method = agglomerative_l_method(self.X, method=method)

        # first tier clustering, using agglomerative clustering
        self.clustering_model = DividableClustering()
        self.clustering_model.fit(self.X, l_method.labels_)

        # second tier, using kmeans
        for suspect_label in range(self.clustering_model.latest_label):
            ind_X = self.clustering_model.get_X_with_idx(suspect_label)
            y_seed = []
            X = []
            for x, idx in ind_X:
                X.append(x)
                y_seed.append(self.y_seed[idx])

            # no collision in this sub-group
            if not self.has_collision(X, y_seed):
                continue

            # there is collisions in this sub-group
            low_cnt = 2
            high_cnt = len(X)
            last_possible_labels = None
            while low_cnt <= high_cnt:
                # 1/4 biased binary search
                cluster_cnt = int((high_cnt - low_cnt) * 1/4 + low_cnt)
                kmeans = KMeans(cluster_cnt)
                kmeans.fit(X)

                if not self.has_collision(X, y_seed, kmeans):
                    last_possible_labels = kmeans.labels_
                    high_cnt = cluster_cnt - 1
                else:
                    low_cnt = cluster_cnt + 1

            self.splitting_score += cluster_cnt
            print('split sub_clusters_cnt:', cluster_cnt, 'cnt:', len(X), 'main cnt:', self.cnt)
            self.clustering_model.split(suspect_label, last_possible_labels)

        self.clustering_model.relabel()
dataset = get_iris()

agg = AgglomerativeClustering(3)
agg.fit(dataset.X)

model = DividableClustering()
model.fit(dataset.X, agg.labels_)

print(len(model.X_by_label[0]))
print(len(model.X_by_label[1]))
print(len(model.X_by_label[2]))

kmeans = KMeans(3)
kmeans.fit(model.get_X(0))

model.split(0, kmeans.labels_)

print(len(model.X_by_label[3]))
print(len(model.X_by_label[4]))
print(len(model.X_by_label[5]))

print(model.X_by_label.keys())

model.relabel()

print(model.X_by_label.keys())
print(len(model.X_by_label[0]))
print(len(model.X_by_label[1]))
print(len(model.X_by_label[2]))
print(len(model.X_by_label[3]))
print(len(model.X_by_label[4]))