def fit(self, X, lr=0.001, batch_size=256, num_epochs=10, save_path=None): num = len(X) num_batch = int(math.ceil(1.0 * len(X) / batch_size)) '''X: tensor data''' self.to(self.device) print("=====Training DEC=======") # optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr) optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, momentum=0.9) print("Extracting initial features at %s" % (str(datetime.datetime.now()))) image_z = [] text_z = [] for batch_idx in range(num_batch): image_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][1] text_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) image_z.append(_image_z.data.cpu()) text_z.append(_text_z.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z torch.cuda.empty_cache() image_z = torch.cat(image_z, dim=0) text_z = torch.cat(text_z, dim=0) print("Initializing cluster centers with kmeans at %s" % (str(datetime.datetime.now()))) image_kmeans = KMeans(self.n_clusters, n_init=20) image_pred = image_kmeans.fit_predict(image_z.data.cpu().numpy()) print("Image kmeans completed at %s" % (str(datetime.datetime.now()))) text_kmeans = KMeans(self.n_clusters, n_init=20) text_pred = text_kmeans.fit_predict(text_z.data.cpu().numpy()) print("Text kmeans completed at %s" % (str(datetime.datetime.now()))) image_ind, text_ind = align_cluster(image_pred, text_pred) image_cluster_centers = np.zeros_like(image_kmeans.cluster_centers_) text_cluster_centers = np.zeros_like(text_kmeans.cluster_centers_) for i in range(self.n_clusters): image_cluster_centers[i] = image_kmeans.cluster_centers_[ image_ind[i]] text_cluster_centers[i] = text_kmeans.cluster_centers_[text_ind[i]] self.image_encoder.mu.data.copy_(torch.Tensor(image_cluster_centers)) self.image_encoder.mu.data = self.image_encoder.mu.cpu() self.text_encoder.mu.data.copy_(torch.Tensor(text_cluster_centers)) self.text_encoder.mu.data = self.text_encoder.mu.cpu() self.train() best_loss = 99999. best_epoch = 0 for epoch in range(num_epochs): # update the target distribution p image_z = [] text_z = [] for batch_idx in range(num_batch): image_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][1] text_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) image_z.append(_image_z.data.cpu()) text_z.append(_text_z.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z torch.cuda.empty_cache() image_z = torch.cat(image_z, dim=0) text_z = torch.cat(text_z, dim=0) q, r = self.soft_assignemt(image_z, text_z) p = self.target_distribution(q, r).data y_pred = torch.argmax(p, dim=1).numpy() count_percentage(y_pred) # train 1 epoch train_loss = 0.0 for batch_idx in range(num_batch): image_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][1] text_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)][2] pbatch = p[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] optimizer.zero_grad() image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) target = Variable(pbatch) image_z, text_z = self.forward(image_inputs, text_inputs) qbatch, rbatch = self.soft_assignemt(image_z.cpu(), text_z.cpu()) loss = self.loss_function(target, qbatch, rbatch) train_loss += loss.data * len(target) loss.backward() optimizer.step() del image_batch, text_batch, image_inputs, text_inputs, image_z, text_z torch.cuda.empty_cache() train_loss = train_loss / num if best_loss > train_loss: best_loss = train_loss best_epoch = epoch if save_path: self.save_model( os.path.join(save_path, "mdec_" + str(self.image_encoder.z_dim)) + '_' + str(self.n_clusters) + ".pt") print("#Epoch %3d: Loss: %.4f Best Loss: %.4f at %s" % (epoch + 1, train_loss, best_loss, str(datetime.datetime.now()))) print("#Best Epoch %3d: Best Loss: %.4f" % (best_epoch, best_loss))
def fit_predict(self, X, train_dataset, test_dataset, lr=0.001, batch_size=256, num_epochs=10, update_time=1, save_path=None, tol=1e-3, kappa=0.1): X_num = len(X) X_num_batch = int(math.ceil(1.0 * len(X) / batch_size)) train_num = len(train_dataset) train_num_batch = int(math.ceil(1.0 * len(train_dataset) / batch_size)) '''X: tensor data''' self.to(self.device) self.encoder.mu.data = self.encoder.mu.cpu() print("=====Training DEC=======") trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) validloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, momentum=0.9) print("Extracting initial features at %s" % (str(datetime.datetime.now()))) z = self.update_z(X, batch_size) train_z = self.update_z(train_dataset, batch_size) print("Initializing cluster centers with kmeans at %s" % (str(datetime.datetime.now()))) kmeans = KMeans(self.n_clusters, n_init=20) kmeans.fit(z.data.cpu().numpy()) train_pred = kmeans.predict(train_z.data.cpu().numpy()) print("kmeans completed at %s" % (str(datetime.datetime.now()))) short_codes = X[:][0] train_short_codes = train_dataset[:][0] train_labels = train_dataset[:][2].data.cpu().numpy() df_train = pd.DataFrame(data=train_labels, index=train_short_codes, columns=['label']) _, ind = align_cluster(train_labels, train_pred) cluster_centers = np.zeros_like(kmeans.cluster_centers_) for i in range(self.n_clusters): cluster_centers[i] = kmeans.cluster_centers_[ind[i]] self.encoder.mu.data.copy_(torch.Tensor(cluster_centers)) self.encoder.mu.data = self.encoder.mu.cpu() if self.use_prior: for label in train_labels: self.prior[label] = self.prior[label] + 1 self.prior = self.prior / len(train_labels) for epoch in range(num_epochs): # update the target distribution p self.train() # train 1 epoch train_loss = 0.0 semi_train_loss = 0.0 adjust_learning_rate(lr, optimizer) for batch_idx in range(train_num_batch): # semi-supervised phase data_batch = train_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, train_num)][1] label_batch = train_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, train_num)][2] optimizer.zero_grad() data_inputs = Variable(data_batch).to(self.device) label_inputs = Variable(label_batch) _z = self.forward(data_inputs) qbatch = self.soft_assignemt(_z.cpu()) semi_loss = self.semi_loss_function(label_inputs, qbatch) semi_train_loss += semi_loss.data * len(label_inputs) semi_loss.backward() optimizer.step() del data_batch, data_inputs, _z z = self.update_z(X, batch_size) q = self.soft_assignemt(z) p = self.target_distribution(q).data adjust_learning_rate(lr * kappa, optimizer) for batch_idx in range(X_num_batch): # clustering phase data_batch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, X_num)][1] pbatch = p[batch_idx * batch_size:min((batch_idx + 1) * batch_size, X_num)] optimizer.zero_grad() data_inputs = Variable(data_batch).to(self.device) p_inputs = Variable(pbatch) _z = self.forward(data_inputs) qbatch = self.soft_assignemt(_z.cpu()) loss = self.loss_function(p_inputs, qbatch) train_loss += loss.data * len(p_inputs) loss.backward() optimizer.step() del data_batch, data_inputs, _z train_loss = train_loss / X_num semi_train_loss = semi_train_loss / train_num train_pred = torch.argmax(p, dim=1).numpy() df_pred = pd.DataFrame(data=train_pred, index=short_codes, columns=['pred']) df_pred = df_pred.loc[df_train.index] train_pred = df_pred['pred'] train_acc = accuracy_score(train_labels, train_pred) train_nmi = normalized_mutual_info_score( train_labels, train_pred, average_method='geometric') train_f_1 = f1_score(train_labels, train_pred, average='macro') print( "#Epoch %3d: acc: %.4f, nmi: %.4f, f_1: %.4f, loss: %.4f, semi_loss: %.4f, at %s" % (epoch + 1, train_acc, train_nmi, train_f_1, train_loss, semi_train_loss, str(datetime.datetime.now()))) if epoch == 0: train_pred_last = train_pred else: delta_label = np.sum(train_pred != train_pred_last).astype( np.float32) / len(train_pred) train_pred_last = train_pred if delta_label < tol: print('delta_label ', delta_label, '< tol ', tol) print("Reach tolerance threshold. Stopping training.") break self.eval() test_labels = test_dataset[:][2].squeeze(dim=0) test_z = self.update_z(test_dataset, batch_size) z = torch.cat([z, test_z], dim=0) q = self.soft_assignemt(z) test_p = self.target_distribution(q).data test_pred = torch.argmax(test_p, dim=1).numpy()[X_num:] test_acc = accuracy_score(test_labels, test_pred) test_short_codes = test_dataset[:][0] test_short_codes = np.concatenate([short_codes, test_short_codes], axis=0) df_test = pd.DataFrame(data=torch.argmax(test_p, dim=1).numpy(), index=test_short_codes, columns=['labels']) df_test.to_csv('udec_label.csv', encoding='utf-8-sig') df_test_p = pd.DataFrame(data=test_p.data.numpy(), index=test_short_codes) df_test_p.to_csv('udec_p.csv', encoding='utf-8-sig') test_nmi = normalized_mutual_info_score(test_labels, test_pred, average_method='geometric') test_f_1 = f1_score(test_labels, test_pred, average='macro') print("#Test acc: %.4f, Test nmi: %.4f, Test f_1: %.4f" % (test_acc, test_nmi, test_f_1)) self.acc = test_acc self.nmi = test_nmi self.f_1 = test_f_1 if save_path: self.save_model(save_path)
def fit_predict(self, full_dataset, train_dataset, test_dataset, args, CONFIG, lr=0.001, batch_size=256, num_epochs=10, update_time=1, save_path=None, tol=1e-3, kappa=0.1): full_num = len(full_dataset) full_num_batch = int(math.ceil(1.0 * len(full_dataset) / batch_size)) train_num = len(train_dataset) train_num_batch = int(math.ceil(1.0 * len(train_dataset) / batch_size)) test_num = len(test_dataset) test_num_batch = int(math.ceil(1.0 * len(test_dataset) / batch_size)) '''X: tensor data''' self.to(self.device) print("=====Training DEC=======") if args.adam: optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr) else: optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, momentum=0.9) full_short_codes = full_dataset[:][0] train_short_codes = train_dataset[:][0] test_short_codes = test_dataset[:][0] train_labels = train_dataset[:][3].squeeze(dim=0).data.cpu().numpy() test_labels = test_dataset[:][3].squeeze(dim=0).data.cpu().numpy() df_train = pd.DataFrame(data=train_labels, index=train_short_codes, columns=['label']) df_test = pd.DataFrame(data=test_labels, index=test_short_codes, columns=['label']) if not args.resume: print("Extracting initial features at %s" % (str(datetime.datetime.now()))) image_z = [] text_z = [] for batch_idx in range(full_num_batch): image_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][1] text_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) image_z.append(_image_z.data.cpu()) text_z.append(_text_z.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z image_z = torch.cat(image_z, dim=0) text_z = torch.cat(text_z, dim=0) train_image_z = [] train_text_z = [] for batch_idx in range(train_num_batch): image_batch = train_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, train_num)][1] text_batch = train_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, train_num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) train_image_z.append(_image_z.data.cpu()) train_text_z.append(_text_z.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z train_image_z = torch.cat(train_image_z, dim=0) train_text_z = torch.cat(train_text_z, dim=0) print("Initializing cluster centers with kmeans at %s" % (str(datetime.datetime.now()))) image_kmeans = KMeans(n_clusters=self.n_clusters, n_init=20, random_state=42) image_kmeans.fit(image_z.data.cpu().numpy()) train_image_pred = image_kmeans.predict( train_image_z.data.cpu().numpy()) print("Image kmeans completed at %s" % (str(datetime.datetime.now()))) text_kmeans = KMeans(n_clusters=self.n_clusters, n_init=20, random_state=42) text_kmeans.fit(text_z.data.cpu().numpy()) train_text_pred = text_kmeans.predict( train_text_z.data.cpu().numpy()) print("Text kmeans completed at %s" % (str(datetime.datetime.now()))) _, image_ind = align_cluster(train_labels, train_image_pred) _, text_ind = align_cluster(train_labels, train_text_pred) image_cluster_centers = np.zeros_like( image_kmeans.cluster_centers_) text_cluster_centers = np.zeros_like(text_kmeans.cluster_centers_) for i in range(self.n_clusters): image_cluster_centers[i] = image_kmeans.cluster_centers_[ image_ind[i]] text_cluster_centers[i] = text_kmeans.cluster_centers_[ text_ind[i]] self.image_encoder.mu.data.copy_( torch.Tensor(image_cluster_centers)) self.text_encoder.mu.data.copy_(torch.Tensor(text_cluster_centers)) if self.use_prior: for label in train_labels: self.prior[label] = self.prior[label] + 1 self.prior /= len(train_labels) print("Calculating initial p at %s" % (str(datetime.datetime.now()))) # update p considering short memory s = [] for batch_idx in range(full_num_batch): image_batch = full_dataset[batch_idx * batch_size:min((batch_idx + 1) * batch_size, full_num)][1] text_batch = full_dataset[batch_idx * batch_size:min((batch_idx + 1) * batch_size, full_num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) _q, _r = self.soft_assignemt(_image_z, _text_z) _s = self.probabililty_fusion(_q, _r, _image_z, _text_z) s.append(_s.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s for batch_idx in range(test_num_batch): image_batch = test_dataset[batch_idx * batch_size:min((batch_idx + 1) * batch_size, test_num)][1] text_batch = test_dataset[batch_idx * batch_size:min((batch_idx + 1) * batch_size, test_num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) _q, _r = self.soft_assignemt(_image_z, _text_z) _s = self.probabililty_fusion(_q, _r, _image_z, _text_z) s.append(_s.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s s = torch.cat(s, dim=0) p = self.target_distribution(s) initial_pred = torch.argmax(s, dim=1).numpy() initial_acc = accuracy_score(test_labels, initial_pred[full_num:]) initial_nmi = normalized_mutual_info_score(test_labels, initial_pred[full_num:], average_method='geometric') initial_f_1 = f1_score(test_labels, initial_pred[full_num:], average='macro') print("#Initial measure: acc: %.4f, nmi: %.4f, f_1: %.4f" % (initial_acc, initial_nmi, initial_f_1)) df_initial = pd.DataFrame(data=initial_pred, index=full_short_codes + test_short_codes, columns=['label']) df_initial['pred'] = 'pred' df_initial.loc[df_train.index, 'pred'] = 'label' for idx, row in df_train.iterrows(): df_initial.loc[idx, 'label'] = row['label'] df_initial.loc[df_test.index, 'pred'] = 'label' for idx, row in df_test.iterrows(): df_initial.loc[idx, 'label'] = row['label'] if args.tsne: print("Conducting initial TSNE at %s" % (str(datetime.datetime.now()))) do_tsne(p.numpy(), df_initial, self.n_clusters, os.path.join(CONFIG.SVG_PATH, args.gpu, 'epoch_000.png')) print("TSNE completed at %s" % (str(datetime.datetime.now()))) flag_end_training = False for epoch in range(num_epochs): print("Epoch %d at %s" % (epoch, str(datetime.datetime.now()))) # update the target distribution p self.train() # train 1 epoch train_unsupervised_loss = 0.0 train_supervised_image_loss = 0.0 train_supervised_text_loss = 0.0 adjust_learning_rate(lr, optimizer) for batch_idx in range(train_num_batch): # supervised phase image_batch = train_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, train_num)][1] text_batch = train_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, train_num)][2] label_batch = train_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, train_num)][3].squeeze(dim=0) optimizer.zero_grad() image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) label_inputs = Variable(label_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) qbatch, rbatch = self.soft_assignemt(_image_z, _text_z) supervised_image_loss, supervised_text_loss = self.semi_loss_function( label_inputs, qbatch, rbatch) train_supervised_image_loss += supervised_image_loss.data * len( label_inputs) train_supervised_text_loss += supervised_text_loss.data * len( label_inputs) supervised_loss = supervised_image_loss + supervised_text_loss supervised_loss.backward() optimizer.step() del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z # update p considering short memory s = [] for batch_idx in range(full_num_batch): image_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][1] text_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) _q, _r = self.soft_assignemt(_image_z, _text_z) _s = self.probabililty_fusion(_q, _r, _image_z, _text_z) s.append(_s.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s s = torch.cat(s, dim=0) p = self.target_distribution(s) adjust_learning_rate(lr * kappa, optimizer) for batch_idx in range(full_num_batch): # clustering phase image_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][1] text_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][2] pbatch = p[batch_idx * batch_size:min((batch_idx + 1) * batch_size, full_num)] optimizer.zero_grad() image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) p_inputs = Variable(pbatch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) qbatch, rbatch = self.soft_assignemt(_image_z, _text_z) sbatch = self.probabililty_fusion(qbatch, rbatch, _image_z, _text_z) unsupervised_loss = self.loss_function(p_inputs, sbatch) train_unsupervised_loss += unsupervised_loss.data * len( p_inputs) unsupervised_loss.backward() optimizer.step() del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z train_unsupervised_loss /= full_num train_supervised_image_loss /= train_num train_supervised_text_loss /= train_num train_pred = torch.argmax(s, dim=1).numpy() df_pred = pd.DataFrame(data=train_pred, index=full_short_codes, columns=['pred']) df_pred = df_pred.loc[df_train.index] train_pred = df_pred['pred'] train_acc = accuracy_score(train_labels, train_pred) train_nmi = normalized_mutual_info_score( train_labels, train_pred, average_method='geometric') train_f_1 = f1_score(train_labels, train_pred, average='macro') print("#Train measure %3d: acc: %.4f, nmi: %.4f, f_1: %.4f" % (epoch + 1, train_acc, train_nmi, train_f_1)) print( "#Train loss %3d: unsup lss: %.4f, super img: %.4f, super txt: %.4f" % (epoch + 1, train_unsupervised_loss, train_supervised_image_loss, train_supervised_text_loss)) if epoch == 0: train_pred_last = train_pred train_unsupervised_loss_last = train_unsupervised_loss else: if args.es: train_unsupervised_loss = train_unsupervised_loss if train_unsupervised_loss_last > train_unsupervised_loss and epoch >= 5: print("Reach local max/min loss. Stopping training.") flag_end_training = True train_unsupervised_loss_last = train_unsupervised_loss else: delta_label = np.sum(train_pred != train_pred_last).astype( np.float32) / len(train_pred) train_pred_last = train_pred if delta_label < tol: print('delta_label ', delta_label, '< tol ', tol) print("Reach tolerance threshold. Stopping training.") flag_end_training = True self.eval() test_unsupervised_loss = 0.0 test_supervised_image_loss = 0.0 test_supervised_text_loss = 0.0 # update p considering short memory s = [] for batch_idx in range(full_num_batch): image_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][1] text_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][2] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) _q, _r = self.soft_assignemt(_image_z, _text_z) _s = self.probabililty_fusion(_q, _r, _image_z, _text_z) s.append(_s.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s for batch_idx in range(test_num_batch): image_batch = test_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, test_num)][1] text_batch = test_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, test_num)][2] label_batch = test_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, test_num)][3].squeeze(dim=0) image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) label_inputs = Variable(label_batch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) qbatch, rbatch = self.soft_assignemt(_image_z, _text_z) supervised_image_loss, supervised_text_loss = self.semi_loss_function( label_inputs, qbatch, rbatch) test_supervised_image_loss += supervised_image_loss.data * len( label_inputs) test_supervised_text_loss += supervised_text_loss.data * len( label_inputs) _q, _r = self.soft_assignemt(_image_z, _text_z) _s = self.probabililty_fusion(_q, _r, _image_z, _text_z) s.append(_s.data.cpu()) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z, _q, _r, _s s = torch.cat(s, dim=0) test_p = self.target_distribution(s) if args.tsne and (epoch + 1) % 5 == 0: do_tsne( test_p.numpy(), df_initial, self.n_clusters, os.path.join(CONFIG.SVG_PATH, args.gpu, 'epoch_' + ('%03d' % (epoch + 1)) + '.png')) for batch_idx in range(full_num_batch): # clustering phase image_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][1] text_batch = full_dataset[batch_idx * batch_size:min( (batch_idx + 1) * batch_size, full_num)][2] pbatch = test_p[batch_idx * batch_size:min((batch_idx + 1) * batch_size, full_num)] image_inputs = Variable(image_batch).to(self.device) text_inputs = Variable(text_batch).to(self.device) p_inputs = Variable(pbatch).to(self.device) _image_z, _text_z = self.forward(image_inputs, text_inputs) qbatch, rbatch = self.soft_assignemt(_image_z, _text_z) sbatch = self.probabililty_fusion(qbatch, rbatch, _image_z, _text_z) unsupervised_loss = self.loss_function(p_inputs, sbatch) test_unsupervised_loss += unsupervised_loss.data * len( p_inputs) del image_batch, text_batch, image_inputs, text_inputs, _image_z, _text_z test_unsupervised_loss /= full_num test_supervised_image_loss /= test_num test_supervised_text_loss /= test_num test_pred = torch.argmax(s, dim=1).numpy()[full_num:] test_acc = accuracy_score(test_labels, test_pred) test_nmi = normalized_mutual_info_score(test_labels, test_pred, average_method='geometric') test_f_1 = f1_score(test_labels, test_pred, average='macro') print("#Test measure %3d: acc: %.4f, nmi: %.4f, f_1: %.4f" % (epoch + 1, test_acc, test_nmi, test_f_1)) print( "#Test loss %3d: unsup lss: %.4f, super img: %.4f, super txt: %.4f" % (epoch + 1, test_unsupervised_loss, test_supervised_image_loss, test_supervised_text_loss)) self.acc = test_acc self.nmi = test_nmi self.f_1 = test_f_1 if flag_end_training: break if save_path and not args.resume: self.save_model(save_path)