def fit_predict(self, X, batch_size=256): num = len(X) num_batch = int(math.ceil(1.0 * len(X) / batch_size)) self.to(self.device) self.image_encoder.mu.data = self.image_encoder.mu.cpu() self.text_encoder.mu.data = self.text_encoder.mu.cpu() self.eval() image_z = [] text_z = [] w = [] short_codes = [] for batch_idx in range(num_batch): short_codes.append(X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][0]) image_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][1] text_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z, _w = self.forward(image_inputs, text_inputs) image_z.append(_image_z.data.cpu()) text_z.append(_text_z.data.cpu()) w.append(_w.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _w torch.cuda.empty_cache() short_codes = np.concatenate(short_codes, axis=0) image_z = torch.cat(image_z, dim=0) text_z = torch.cat(text_z, dim=0) w = torch.cat(w, dim=0) q, r = self.soft_assignemt(image_z, text_z) p = self.target_distribution(q, r, w).data # y_pred = torch.argmax(p, dim=1).numpy() y_confidence, y_pred = torch.max(p, dim=1) y_confidence = y_confidence.numpy() y_pred = y_pred.numpy() p = p.numpy() w = w.data.numpy() count_percentage(y_pred) return short_codes, y_pred, y_confidence, p, w
def fit(self, X, lr=0.001, batch_size=256, num_epochs=10, save_path=None): num = len(X) num_batch = int(math.ceil(1.0 * len(X) / batch_size)) '''X: tensor data''' self.to(self.device) print("=====Training DEC=======") # optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr) optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, momentum=0.9) print("Extracting initial features at %s" % (str(datetime.datetime.now()))) image_z = [] text_z = [] for batch_idx in range(num_batch): image_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][1] text_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) image_z.append(_image_z.data.cpu()) text_z.append(_text_z.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z torch.cuda.empty_cache() image_z = torch.cat(image_z, dim=0) text_z = torch.cat(text_z, dim=0) print("Initializing cluster centers with kmeans at %s" % (str(datetime.datetime.now()))) image_kmeans = KMeans(self.n_clusters, n_init=20) image_pred = image_kmeans.fit_predict(image_z.data.cpu().numpy()) print("Image kmeans completed at %s" % (str(datetime.datetime.now()))) text_kmeans = KMeans(self.n_clusters, n_init=20) text_pred = text_kmeans.fit_predict(text_z.data.cpu().numpy()) print("Text kmeans completed at %s" % (str(datetime.datetime.now()))) image_ind, text_ind = align_cluster(image_pred, text_pred) image_cluster_centers = np.zeros_like(image_kmeans.cluster_centers_) text_cluster_centers = np.zeros_like(text_kmeans.cluster_centers_) for i in range(self.n_clusters): image_cluster_centers[i] = image_kmeans.cluster_centers_[ image_ind[i]] text_cluster_centers[i] = text_kmeans.cluster_centers_[text_ind[i]] self.image_encoder.mu.data.copy_(torch.Tensor(image_cluster_centers)) self.image_encoder.mu.data = self.image_encoder.mu.cpu() self.text_encoder.mu.data.copy_(torch.Tensor(text_cluster_centers)) self.text_encoder.mu.data = self.text_encoder.mu.cpu() self.train() best_loss = 99999. best_epoch = 0 for epoch in range(num_epochs): # update the target distribution p image_z = [] text_z = [] for batch_idx in range(num_batch): image_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][1] text_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) image_z.append(_image_z.data.cpu()) text_z.append(_text_z.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z torch.cuda.empty_cache() image_z = torch.cat(image_z, dim=0) text_z = torch.cat(text_z, dim=0) q, r = self.soft_assignemt(image_z, text_z) p = self.target_distribution(q, r).data y_pred = torch.argmax(p, dim=1).numpy() count_percentage(y_pred) # train 1 epoch train_loss = 0.0 for batch_idx in range(num_batch): image_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][1] text_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][2] pbatch = p[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] optimizer.zero_grad() image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) target = Variable(pbatch) image_z, text_z = self.forward(image_inputs, text_inputs) qbatch, rbatch = self.soft_assignemt(image_z.cpu(), text_z.cpu()) loss = self.loss_function(target, qbatch, rbatch) train_loss += loss.data * len(target) loss.backward() optimizer.step() del image_batch, text_batch, image_inputs, text_inputs, image_z, text_z torch.cuda.empty_cache() train_loss = train_loss / num if best_loss > train_loss: best_loss = train_loss best_epoch = epoch if save_path: self.save_model( os.path.join(save_path, "mdec_" + str(self.image_encoder.z_dim)) + '_' + str(self.n_clusters) + ".pt") print("#Epoch %3d: Loss: %.4f Best Loss: %.4f at %s" % (epoch + 1, train_loss, best_loss, str(datetime.datetime.now()))) print("#Best Epoch %3d: Best Loss: %.4f" % (best_epoch, best_loss))