Ejemplo n.º 1
0
    def fit(self, X, lr=0.001, batch_size=256, num_epochs=10, save_path=None):
        num = len(X)
        num_batch = int(math.ceil(1.0 * len(X) / batch_size))
        '''X: tensor data'''
        self.to(self.device)
        print("=====Training DEC=======")
        # optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr)
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     self.parameters()),
                              lr=lr,
                              momentum=0.9)

        print("Extracting initial features at %s" %
              (str(datetime.datetime.now())))
        image_z = []
        text_z = []
        for batch_idx in range(num_batch):
            image_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                       batch_size, num)][1]
            text_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                      batch_size, num)][2]
            image_inputs = Variable(image_batch).to(self.device)
            text_inputs = Variable(text_batch).to(self.device)
            _image_z, _text_z = self.forward(image_inputs, text_inputs)
            image_z.append(_image_z.data.cpu())
            text_z.append(_text_z.data.cpu())
            del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z
            torch.cuda.empty_cache()
        image_z = torch.cat(image_z, dim=0)
        text_z = torch.cat(text_z, dim=0)

        print("Initializing cluster centers with kmeans at %s" %
              (str(datetime.datetime.now())))
        image_kmeans = KMeans(self.n_clusters, n_init=20)
        image_pred = image_kmeans.fit_predict(image_z.data.cpu().numpy())
        print("Image kmeans completed at %s" % (str(datetime.datetime.now())))

        text_kmeans = KMeans(self.n_clusters, n_init=20)
        text_pred = text_kmeans.fit_predict(text_z.data.cpu().numpy())
        print("Text kmeans completed at %s" % (str(datetime.datetime.now())))

        image_ind, text_ind = align_cluster(image_pred, text_pred)

        image_cluster_centers = np.zeros_like(image_kmeans.cluster_centers_)
        text_cluster_centers = np.zeros_like(text_kmeans.cluster_centers_)

        for i in range(self.n_clusters):
            image_cluster_centers[i] = image_kmeans.cluster_centers_[
                image_ind[i]]
            text_cluster_centers[i] = text_kmeans.cluster_centers_[text_ind[i]]
        self.image_encoder.mu.data.copy_(torch.Tensor(image_cluster_centers))
        self.image_encoder.mu.data = self.image_encoder.mu.cpu()
        self.text_encoder.mu.data.copy_(torch.Tensor(text_cluster_centers))
        self.text_encoder.mu.data = self.text_encoder.mu.cpu()
        self.train()
        best_loss = 99999.
        best_epoch = 0

        for epoch in range(num_epochs):
            # update the target distribution p

            image_z = []
            text_z = []
            for batch_idx in range(num_batch):
                image_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                           batch_size, num)][1]
                text_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                          batch_size, num)][2]
                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                image_z.append(_image_z.data.cpu())
                text_z.append(_text_z.data.cpu())
                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z
                torch.cuda.empty_cache()
            image_z = torch.cat(image_z, dim=0)
            text_z = torch.cat(text_z, dim=0)

            q, r = self.soft_assignemt(image_z, text_z)
            p = self.target_distribution(q, r).data
            y_pred = torch.argmax(p, dim=1).numpy()
            count_percentage(y_pred)
            # train 1 epoch
            train_loss = 0.0
            for batch_idx in range(num_batch):
                image_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                           batch_size, num)][1]
                text_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                          batch_size, num)][2]
                pbatch = p[batch_idx * batch_size:min((batch_idx + 1) *
                                                      batch_size, num)]

                optimizer.zero_grad()
                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                target = Variable(pbatch)

                image_z, text_z = self.forward(image_inputs, text_inputs)
                qbatch, rbatch = self.soft_assignemt(image_z.cpu(),
                                                     text_z.cpu())
                loss = self.loss_function(target, qbatch, rbatch)
                train_loss += loss.data * len(target)
                loss.backward()
                optimizer.step()

                del image_batch, text_batch, image_inputs, text_inputs, image_z, text_z
                torch.cuda.empty_cache()
            train_loss = train_loss / num
            if best_loss > train_loss:
                best_loss = train_loss
                best_epoch = epoch
                if save_path:
                    self.save_model(
                        os.path.join(save_path, "mdec_" +
                                     str(self.image_encoder.z_dim)) + '_' +
                        str(self.n_clusters) + ".pt")
            print("#Epoch %3d: Loss: %.4f Best Loss: %.4f at %s" %
                  (epoch + 1, train_loss, best_loss,
                   str(datetime.datetime.now())))

        print("#Best Epoch %3d: Best Loss: %.4f" % (best_epoch, best_loss))
Ejemplo n.º 2
0
    def fit_predict(self,
                    X,
                    train_dataset,
                    test_dataset,
                    lr=0.001,
                    batch_size=256,
                    num_epochs=10,
                    update_time=1,
                    save_path=None,
                    tol=1e-3,
                    kappa=0.1):
        X_num = len(X)
        X_num_batch = int(math.ceil(1.0 * len(X) / batch_size))
        train_num = len(train_dataset)
        train_num_batch = int(math.ceil(1.0 * len(train_dataset) / batch_size))
        '''X: tensor data'''
        self.to(self.device)
        self.encoder.mu.data = self.encoder.mu.cpu()
        print("=====Training DEC=======")
        trainloader = torch.utils.data.DataLoader(train_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=True)
        validloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False)
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     self.parameters()),
                              lr=lr,
                              momentum=0.9)

        print("Extracting initial features at %s" %
              (str(datetime.datetime.now())))
        z = self.update_z(X, batch_size)
        train_z = self.update_z(train_dataset, batch_size)
        print("Initializing cluster centers with kmeans at %s" %
              (str(datetime.datetime.now())))
        kmeans = KMeans(self.n_clusters, n_init=20)
        kmeans.fit(z.data.cpu().numpy())
        train_pred = kmeans.predict(train_z.data.cpu().numpy())
        print("kmeans completed at %s" % (str(datetime.datetime.now())))

        short_codes = X[:][0]
        train_short_codes = train_dataset[:][0]
        train_labels = train_dataset[:][2].data.cpu().numpy()
        df_train = pd.DataFrame(data=train_labels,
                                index=train_short_codes,
                                columns=['label'])
        _, ind = align_cluster(train_labels, train_pred)

        cluster_centers = np.zeros_like(kmeans.cluster_centers_)
        for i in range(self.n_clusters):
            cluster_centers[i] = kmeans.cluster_centers_[ind[i]]
        self.encoder.mu.data.copy_(torch.Tensor(cluster_centers))
        self.encoder.mu.data = self.encoder.mu.cpu()

        if self.use_prior:
            for label in train_labels:
                self.prior[label] = self.prior[label] + 1
            self.prior = self.prior / len(train_labels)
        for epoch in range(num_epochs):
            # update the target distribution p
            self.train()
            # train 1 epoch
            train_loss = 0.0
            semi_train_loss = 0.0

            adjust_learning_rate(lr, optimizer)

            for batch_idx in range(train_num_batch):
                # semi-supervised phase
                data_batch = train_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, train_num)][1]
                label_batch = train_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, train_num)][2]

                optimizer.zero_grad()
                data_inputs = Variable(data_batch).to(self.device)
                label_inputs = Variable(label_batch)

                _z = self.forward(data_inputs)
                qbatch = self.soft_assignemt(_z.cpu())
                semi_loss = self.semi_loss_function(label_inputs, qbatch)
                semi_train_loss += semi_loss.data * len(label_inputs)
                semi_loss.backward()
                optimizer.step()

                del data_batch, data_inputs, _z

            z = self.update_z(X, batch_size)
            q = self.soft_assignemt(z)
            p = self.target_distribution(q).data

            adjust_learning_rate(lr * kappa, optimizer)

            for batch_idx in range(X_num_batch):
                # clustering phase
                data_batch = X[batch_idx *
                               batch_size:min((batch_idx + 1) *
                                              batch_size, X_num)][1]
                pbatch = p[batch_idx * batch_size:min((batch_idx + 1) *
                                                      batch_size, X_num)]

                optimizer.zero_grad()
                data_inputs = Variable(data_batch).to(self.device)
                p_inputs = Variable(pbatch)

                _z = self.forward(data_inputs)
                qbatch = self.soft_assignemt(_z.cpu())
                loss = self.loss_function(p_inputs, qbatch)
                train_loss += loss.data * len(p_inputs)
                loss.backward()
                optimizer.step()

                del data_batch, data_inputs, _z
            train_loss = train_loss / X_num
            semi_train_loss = semi_train_loss / train_num

            train_pred = torch.argmax(p, dim=1).numpy()
            df_pred = pd.DataFrame(data=train_pred,
                                   index=short_codes,
                                   columns=['pred'])
            df_pred = df_pred.loc[df_train.index]
            train_pred = df_pred['pred']
            train_acc = accuracy_score(train_labels, train_pred)
            train_nmi = normalized_mutual_info_score(
                train_labels, train_pred, average_method='geometric')
            train_f_1 = f1_score(train_labels, train_pred, average='macro')
            print(
                "#Epoch %3d: acc: %.4f, nmi: %.4f, f_1: %.4f, loss: %.4f, semi_loss: %.4f, at %s"
                % (epoch + 1, train_acc, train_nmi, train_f_1, train_loss,
                   semi_train_loss, str(datetime.datetime.now())))
            if epoch == 0:
                train_pred_last = train_pred
            else:
                delta_label = np.sum(train_pred != train_pred_last).astype(
                    np.float32) / len(train_pred)
                train_pred_last = train_pred
                if delta_label < tol:
                    print('delta_label ', delta_label, '< tol ', tol)
                    print("Reach tolerance threshold. Stopping training.")
                    break

        self.eval()
        test_labels = test_dataset[:][2].squeeze(dim=0)
        test_z = self.update_z(test_dataset, batch_size)
        z = torch.cat([z, test_z], dim=0)

        q = self.soft_assignemt(z)
        test_p = self.target_distribution(q).data
        test_pred = torch.argmax(test_p, dim=1).numpy()[X_num:]
        test_acc = accuracy_score(test_labels, test_pred)

        test_short_codes = test_dataset[:][0]
        test_short_codes = np.concatenate([short_codes, test_short_codes],
                                          axis=0)
        df_test = pd.DataFrame(data=torch.argmax(test_p, dim=1).numpy(),
                               index=test_short_codes,
                               columns=['labels'])
        df_test.to_csv('udec_label.csv', encoding='utf-8-sig')
        df_test_p = pd.DataFrame(data=test_p.data.numpy(),
                                 index=test_short_codes)
        df_test_p.to_csv('udec_p.csv', encoding='utf-8-sig')
        test_nmi = normalized_mutual_info_score(test_labels,
                                                test_pred,
                                                average_method='geometric')
        test_f_1 = f1_score(test_labels, test_pred, average='macro')
        print("#Test acc: %.4f, Test nmi: %.4f, Test f_1: %.4f" %
              (test_acc, test_nmi, test_f_1))
        self.acc = test_acc
        self.nmi = test_nmi
        self.f_1 = test_f_1
        if save_path:
            self.save_model(save_path)
Ejemplo n.º 3
0
    def fit_predict(self,
                    full_dataset,
                    train_dataset,
                    test_dataset,
                    args,
                    CONFIG,
                    lr=0.001,
                    batch_size=256,
                    num_epochs=10,
                    update_time=1,
                    save_path=None,
                    tol=1e-3,
                    kappa=0.1):
        full_num = len(full_dataset)
        full_num_batch = int(math.ceil(1.0 * len(full_dataset) / batch_size))
        train_num = len(train_dataset)
        train_num_batch = int(math.ceil(1.0 * len(train_dataset) / batch_size))
        test_num = len(test_dataset)
        test_num_batch = int(math.ceil(1.0 * len(test_dataset) / batch_size))
        '''X: tensor data'''
        self.to(self.device)
        print("=====Training DEC=======")
        if args.adam:
            optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          self.parameters()),
                                   lr=lr)
        else:
            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         self.parameters()),
                                  lr=lr,
                                  momentum=0.9)
        full_short_codes = full_dataset[:][0]
        train_short_codes = train_dataset[:][0]
        test_short_codes = test_dataset[:][0]
        train_labels = train_dataset[:][3].squeeze(dim=0).data.cpu().numpy()
        test_labels = test_dataset[:][3].squeeze(dim=0).data.cpu().numpy()
        df_train = pd.DataFrame(data=train_labels,
                                index=train_short_codes,
                                columns=['label'])
        df_test = pd.DataFrame(data=test_labels,
                               index=test_short_codes,
                               columns=['label'])

        if not args.resume:
            print("Extracting initial features at %s" %
                  (str(datetime.datetime.now())))
            image_z = []
            text_z = []
            for batch_idx in range(full_num_batch):
                image_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][1]
                text_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][2]
                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                image_z.append(_image_z.data.cpu())
                text_z.append(_text_z.data.cpu())
                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z
            image_z = torch.cat(image_z, dim=0)
            text_z = torch.cat(text_z, dim=0)

            train_image_z = []
            train_text_z = []
            for batch_idx in range(train_num_batch):
                image_batch = train_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, train_num)][1]
                text_batch = train_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, train_num)][2]
                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                train_image_z.append(_image_z.data.cpu())
                train_text_z.append(_text_z.data.cpu())
                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z
            train_image_z = torch.cat(train_image_z, dim=0)
            train_text_z = torch.cat(train_text_z, dim=0)

            print("Initializing cluster centers with kmeans at %s" %
                  (str(datetime.datetime.now())))
            image_kmeans = KMeans(n_clusters=self.n_clusters,
                                  n_init=20,
                                  random_state=42)
            image_kmeans.fit(image_z.data.cpu().numpy())
            train_image_pred = image_kmeans.predict(
                train_image_z.data.cpu().numpy())
            print("Image kmeans completed at %s" %
                  (str(datetime.datetime.now())))

            text_kmeans = KMeans(n_clusters=self.n_clusters,
                                 n_init=20,
                                 random_state=42)
            text_kmeans.fit(text_z.data.cpu().numpy())
            train_text_pred = text_kmeans.predict(
                train_text_z.data.cpu().numpy())
            print("Text kmeans completed at %s" %
                  (str(datetime.datetime.now())))

            _, image_ind = align_cluster(train_labels, train_image_pred)
            _, text_ind = align_cluster(train_labels, train_text_pred)

            image_cluster_centers = np.zeros_like(
                image_kmeans.cluster_centers_)
            text_cluster_centers = np.zeros_like(text_kmeans.cluster_centers_)
            for i in range(self.n_clusters):
                image_cluster_centers[i] = image_kmeans.cluster_centers_[
                    image_ind[i]]
                text_cluster_centers[i] = text_kmeans.cluster_centers_[
                    text_ind[i]]
            self.image_encoder.mu.data.copy_(
                torch.Tensor(image_cluster_centers))
            self.text_encoder.mu.data.copy_(torch.Tensor(text_cluster_centers))

        if self.use_prior:
            for label in train_labels:
                self.prior[label] = self.prior[label] + 1
            self.prior /= len(train_labels)

        print("Calculating initial p at %s" % (str(datetime.datetime.now())))
        # update p considering short memory
        s = []
        for batch_idx in range(full_num_batch):
            image_batch = full_dataset[batch_idx *
                                       batch_size:min((batch_idx + 1) *
                                                      batch_size, full_num)][1]
            text_batch = full_dataset[batch_idx *
                                      batch_size:min((batch_idx + 1) *
                                                     batch_size, full_num)][2]

            image_inputs = Variable(image_batch).to(self.device)
            text_inputs = Variable(text_batch).to(self.device)

            _image_z, _text_z = self.forward(image_inputs, text_inputs)
            _q, _r = self.soft_assignemt(_image_z, _text_z)
            _s = self.probabililty_fusion(_q, _r, _image_z, _text_z)
            s.append(_s.data.cpu())

            del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s

        for batch_idx in range(test_num_batch):
            image_batch = test_dataset[batch_idx *
                                       batch_size:min((batch_idx + 1) *
                                                      batch_size, test_num)][1]
            text_batch = test_dataset[batch_idx *
                                      batch_size:min((batch_idx + 1) *
                                                     batch_size, test_num)][2]

            image_inputs = Variable(image_batch).to(self.device)
            text_inputs = Variable(text_batch).to(self.device)

            _image_z, _text_z = self.forward(image_inputs, text_inputs)
            _q, _r = self.soft_assignemt(_image_z, _text_z)
            _s = self.probabililty_fusion(_q, _r, _image_z, _text_z)
            s.append(_s.data.cpu())

            del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s

        s = torch.cat(s, dim=0)

        p = self.target_distribution(s)
        initial_pred = torch.argmax(s, dim=1).numpy()
        initial_acc = accuracy_score(test_labels, initial_pred[full_num:])
        initial_nmi = normalized_mutual_info_score(test_labels,
                                                   initial_pred[full_num:],
                                                   average_method='geometric')
        initial_f_1 = f1_score(test_labels,
                               initial_pred[full_num:],
                               average='macro')
        print("#Initial measure: acc: %.4f, nmi: %.4f, f_1: %.4f" %
              (initial_acc, initial_nmi, initial_f_1))
        df_initial = pd.DataFrame(data=initial_pred,
                                  index=full_short_codes + test_short_codes,
                                  columns=['label'])
        df_initial['pred'] = 'pred'
        df_initial.loc[df_train.index, 'pred'] = 'label'
        for idx, row in df_train.iterrows():
            df_initial.loc[idx, 'label'] = row['label']
        df_initial.loc[df_test.index, 'pred'] = 'label'
        for idx, row in df_test.iterrows():
            df_initial.loc[idx, 'label'] = row['label']

        if args.tsne:
            print("Conducting initial TSNE at %s" %
                  (str(datetime.datetime.now())))
            do_tsne(p.numpy(), df_initial, self.n_clusters,
                    os.path.join(CONFIG.SVG_PATH, args.gpu, 'epoch_000.png'))
            print("TSNE completed at %s" % (str(datetime.datetime.now())))

        flag_end_training = False
        for epoch in range(num_epochs):
            print("Epoch %d at %s" % (epoch, str(datetime.datetime.now())))
            # update the target distribution p
            self.train()
            # train 1 epoch
            train_unsupervised_loss = 0.0
            train_supervised_image_loss = 0.0
            train_supervised_text_loss = 0.0
            adjust_learning_rate(lr, optimizer)
            for batch_idx in range(train_num_batch):
                # supervised phase
                image_batch = train_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, train_num)][1]
                text_batch = train_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, train_num)][2]
                label_batch = train_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, train_num)][3].squeeze(dim=0)

                optimizer.zero_grad()
                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                label_inputs = Variable(label_batch).to(self.device)

                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                qbatch, rbatch = self.soft_assignemt(_image_z, _text_z)
                supervised_image_loss, supervised_text_loss = self.semi_loss_function(
                    label_inputs, qbatch, rbatch)
                train_supervised_image_loss += supervised_image_loss.data * len(
                    label_inputs)
                train_supervised_text_loss += supervised_text_loss.data * len(
                    label_inputs)
                supervised_loss = supervised_image_loss + supervised_text_loss
                supervised_loss.backward()
                optimizer.step()

                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z

            # update p considering short memory
            s = []
            for batch_idx in range(full_num_batch):
                image_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][1]
                text_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][2]

                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)

                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                _q, _r = self.soft_assignemt(_image_z, _text_z)
                _s = self.probabililty_fusion(_q, _r, _image_z, _text_z)
                s.append(_s.data.cpu())

                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s
            s = torch.cat(s, dim=0)

            p = self.target_distribution(s)

            adjust_learning_rate(lr * kappa, optimizer)

            for batch_idx in range(full_num_batch):
                # clustering phase
                image_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][1]
                text_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][2]
                pbatch = p[batch_idx * batch_size:min((batch_idx + 1) *
                                                      batch_size, full_num)]

                optimizer.zero_grad()
                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                p_inputs = Variable(pbatch).to(self.device)

                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                qbatch, rbatch = self.soft_assignemt(_image_z, _text_z)
                sbatch = self.probabililty_fusion(qbatch, rbatch, _image_z,
                                                  _text_z)
                unsupervised_loss = self.loss_function(p_inputs, sbatch)
                train_unsupervised_loss += unsupervised_loss.data * len(
                    p_inputs)
                unsupervised_loss.backward()
                optimizer.step()

                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z
            train_unsupervised_loss /= full_num
            train_supervised_image_loss /= train_num
            train_supervised_text_loss /= train_num

            train_pred = torch.argmax(s, dim=1).numpy()
            df_pred = pd.DataFrame(data=train_pred,
                                   index=full_short_codes,
                                   columns=['pred'])
            df_pred = df_pred.loc[df_train.index]
            train_pred = df_pred['pred']
            train_acc = accuracy_score(train_labels, train_pred)
            train_nmi = normalized_mutual_info_score(
                train_labels, train_pred, average_method='geometric')
            train_f_1 = f1_score(train_labels, train_pred, average='macro')
            print("#Train measure %3d: acc: %.4f, nmi: %.4f, f_1: %.4f" %
                  (epoch + 1, train_acc, train_nmi, train_f_1))
            print(
                "#Train loss %3d: unsup lss: %.4f, super img: %.4f, super txt: %.4f"
                % (epoch + 1, train_unsupervised_loss,
                   train_supervised_image_loss, train_supervised_text_loss))
            if epoch == 0:
                train_pred_last = train_pred
                train_unsupervised_loss_last = train_unsupervised_loss
            else:
                if args.es:
                    train_unsupervised_loss = train_unsupervised_loss
                    if train_unsupervised_loss_last > train_unsupervised_loss and epoch >= 5:
                        print("Reach local max/min loss. Stopping training.")
                        flag_end_training = True
                    train_unsupervised_loss_last = train_unsupervised_loss
                else:
                    delta_label = np.sum(train_pred != train_pred_last).astype(
                        np.float32) / len(train_pred)
                    train_pred_last = train_pred
                    if delta_label < tol:
                        print('delta_label ', delta_label, '< tol ', tol)
                        print("Reach tolerance threshold. Stopping training.")
                        flag_end_training = True

            self.eval()
            test_unsupervised_loss = 0.0
            test_supervised_image_loss = 0.0
            test_supervised_text_loss = 0.0
            # update p considering short memory
            s = []
            for batch_idx in range(full_num_batch):
                image_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][1]
                text_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][2]

                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)

                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                _q, _r = self.soft_assignemt(_image_z, _text_z)
                _s = self.probabililty_fusion(_q, _r, _image_z, _text_z)
                s.append(_s.data.cpu())

                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s

            for batch_idx in range(test_num_batch):
                image_batch = test_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, test_num)][1]
                text_batch = test_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, test_num)][2]
                label_batch = test_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, test_num)][3].squeeze(dim=0)

                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                label_inputs = Variable(label_batch).to(self.device)

                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                qbatch, rbatch = self.soft_assignemt(_image_z, _text_z)
                supervised_image_loss, supervised_text_loss = self.semi_loss_function(
                    label_inputs, qbatch, rbatch)
                test_supervised_image_loss += supervised_image_loss.data * len(
                    label_inputs)
                test_supervised_text_loss += supervised_text_loss.data * len(
                    label_inputs)
                _q, _r = self.soft_assignemt(_image_z, _text_z)
                _s = self.probabililty_fusion(_q, _r, _image_z, _text_z)
                s.append(_s.data.cpu())

                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s
            s = torch.cat(s, dim=0)
            test_p = self.target_distribution(s)

            if args.tsne and (epoch + 1) % 5 == 0:
                do_tsne(
                    test_p.numpy(), df_initial, self.n_clusters,
                    os.path.join(CONFIG.SVG_PATH, args.gpu,
                                 'epoch_' + ('%03d' % (epoch + 1)) + '.png'))

            for batch_idx in range(full_num_batch):
                # clustering phase
                image_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][1]
                text_batch = full_dataset[batch_idx * batch_size:min(
                    (batch_idx + 1) * batch_size, full_num)][2]
                pbatch = test_p[batch_idx *
                                batch_size:min((batch_idx + 1) *
                                               batch_size, full_num)]

                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                p_inputs = Variable(pbatch).to(self.device)

                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                qbatch, rbatch = self.soft_assignemt(_image_z, _text_z)
                sbatch = self.probabililty_fusion(qbatch, rbatch, _image_z,
                                                  _text_z)
                unsupervised_loss = self.loss_function(p_inputs, sbatch)
                test_unsupervised_loss += unsupervised_loss.data * len(
                    p_inputs)
                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z
            test_unsupervised_loss /= full_num
            test_supervised_image_loss /= test_num
            test_supervised_text_loss /= test_num

            test_pred = torch.argmax(s, dim=1).numpy()[full_num:]
            test_acc = accuracy_score(test_labels, test_pred)
            test_nmi = normalized_mutual_info_score(test_labels,
                                                    test_pred,
                                                    average_method='geometric')
            test_f_1 = f1_score(test_labels, test_pred, average='macro')
            print("#Test measure %3d: acc: %.4f, nmi: %.4f, f_1: %.4f" %
                  (epoch + 1, test_acc, test_nmi, test_f_1))
            print(
                "#Test loss %3d: unsup lss: %.4f, super img: %.4f, super txt: %.4f"
                % (epoch + 1, test_unsupervised_loss,
                   test_supervised_image_loss, test_supervised_text_loss))
            self.acc = test_acc
            self.nmi = test_nmi
            self.f_1 = test_f_1

            if flag_end_training:
                break

        if save_path and not args.resume:
            self.save_model(save_path)