예제 #1
0
    def fit_predict(self, X, batch_size=256):
        num = len(X)
        num_batch = int(math.ceil(1.0 * len(X) / batch_size))
        self.to(self.device)
        self.image_encoder.mu.data = self.image_encoder.mu.cpu()
        self.text_encoder.mu.data = self.text_encoder.mu.cpu()

        self.eval()
        image_z = []
        text_z = []
        w = []
        short_codes = []
        for batch_idx in range(num_batch):
            short_codes.append(X[batch_idx *
                                 batch_size:min((batch_idx + 1) *
                                                batch_size, num)][0])
            image_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                       batch_size, num)][1]
            text_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                      batch_size, num)][2]
            image_inputs = Variable(image_batch).to(self.device)
            text_inputs = Variable(text_batch).to(self.device)
            _image_z, _text_z, _w = self.forward(image_inputs, text_inputs)
            image_z.append(_image_z.data.cpu())
            text_z.append(_text_z.data.cpu())
            w.append(_w.data.cpu())
            del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _w
            torch.cuda.empty_cache()
        short_codes = np.concatenate(short_codes, axis=0)
        image_z = torch.cat(image_z, dim=0)
        text_z = torch.cat(text_z, dim=0)
        w = torch.cat(w, dim=0)

        q, r = self.soft_assignemt(image_z, text_z)
        p = self.target_distribution(q, r, w).data
        # y_pred = torch.argmax(p, dim=1).numpy()
        y_confidence, y_pred = torch.max(p, dim=1)
        y_confidence = y_confidence.numpy()
        y_pred = y_pred.numpy()
        p = p.numpy()
        w = w.data.numpy()
        count_percentage(y_pred)
        return short_codes, y_pred, y_confidence, p, w
예제 #2
0
    def fit(self, X, lr=0.001, batch_size=256, num_epochs=10, save_path=None):
        num = len(X)
        num_batch = int(math.ceil(1.0 * len(X) / batch_size))
        '''X: tensor data'''
        self.to(self.device)
        print("=====Training DEC=======")
        # optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr)
        optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                     self.parameters()),
                              lr=lr,
                              momentum=0.9)

        print("Extracting initial features at %s" %
              (str(datetime.datetime.now())))
        image_z = []
        text_z = []
        for batch_idx in range(num_batch):
            image_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                       batch_size, num)][1]
            text_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                      batch_size, num)][2]
            image_inputs = Variable(image_batch).to(self.device)
            text_inputs = Variable(text_batch).to(self.device)
            _image_z, _text_z = self.forward(image_inputs, text_inputs)
            image_z.append(_image_z.data.cpu())
            text_z.append(_text_z.data.cpu())
            del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z
            torch.cuda.empty_cache()
        image_z = torch.cat(image_z, dim=0)
        text_z = torch.cat(text_z, dim=0)

        print("Initializing cluster centers with kmeans at %s" %
              (str(datetime.datetime.now())))
        image_kmeans = KMeans(self.n_clusters, n_init=20)
        image_pred = image_kmeans.fit_predict(image_z.data.cpu().numpy())
        print("Image kmeans completed at %s" % (str(datetime.datetime.now())))

        text_kmeans = KMeans(self.n_clusters, n_init=20)
        text_pred = text_kmeans.fit_predict(text_z.data.cpu().numpy())
        print("Text kmeans completed at %s" % (str(datetime.datetime.now())))

        image_ind, text_ind = align_cluster(image_pred, text_pred)

        image_cluster_centers = np.zeros_like(image_kmeans.cluster_centers_)
        text_cluster_centers = np.zeros_like(text_kmeans.cluster_centers_)

        for i in range(self.n_clusters):
            image_cluster_centers[i] = image_kmeans.cluster_centers_[
                image_ind[i]]
            text_cluster_centers[i] = text_kmeans.cluster_centers_[text_ind[i]]
        self.image_encoder.mu.data.copy_(torch.Tensor(image_cluster_centers))
        self.image_encoder.mu.data = self.image_encoder.mu.cpu()
        self.text_encoder.mu.data.copy_(torch.Tensor(text_cluster_centers))
        self.text_encoder.mu.data = self.text_encoder.mu.cpu()
        self.train()
        best_loss = 99999.
        best_epoch = 0

        for epoch in range(num_epochs):
            # update the target distribution p

            image_z = []
            text_z = []
            for batch_idx in range(num_batch):
                image_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                           batch_size, num)][1]
                text_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                          batch_size, num)][2]
                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                _image_z, _text_z = self.forward(image_inputs, text_inputs)
                image_z.append(_image_z.data.cpu())
                text_z.append(_text_z.data.cpu())
                del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z
                torch.cuda.empty_cache()
            image_z = torch.cat(image_z, dim=0)
            text_z = torch.cat(text_z, dim=0)

            q, r = self.soft_assignemt(image_z, text_z)
            p = self.target_distribution(q, r).data
            y_pred = torch.argmax(p, dim=1).numpy()
            count_percentage(y_pred)
            # train 1 epoch
            train_loss = 0.0
            for batch_idx in range(num_batch):
                image_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                           batch_size, num)][1]
                text_batch = X[batch_idx * batch_size:min((batch_idx + 1) *
                                                          batch_size, num)][2]
                pbatch = p[batch_idx * batch_size:min((batch_idx + 1) *
                                                      batch_size, num)]

                optimizer.zero_grad()
                image_inputs = Variable(image_batch).to(self.device)
                text_inputs = Variable(text_batch).to(self.device)
                target = Variable(pbatch)

                image_z, text_z = self.forward(image_inputs, text_inputs)
                qbatch, rbatch = self.soft_assignemt(image_z.cpu(),
                                                     text_z.cpu())
                loss = self.loss_function(target, qbatch, rbatch)
                train_loss += loss.data * len(target)
                loss.backward()
                optimizer.step()

                del image_batch, text_batch, image_inputs, text_inputs, image_z, text_z
                torch.cuda.empty_cache()
            train_loss = train_loss / num
            if best_loss > train_loss:
                best_loss = train_loss
                best_epoch = epoch
                if save_path:
                    self.save_model(
                        os.path.join(save_path, "mdec_" +
                                     str(self.image_encoder.z_dim)) + '_' +
                        str(self.n_clusters) + ".pt")
            print("#Epoch %3d: Loss: %.4f Best Loss: %.4f at %s" %
                  (epoch + 1, train_loss, best_loss,
                   str(datetime.datetime.now())))

        print("#Best Epoch %3d: Best Loss: %.4f" % (best_epoch, best_loss))