def __initialize_models(self, feat, labels=None):
        self.data_size = feat.shape[0]
        self.feat_dim = feat.shape[1]
        if self.verbose:
            print('Pretraining Cluster Centers by KMeans')
        self.kmeans = KMeans(n_clusters=self.n_clusters,
                             n_init=20,
                             n_jobs=self.max_jobs,
                             verbose=False)
        self.last_pred = self.kmeans.fit_predict(feat)

        if labels is not None:
            tmp_acc = cluster_acc(labels, self.last_pred)
            if self.verbose:
                print('KMeans acc is {}'.format(tmp_acc))

        if self.verbose:
            print('Building Cluster Layer')
        # self.cluster_layer = ClusterNet(torch.Tensor(self.kmeans.cluster_centers_.astype(np.float32)))
        self.cluster_layer = ClusterNet(torch.from_numpy(self.kmeans.cluster_centers_.astype(np.float32)))
        if self.use_cuda:
            self.cluster_layer.cuda()
        if self.verbose:
            print('Building Optimizer')
        self.optimizer = optim.Adam(self.cluster_layer.parameters(), lr=self.lr)
class Text_IDEC(object):

    def __init__(self, root_dir, batch_size=256, n_clusters=4, fd_hidden_dim=10, layer_norm=True, lr=0.001,
                 direct_update=False, maxiter=2e4, update_interval=140, tol=0.001, gamma=0.1,
                 fine_tune_infersent=False, use_vat=False, use_tensorboard=False, semi_supervised=False, split_sents=False, id=0, verbose=True):
        # model's settings
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.fd_hidden_dim = fd_hidden_dim
        self.n_clusters = n_clusters
        self.layer_norm = layer_norm
        self.use_vat = use_vat
        self.semi_supervised = semi_supervised
        self.lr = lr
        self.direct_update = direct_update
        self.maxiter = maxiter
        self.update_interval = update_interval
        self.tol = tol
        self.gamma = gamma
        self.fine_tune_infersent = fine_tune_infersent
        self.verbose = verbose
        self.use_tensorboard = use_tensorboard
        self.id = id
        self.use_cuda = torch.cuda.is_available()
        self.split_sents = split_sents
        # data loader
        self.corpus_loader = Corpus_Loader(self.root_dir,
                                           layer_norm=self.layer_norm,
                                           verbose=self.verbose,
                                           use_cuda=self.use_cuda,
                                           semi_supervised=self.semi_supervised,
                                           split_sents=self.split_sents)
        # model's components
        self.kmeans = None
        # self.fd_ae = extract_sdae_text(dim=fd_hidden_dim)
        self.fd_ae = extract_sdae_model(input_dim=cfg.INPUT_DIM, hidden_dims=cfg.HIDDEN_DIMS)

        self.cluster_layer = None
        self.ae_criteron = nn.MSELoss()
        self.cluster_criteron = F.binary_cross_entropy
        self.optimizer = None
        # model's state
        self.current_p = None
        self.current_q = None
        self.current_pred_labels = None
        self.past_pred_labels = None
        self.current_cluster_acc = None
        self.current_cluster_nmi = None
        self.current_cluster_ari = None
        # model's logger
        self.logger_tensorboard = None
        # initialize model's parameters and update current state
        self.initialize_model()
        self.initialize_tensorboard()

    def initialize_tensorboard(self):
        outputdir = get_output_dir(self.root_dir)
        loggin_dir = os.path.join(outputdir, 'runs', 'clustering')
        if not os.path.exists(loggin_dir):
            os.makedirs(loggin_dir)
        self.logger_tensorboard = tensorboard_logger.Logger(os.path.join(loggin_dir, '{}'.format(self.id)))

    def initialize_model(self):
        if self.verbose:
            print('Loading pretrainded feedforward autoencoder')
        self.load_pretrained_fd_autoencoder()
        if self.verbose:
            print('Kmeans by hidden features')
        self.initialize_kmeans()
        if self.verbose:
            print('Kmeans cluster acc is {}'.format(self.current_cluster_acc))
            print('Kmeans cluster mni is {}'.format(self.current_cluster_nmi))
            print('Kmeans cluster ari is {}'.format(self.current_cluster_ari))
            print('Initialzing cluster layer by Kmeans centers')
        self.initialize_cluster_layer()
        if self.verbose:
            print('Initializing Adam optimzer, learning rate is {}'.format(self.lr))
        self.initialize_optimizer()
        if self.verbose:
            print('Updating target distribution')
        self.update_target_distribution()

    def load_pretrained_fd_autoencoder(self):
        """
        load pretrained stack denoise autoencoder
        """
        outputdir = get_output_dir(self.root_dir)
        ##########################
        outputdir = self.root_dir
        ##########################
        net_filename = os.path.join(outputdir, cfg.PRETRAINED_FAE_FILENAME)
        checkpoint = torch.load(net_filename)
        # there some problems when loading cuda pretrained models
        self.fd_ae.load_state_dict(checkpoint['state_dict'])
        if self.use_cuda:
            self.fd_ae.cuda()

    def initialize_optimizer(self):
        params = [
            {'params': self.fd_ae.parameters()},
            {'params': self.cluster_layer.parameters()}
        ]
        if self.fine_tune_infersent:
            params.append({'params': self.corpus_loader.infersent.parameters(), 'lr': 0.001 * self.lr})
        self.optimizer = optim.Adam(params, lr=self.lr)

    def initialize_kmeans(self):
        features = self.__get_initial_hidden_features()
        kmeans = KMeans(n_clusters=self.n_clusters, n_init=20)
        self.current_pred_labels = kmeans.fit_predict(features)
        self.update_cluster_acc()
        self.kmeans = kmeans

    def __get_initial_hidden_features(self):
        batch_size = self.batch_size
        features_numpy = self.corpus_loader.get_fixed_features()
        data_size = self.corpus_loader.data_size
        hidden_feat = np.zeros((data_size, self.fd_hidden_dim))
        for index in range(0, data_size, batch_size):
            data_batch = features_numpy[index: index+batch_size]
            data_batch = Variable(torch.Tensor(data_batch))
            if self.use_cuda:
                data_batch = data_batch.cuda()
            hidden_batch, _ = self.fd_ae(data_batch)
            hidden_batch = hidden_batch.data.cpu().numpy()
            hidden_feat[index: index+batch_size] = hidden_batch
        return hidden_feat

    #################################################################
    def get_current_hidden_features(self):
        return self.__get_initial_hidden_features()
    #################################################################

    def initialize_cluster_layer(self):
        self.cluster_layer = ClusterNet(torch.Tensor(self.kmeans.cluster_centers_.astype(np.float32)))
        if self.use_cuda:
            self.cluster_layer.cuda()

    def get_batch_target_distribution(self, batch_id):
        batch_target_distribution = self.current_p[batch_id]
        batch_target_distribution = Variable(torch.Tensor(batch_target_distribution))
        if self.use_cuda:
            batch_target_distribution = batch_target_distribution.cuda()
        return batch_target_distribution

    def update_target_distribution(self):
        data_size = self.corpus_loader.data_size
        all_q = np.zeros((data_size, self.n_clusters))
        tmp_size = 0
        for current_batch in self.corpus_loader.\
                train_data_iter(self.batch_size):
            id_batch = current_batch[2]
            if self.fine_tune_infersent:
                sent_feat = current_batch[3]
            else:
                sent_feat = current_batch[0]
            hidden_feat, _ = self.fd_ae(sent_feat)
            q_batch = self.cluster_layer(hidden_feat)
            q_batch = q_batch.cpu().data.numpy()
            all_q[id_batch] = q_batch
            tmp_size += len(id_batch)
        assert tmp_size == data_size
        all_p = self.target_distribution_numpy(all_q)
        self.current_p = all_p
        self.current_q = all_q
        self.update_pred_labels()
        self.update_cluster_acc()

    def update_pred_labels(self):
        # warning:
        # When running this function first time,
        # the value of self.past_pred_labels will be equal to self.current_pred_labels
        # This function shouldn't be called for successive times.
        self.past_pred_labels = self.current_pred_labels
        self.current_pred_labels = np.argmax(self.current_q, axis=1)

    def update_cluster_acc(self):
        from sklearn.metrics import normalized_mutual_info_score
        from sklearn.metrics import adjusted_mutual_info_score
        self.current_cluster_acc = cluster_acc(np.array(self.corpus_loader.train_labels), self.current_pred_labels)
        self.current_cluster_nmi = normalized_mutual_info_score(np.array(self.corpus_loader.train_labels), self.current_pred_labels)
        self.current_cluster_ari = adjusted_mutual_info_score(np.array(self.corpus_loader.train_labels), self.current_pred_labels)

    @staticmethod
    def target_distribution_torch(q):
        p = torch.pow(q, 2) / torch.sum(q, dim=0).unsqueeze(0)
        p = p / torch.sum(p, dim=1).unsqueeze(1)
        # p = torch.t(torch.t(p) / torch.sum(p, dim=1))
        return Variable(p.data)

    @staticmethod
    def target_distribution_numpy(q):
        p = np.power(q, 2) / np.sum(q, axis=0, keepdims=True)
        p = p / np.sum(p, axis=1, keepdims=True)
        return p

    def vat(self, x_batch, xi=0.1, Ip=1):
        # virtual adversarial training
        # forbid x_batch's grad backward
        x_batch = Variable(x_batch.data)
        hidden_batch, _ = self.fd_ae(x_batch)
        q_batch = self.cluster_layer(hidden_batch)
        q_batch = Variable(q_batch.data)
        # initialize residue d to normalized random vector
        d = torch.randn(x_batch.size())
        if self.use_cuda:
            d = d.cuda()
        d = d / (torch.norm(d, p=2, dim=1, keepdim=True) + 1e-8)
        # ensure model's parameter to be 0
        self.model_zero_grad()
        for i in range(Ip):
            d = nn.Parameter(d)
            tmp_x_batch = x_batch + xi * d
            tmp_hidden_batch, _ = self.fd_ae(tmp_x_batch)
            tmp_q_batch = self.cluster_layer(tmp_hidden_batch)
            tmp_loss = F.binary_cross_entropy(tmp_q_batch, q_batch)
            tmp_loss.backward()
            d = d.grad.data
            d = d / (torch.norm(d, p=2, dim=1, keepdim=True) + 1e-8)
            self.model_zero_grad()
        # computing vat loss
        d = Variable(d)
        tmp_x_batch = x_batch + xi * d
        tmp_hidden_batch, _ = self.fd_ae(tmp_x_batch)
        tmp_q_batch = self.cluster_layer(tmp_hidden_batch)
        tmp_loss = F.binary_cross_entropy(tmp_q_batch, q_batch)
        return tmp_loss

    def whether_convergence(self):
        delta_label = np.sum(self.past_pred_labels != self.current_pred_labels) / float(len(self.current_pred_labels))
        return delta_label < self.tol

    def model_zero_grad(self):
        self.cluster_layer.zero_grad()
        self.fd_ae.zero_grad()
        if self.fine_tune_infersent:
            self.corpus_loader.infersent.zero_grad()

    def clustering(self):
        if self.semi_supervised:
            train_data_iter = self.corpus_loader.train_data_iter(self.batch_size,
                                                                 return_variable_features=self.fine_tune_infersent,
                                                                 shuffle=False,
                                                                 infinite=True)
            constraints_data_iter = self.corpus_loader.constraint_data_iter(self.batch_size,
                                                                 shuffle=True,
                                                                 infinite=True)
            ite = 0
            tmp_ite_cons = 0
            while True:
                if random.random() > 0.95:
                    self.model_zero_grad()
                    feat_batch1, feat_batch2 = constraints_data_iter.next()
                    hidden_batch1, output_feat1 = self.fd_ae(feat_batch1)
                    hidden_batch2, output_feat2 = self.fd_ae(feat_batch2)
                    ae_loss1 = self.ae_criteron(output_feat1, feat_batch1)
                    ae_loss2 = self.ae_criteron(output_feat2, feat_batch2)
                    q_batch1 = self.cluster_layer(hidden_batch1)
                    q_batch2 = self.cluster_layer(hidden_batch2)
                    if random.random() > 0.5:
                        q_batch1, q_batch2 = q_batch2, q_batch1
                    q_batch2 = Variable(q_batch2.data)
                    k_loss = self.cluster_criteron(q_batch1, q_batch2)
                    loss = 2 * self.gamma * k_loss + ae_loss1 + ae_loss2

                    if self.use_tensorboard:
                        self.logger_tensorboard.log_value('cons_loss', loss.data[0], tmp_ite_cons)
                        self.logger_tensorboard.log_value('cons_kl_loss', k_loss.data[0], tmp_ite_cons)
                    loss.backward()
                    self.optimizer.step()
                    tmp_ite_cons += 1
                else:
                    if ite % self.update_interval == (self.update_interval - 1):
                        self.update_target_distribution()
                        print('Iter {} acc {} nmi {} ari {}'.format(ite, self.current_cluster_acc, self.current_cluster_nmi, self.current_cluster_ari))
                        if self.use_tensorboard:
                            self.logger_tensorboard.log_value('acc', self.current_cluster_acc, ite)
                        if ite > 0 and self.whether_convergence():
                            break

                    current_batch = train_data_iter.next()
                    fixed_feat_batch = current_batch[0]
                    id_batch = current_batch[2]
                    if self.fine_tune_infersent:
                        sent_feat_batch = current_batch[3]
                    else:
                        sent_feat_batch = fixed_feat_batch

                    self.model_zero_grad()
                    hidden_batch, output_batch = self.fd_ae(sent_feat_batch)
                    q_batch = self.cluster_layer(hidden_batch)
                    if self.direct_update:
                        p_batch = self.target_distribution_torch(q_batch)
                    else:
                        p_batch = self.get_batch_target_distribution(id_batch)
                    ae_loss = self.ae_criteron(output_batch, fixed_feat_batch)
                    cluster_loss = self.cluster_criteron(q_batch, p_batch)
                    if self.use_vat:
                        vat_loss = self.vat(sent_feat_batch)
                    else:
                        vat_loss = 0
                    loss = self.gamma * (cluster_loss + vat_loss) + ae_loss
                    if self.use_tensorboard:
                        self.logger_tensorboard.log_value('cluster_loss', cluster_loss.data[0], ite)
                        self.logger_tensorboard.log_value('ae_loss', ae_loss.data[0], ite)
                        if self.use_vat:
                            self.logger_tensorboard.log_value('vat_loss', vat_loss.data[0], ite)
                        self.logger_tensorboard.log_value('loss', loss.data[0], ite)
                    loss.backward()
                    self.optimizer.step()
                    ######################################
                    ite += 1
                    if ite >= int(self.maxiter):
                        break
                    ######################################
        else:
            train_data_iter = self.corpus_loader.train_data_iter(self.batch_size,
                                                                 # return_variable_features=self.fine_tune_infersent,
                                                                 shuffle=False,
                                                                 infinite=True)
            for ite in range(int(self.maxiter)):
                if ite % self.update_interval == (self.update_interval - 1):
                    self.update_target_distribution()
                    print('Iter {} acc {} nmi {} ari {}'.format(ite, self.current_cluster_acc, self.current_cluster_nmi, self.current_cluster_ari))
                    if self.use_tensorboard:
                        self.logger_tensorboard.log_value('acc', self.current_cluster_acc, ite)
                    if ite > 0 and self.whether_convergence():
                        break

                current_batch = train_data_iter.next()
                fixed_feat_batch = current_batch[0]
                id_batch = current_batch[2]
                if self.fine_tune_infersent:
                    sent_feat_batch = current_batch[3]
                else:
                    sent_feat_batch = fixed_feat_batch

                self.model_zero_grad()
                hidden_batch, output_batch = self.fd_ae(sent_feat_batch)
                q_batch = self.cluster_layer(hidden_batch)
                if self.direct_update:
                    p_batch = self.target_distribution_torch(q_batch)
                else:
                    p_batch = self.get_batch_target_distribution(id_batch)
                #############################################################
                # ae_loss = self.ae_criteron(output_batch, fixed_feat_batch)
                ae_loss = 0.0
                #############################################################
                cluster_loss = self.cluster_criteron(q_batch, p_batch)
                if self.use_vat:
                    vat_loss = self.vat(sent_feat_batch)
                else:
                    vat_loss = 0
                loss = self.gamma * (cluster_loss + vat_loss) + ae_loss
                if self.use_tensorboard:
                    self.logger_tensorboard.log_value('cluster_loss', cluster_loss.data[0], ite)
                    #############################################################
                    # self.logger_tensorboard.log_value('ae_loss', ae_loss.data[0], ite)
                    #############################################################
                    if self.use_vat:
                        self.logger_tensorboard.log_value('vat_loss', vat_loss.data[0], ite)
                    self.logger_tensorboard.log_value('loss', loss.data[0], ite)
                loss.backward()
                self.optimizer.step()
class EnhancedKMeans(object):
    def __init__(self,
                 n_clusters=4,
                 update_interval=2,
                 tol=0.001,
                 lr=0.001,
                 maxiter=2e4,
                 batch_size=64,
                 max_jobs=10,
                 use_cuda=torch.cuda.is_available(),
                 logger=None,
                 verbose=False):
        self.n_clusters = n_clusters
        self.feat_dim = None
        self.data_size = None
        self.update_interval = update_interval
        self.tol = tol
        self.lr = lr
        self.maxiter = maxiter
        self.batch_size = batch_size
        self.max_jobs = max_jobs
        self.use_cuda = use_cuda
        self.verbose = verbose
        self.logger = logger
        if logger is not None:
            assert isinstance(self.logger, EKMLogger)
        self.kmeans = None
        self.cluster_layer = None
        self.optimizer = None
        self.last_pred = None
        self.current_p = None
        self.current_q = None

    def __initialize_models(self, feat, labels=None):
        self.data_size = feat.shape[0]
        self.feat_dim = feat.shape[1]
        if self.verbose:
            print('Pretraining Cluster Centers by KMeans')
        self.kmeans = KMeans(n_clusters=self.n_clusters,
                             n_init=20,
                             n_jobs=self.max_jobs,
                             verbose=False)
        self.last_pred = self.kmeans.fit_predict(feat)

        if labels is not None:
            tmp_acc = cluster_acc(labels, self.last_pred)
            if self.verbose:
                print('KMeans acc is {}'.format(tmp_acc))

        if self.verbose:
            print('Building Cluster Layer')
        # self.cluster_layer = ClusterNet(torch.Tensor(self.kmeans.cluster_centers_.astype(np.float32)))
        self.cluster_layer = ClusterNet(torch.from_numpy(self.kmeans.cluster_centers_.astype(np.float32)))
        if self.use_cuda:
            self.cluster_layer.cuda()
        if self.verbose:
            print('Building Optimizer')
        self.optimizer = optim.Adam(self.cluster_layer.parameters(), lr=self.lr)
        # self.optimizer = optim.SGD(self.cluster_layer.parameters(), lr=self.lr)

    def __update_target_distribute(self, feat):
        if self.verbose:
            print('Updating Target Distribution')
        all_q = np.zeros((self.data_size, self.n_clusters))
        tmp_size = 0
        for i in range(0, self.data_size, self.batch_size):
            tmp_feat = feat[i:i+self.batch_size].astype(np.float32)
            tmp_feat = Variable(torch.from_numpy(tmp_feat))
            if self.use_cuda:
                tmp_feat = tmp_feat.cuda()
            q_batch = self.cluster_layer(tmp_feat)
            q_batch = q_batch.cpu().data.numpy()
            all_q[i:i+self.batch_size] = q_batch
            tmp_size += len(q_batch)
        assert tmp_size == self.data_size
        self.current_q = all_q
        self.current_p = self.__get_target_distribution(self.current_q)

    @staticmethod
    def __get_target_distribution(q):
        p = np.power(q, 2) / np.sum(q, axis=0, keepdims=True)
        p = p / np.sum(p, axis=1, keepdims=True)
        return p

    @staticmethod
    def __get_label_pred(q):
        pred = np.argmax(q, axis=1)
        return pred

    def __whether_convergence(self, pred_cur, pred_last):
        delta_label = np.sum(pred_cur != pred_last) / float(len(pred_cur))
        return delta_label < self.tol

    def fit(self, feat, labels=None):
        self.__initialize_models(feat, labels=labels)
        self.__update_target_distribute(feat)

        if self.verbose:
            print('Begin to Iterate')
        index = 0
        for ite in range(int(self.maxiter)):
            if ite % self.update_interval == (self.update_interval - 1):
                self.__update_target_distribute(feat)
                tmp_pred_cur = self.__get_label_pred(self.current_q)
                acc = None
                if labels is not None:
                    acc = cluster_acc(labels, tmp_pred_cur)
                    if self.logger is not None:
                        self.logger.record_acc(acc, ite)
                if self.verbose:
                    if acc is not None:
                        print('Iter {} Acc {}'.format(ite,acc))
                    else:
                        print('Update Target Distribution in Iter {}'.format(ite))

                if ite > 0 and self.__whether_convergence(tmp_pred_cur, self.last_pred):
                    break
                self.last_pred = tmp_pred_cur

            if index + self.batch_size > self.data_size:
                feat_batch = feat[index:]
                p_batch = self.current_p[index:]
                index = 0
            else:
                feat_batch = feat[index: index + self.batch_size]
                p_batch = self.current_p[index: index + self.batch_size]
            feat_batch = Variable(torch.from_numpy(feat_batch.astype(np.float32)))
            p_batch = Variable(torch.from_numpy(p_batch.astype(np.float32)))
            if self.use_cuda:
                feat_batch = feat_batch.cuda()
                p_batch = p_batch.cuda()

            self.cluster_layer.zero_grad()
            q_batch = self.cluster_layer(feat_batch)
            cluster_loss = F.binary_cross_entropy(q_batch, p_batch)
            if self.logger is not None:
                self.logger.record_loss(cluster_loss.data[0], ite)
            cluster_loss.backward()
            self.optimizer.step()
 def initialize_cluster_layer(self):
     self.cluster_layer = ClusterNet(torch.Tensor(self.kmeans.cluster_centers_.astype(np.float32)))
     if self.use_cuda:
         self.cluster_layer.cuda()