def get_v_adv_loss(self, model: nn.Module, ul_left_input, ul_right_input, p_mult, power_iterations=1): bernoulli = dist.Bernoulli prob, left_word_emb, right_word_emb = model.siamese_forward(ul_left_input, ul_right_input)[0:3] prob = prob.clamp(min=1e-7, max=1. - 1e-7) prob_dist = bernoulli(probs=prob) # generate virtual adversarial perturbation left_d, _ = tl.cudafy(torch.FloatTensor(left_word_emb.shape).uniform_(0, 1)) right_d, _ = tl.cudafy(torch.FloatTensor(right_word_emb.shape).uniform_(0, 1)) left_d.requires_grad, right_d.requires_grad = True, True # prob_dist.requires_grad = True # kl_divergence for _ in range(power_iterations): left_d = (0.02) * F.normalize(left_d, p=2, dim=1) right_d = (0.02) * F.normalize(right_d, p=2, dim=1) # d1 = dist.Categorical(a) # d2 = dist.Categorical(torch.ones(5)) p_prob = model.siamese_forward(ul_left_input, ul_right_input, left_d, right_d)[0] p_prob = p_prob.clamp(min=1e-7, max=1. - 1e-7) # torch.distribution try: kl = dist.kl_divergence(prob_dist, bernoulli(probs=p_prob)) except: wait = True left_gradient, right_gradient = torch.autograd.grad(kl.sum(), [left_d, right_d], retain_graph=True) left_d = left_gradient.detach() right_d = right_gradient.detach() left_d = p_mult * F.normalize(left_d, p=2, dim=1) right_d = p_mult * F.normalize(right_d, p=2, dim=1) # virtual adversarial loss p_prob = model.siamese_forward(ul_left_input, ul_right_input, left_d, right_d)[0].clamp(min=1e-7, max=1. - 1e-7) v_adv_losses = dist.kl_divergence(prob_dist, bernoulli(probs=p_prob)) return torch.mean(v_adv_losses)
def pairwise_distance(embeddings, squared=False): pairwise_distances_squared = torch.sum(embeddings ** 2, dim=1, keepdim=True) + \ torch.sum(embeddings.t() ** 2, dim=0, keepdim=True) - \ 2.0 * torch.matmul(embeddings, embeddings.t()) error_mask = pairwise_distances_squared <= 0.0 if squared: pairwise_distances = pairwise_distances_squared.clamp(min=0) else: pairwise_distances = pairwise_distances_squared.clamp(min=1e-16).sqrt() pairwise_distances = torch.mul(pairwise_distances, ~error_mask) num_data = embeddings.shape[0] # Explicitly set diagonals to zero. if pairwise_distances.is_cuda: mask_offdiagonals = torch.ones_like(pairwise_distances) - torch.diag( tl.cudafy(torch.ones([num_data]))[0]) else: mask_offdiagonals = torch.ones_like(pairwise_distances) - torch.diag( torch.ones([num_data])) pairwise_distances = torch.mul(pairwise_distances, mask_offdiagonals) return pairwise_distances
def compute_gt_cluster_score(self, pairwise_distances, labels): """Compute ground truth facility location score. Loop over each unique classes and compute average travel distances. Args: pairwise_distances: 2-D numpy array of pairwise distances. labels: 1-D numpy array of ground truth cluster assignment. Returns: gt_cluster_score: dtypes.float32 score. """ unique_class_ids = torch.unique(labels) num_classes = len(unique_class_ids) gt_cluster_score = tl.cudafy(torch.from_numpy(np.array([0.0])))[0] for i in range(num_classes): """Per each cluster, compute the average travel distance.""" mask = labels == unique_class_ids[i] this_cluster_ids = torch.where(mask)[0] temp = (tl.gather(pairwise_distances, this_cluster_ids)).T pairwise_distances_subset = (tl.gather(temp, this_cluster_ids)).T this_cluster_score = -1.0 * torch.min( torch.sum(pairwise_distances_subset, 0)) gt_cluster_score += this_cluster_score return gt_cluster_score
def cos_smi(self, data_left, data_right): self.eval() if isinstance(data_right, np.ndarray): data_right, _ = tl.cudafy( torch.from_numpy(np.array(data_right, dtype=np.int64))) if isinstance(data_left, np.ndarray): data_left, _ = tl.cudafy( torch.from_numpy(np.array(data_left, dtype=np.int64))) _, vector_l = self.forward_norm(data_left) _, vector_r = self.forward_norm(data_right) length_l = torch.sum(torch.pow(vector_l, 2), dim=1).sqrt() length_r = torch.sum(torch.pow(vector_r, 2), dim=1).sqrt() rns = torch.sum(torch.mul(vector_l, vector_r), dim=1) / torch.mul( length_l, length_r).float() return rns, vector_l, vector_r
def pred_vector(self, data, opt): self.eval() if not isinstance(data, torch.Tensor): data, _ = tl.cudafy( torch.from_numpy(np.array(data, dtype=np.int64))) if opt.train_loss_type.startswith('Siamese'): _, vectors = self.forward(data) else: _, vectors = self.forward_norm(data) return vectors
def pred_X(self, data_left, data_right): self.eval() if isinstance(data_right, np.ndarray): data_right, _ = tl.cudafy( torch.from_numpy(np.array(data_right, dtype=np.int64))) if isinstance(data_left, np.ndarray): data_left, _ = tl.cudafy( torch.from_numpy(np.array(data_left, dtype=np.int64))) _, vector_l = self.forward_norm(data_left) _, vector_r = self.forward_norm(data_right) distances_squared = torch.sum(torch.pow(vector_l - vector_r, 2), dim=1) if not self.squared: prediction = distances_squared.sqrt() # the euclidean dist between two normalized vector is in [0,2] rns = 1 - prediction / 2.0 else: # the euclidean dist(squared) between two normalized vector is in [0,4] prediction = distances_squared rns = 1 - prediction / 4.0 # prediction, _l, _r, encoded_l, encoded_r = self.siamese_forward(data_left, data_right) return rns, vector_l, vector_r
def process_and_train_FL(model: BasicModel, opt: config.Option): # preparing saving files. save_path = os.path.join(opt.save_dir + '/model_file', opt.save_model_name).replace('\\', '/') print("model file save path: ", save_path) if not os.path.exists(save_path): os.makedirs(save_path) msger = messager( save_path=save_path, types=[ 'train_data_file', 'val_data_file', 'test_data_file', 'load_model_name', 'save_model_name', 'trainset_loss_type', 'testset_loss_type', 'class_num_ratio' ], json_name='train_information_msg_' + time.strftime('%m{}%d{}_%H:%M'.format('月', '日')) + '.json') msger.record_message([ opt.train_data_file, opt.val_data_file, opt.test_data_file, opt.load_model_name, opt.save_model_name, opt.train_loss_type, opt.testset_loss_type, opt.class_num_ratio ]) msger.save_json() # train data loading print('-----Data Loading-----') if opt.BERT: dataloader_train = dataloader_BERT(opt.train_data_file, opt.wordvec_file, opt.rel2id_file, opt.similarity_file, opt.same_level_pair_file, max_len=opt.max_len, random_init=opt.random_init, seed=opt.seed) dataloader_val = dataloader_BERT(opt.val_data_file, opt.wordvec_file, opt.rel2id_file, opt.similarity_file, max_len=opt.max_len) dataloader_test = dataloader_BERT(opt.test_data_file, opt.wordvec_file, opt.rel2id_file, opt.similarity_file, max_len=opt.max_len) else: dataloader_train = dataloader(opt.train_data_file, opt.wordvec_file, opt.rel2id_file, opt.similarity_file, opt.same_level_pair_file, max_len=opt.max_len, random_init=opt.random_init, seed=opt.seed, data_type=opt.data_type) dataloader_val = dataloader(opt.val_data_file, opt.wordvec_file, opt.rel2id_file, opt.similarity_file, max_len=opt.max_len, data_type=opt.data_type) dataloader_test = dataloader(opt.test_data_file, opt.wordvec_file, opt.rel2id_file, opt.similarity_file, max_len=opt.max_len, data_type=opt.data_type) word_emb_dim = dataloader_train._word_emb_dim_() word_vec_mat = dataloader_train._word_vec_mat_() # numpy.array float32 print('word_emb_dim is {}'.format(word_emb_dim)) # compile model print('-----Model Initializing-----') if opt.BERT != True: model.set_embedding_weight(word_vec_mat) if opt.load_model_name is not None: model.load_model(opt.load_model_name) os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu if torch.cuda.is_available(): torch.cuda.set_device(int(opt.gpu)) model, cuda_flag = tl.cudafy(model) if not cuda_flag: print("There is no gpu,use default cpu") count = tl.count_parameters(model) print("num of parameters:", count) # if the datasets are imbalanced such as nyt_su or trex , we load all test/dev data to perform open setting print('-----Validation Data Preparing-----') try: opt.data_type.index('imbalance') print("try load all imbalance dev data!") val_data, val_data_label = dataloader_val._data_() except: print("load part of data") if opt.data_type.startswith('fewrel'): val_data, val_data_label = dataloader_val._part_data_( 100 ) # 16 relation classes in validation data,each class has 100 sample in fewrel else: # other data sets has the problem of label imbalance val_data, val_data_label = dataloader_val._part_data_( 100 ) # for nyt_fb :sample 5 instance per relation, will get 490 dev instance print("-------Test Data Preparing--------") try: opt.data_type.index('imbalance') print("try load all imbalance test data!") test_data, test_data_label = dataloader_test._data_() except: print("load part of data") if opt.data_type.startswith('fewrel'): test_data, test_data_label = dataloader_test._data_() else: test_data, test_data_label = dataloader_test._data_( 100) # sample as the dev setting print("val_data:", len(val_data)) print("val_data_label:", len(set(val_data_label))) print("test_data:", len(test_data)) print("test_data_label:", len(set(test_data_label))) # intializing parameters batch_num_list = opt.batch_num msger_cluster = messager( save_path=save_path, types=[ 'method', 'temp_epoch', 'temp_batch_num', 'temp_batch_size', 'temp_lr', 'NMI', 'F1', 'precision', 'recall', 'msg' ], json_name='Validation_cluster_msg_' + time.strftime('%m{}%d{}_%H:%M'.format('月', '日')) + '.json') if opt.record_test: msger_test = messager( save_path=save_path, types=[ 'temp_globle_step', 'temp_batch_size', 'temp_learning_rate', 'NMI', 'F1', 'precision', 'recall', 'msg' ], json_name='Test_cluster_msg_' + time.strftime('%m{}%d{}_%H:%M'.format('月', '日')) + '.json') if opt.whether_visualize: loger = SummaryWriter(comment=opt.save_model_name) else: loger = None best_batch_step = 0 best_epoch = 0 batch_size_chose = -1 print_flag = opt.print_losses best_validation_f1 = 0 best_test_f1 = 0 loss_list = [] global_step = 0 for epoch in range(opt.epoch_num): print('------epoch {}------'.format(epoch)) print('max batch num to train is {}'.format(batch_num_list[epoch])) loss_reduce = 10000. early_stop_record = 0 for i in range(1, batch_num_list[epoch] + 1): global_step += 1 loss_list = model.train_self(opt, dataloader_train, loss_list, loger, batch_chose=batch_size_chose, global_step=global_step, temp_epoch=epoch) # print loss & record loss if i % 100 == 0: ave_loss = sum(loss_list) / 100. print('temp_batch_num: ', i, ' total_batch_num: ', batch_num_list[epoch], " ave_loss: ", ave_loss, ' temp learning rate: ', opt.lr[opt.lr_chose]) # empty the loss list loss_list = [] # visualize if opt.whether_visualize: loger.add_scalar('all_epoch_loss', ave_loss, global_step=global_step) # early stop if opt.early_stop is not None: if ave_loss < loss_reduce: early_stop_record = 0 loss_reduce = ave_loss else: early_stop_record += 1 if early_stop_record == opt.early_stop: print( "~~~~~~~~~ The loss can't be reduced in {} step, early stop! ~~~~~~~~~~~~" .format(opt.early_stop * 100)) cluster_result, cluster_msg, cluster_center, features = K_means_BERT( test_data, model.pred_vector, test_data_label, opt) if opt.BERT else K_means( test_data, model.pred_vector, len(np.unique(test_data_label)), opt) cluster_test_b3 = ClusterEvaluation( test_data_label, cluster_result).printEvaluation(extra_info=True, print_flag=True) print("learning rate decay num:", opt.lr_decay_num) print("learning rate decay step:", opt.lr_decay_record) print("best_epoch:", best_epoch) print("best_step:", best_batch_step) print("best_batch_size:", best_batch_size) print("best_cluster_eval_b3:", best_validation_f1) print("seed:", opt.seed) # clustering & validation if i % 200 == 0: print(opt.save_model_name, 'epoch:', epoch) with torch.no_grad(): # fewrel -> K-means ; nyt+su -> Mean-Shift if opt.dataset.startswith("fewrel"): print("chose k-means >>>") F_score = -1.0 best_cluster_result = None best_cluster_msg = None best_cluster_center = None best_features = None best_cluster_eval_b3 = None for iterion in range(opt.eval_num): K_num = opt.K_num if opt.K_num != 0 else len( np.unique(val_data_label)) cluster_result, cluster_msg, cluster_center, features = K_means_BERT( val_data, model.pred_vector, val_data_label, opt) if opt.BERT else K_means( val_data, model.pred_vector, K_num, opt) cluster_eval_b3 = ClusterEvaluation( val_data_label, cluster_result).printEvaluation( print_flag=False) if F_score < cluster_eval_b3['F1']: F_score = cluster_eval_b3['F1'] best_cluster_result = cluster_result best_cluster_msg = cluster_msg best_cluster_center = cluster_center best_features = features best_cluster_eval_b3 = cluster_eval_b3 cluster_result = best_cluster_result cluster_msg = best_cluster_msg cluster_center = best_cluster_center features = best_features cluster_eval_b3 = best_cluster_eval_b3 else: print("chose mean-shift >>>") cluster_result, cluster_msg, cluster_center, features = mean_shift_BERT( val_data, model.pred_vector, val_data_label, opt) if opt.BERT else mean_shift( val_data, model.pred_vector, opt) cluster_eval_b3 = ClusterEvaluation( val_data_label, cluster_result).printEvaluation(print_flag=False, extra_info=True) NMI_score = normalized_mutual_info_score( val_data_label, cluster_result) print("NMI:{} ,F1:{} ,precision:{} ,recall:{}".format( NMI_score, cluster_eval_b3['F1'], cluster_eval_b3['precision'], cluster_eval_b3['recall'], )) msger_cluster.record_message([ opt.select_cluster, epoch, i, opt.batch_size[batch_size_chose], model.lr, NMI_score, cluster_eval_b3['F1'], cluster_eval_b3['precision'], cluster_eval_b3['recall'], cluster_msg ]) msger_cluster.save_json() two_f1 = cluster_eval_b3['F1'] if two_f1 > best_validation_f1: # acc if opt.record_test == False: model.save_model(model_name=opt.save_model_name, global_step=global_step) best_batch_step = i best_epoch = epoch best_batch_size = opt.batch_size[batch_size_chose] best_validation_f1 = two_f1 if opt.whether_visualize: loger.add_embedding(features, metadata=val_data_label, label_img=None, global_step=global_step, tag='ground_truth', metadata_header=None) loger.add_embedding(features, metadata=cluster_result, label_img=None, global_step=global_step, tag='prediction', metadata_header=None) loger.add_scalar('all_epoch_NMI', NMI_score, global_step=global_step) loger.add_scalar('all_epoch_F1', cluster_eval_b3['F1'], global_step=global_step) loger.add_scalar('all_epoch_precision', cluster_eval_b3['precision'], global_step=global_step) loger.add_scalar('all_epoch_recall', cluster_eval_b3['recall'], global_step=global_step) if opt.record_test: if opt.dataset.startswith("fewrel"): cluster_result, cluster_msg, cluster_center, features = K_means_BERT( test_data, model.pred_vector, test_data_label, opt) if opt.BERT else K_means( test_data, model.pred_vector, len(np.unique(test_data_label)), opt) cluster_test_b3 = ClusterEvaluation( test_data_label, cluster_result).printEvaluation( print_flag=False) else: cluster_result, cluster_msg, cluster_center, features = mean_shift_BERT( test_data, model.pred_vector, test_data_label, opt) if opt.BERT else mean_shift( test_data, model.pred_vector, opt) cluster_test_b3 = ClusterEvaluation( test_data_label, cluster_result).printEvaluation( print_flag=False, extra_info=True) msger_test.record_message([ global_step, opt.batch_size[batch_size_chose], opt.lr[opt.lr_chose], NMI_score, cluster_test_b3['F1'], cluster_test_b3['precision'], cluster_test_b3['recall'], cluster_msg ]) msger_test.save_json() print('test messages saved.') if cluster_test_b3['F1'] > best_test_f1: model.save_model(model_name=opt.save_model_name, global_step=global_step) best_batch_step = i best_epoch = epoch best_batch_size = opt.batch_size[batch_size_chose] best_test_f1 = cluster_test_b3['F1'] model.lr_decay(opt) opt.lr_decay_record.append(global_step) print('End: The model is:', opt.save_model_name, opt.train_loss_type, opt.testset_loss_type) if opt.dataset.startswith("fewrel"): print('\n-----K-means Clustering test-----') best_test_b3, NMI_score = k_means_cluster_evaluation( model, opt, test_data, test_data_label, loger) else: print("\n-----------Mean_shift Clustering test:---------------") model.load_model(opt.save_model_name + "_best.pt") cluster_result_ms, cluster_msg_ms, _, _ = mean_shift_BERT( test_data, model.pred_vector, test_data_label, opt) if opt.BERT else mean_shift(test_data, model.pred_vector, opt) cluster_eval_b3_ms = ClusterEvaluation( test_data_label, cluster_result_ms).printEvaluation(print_flag=opt.print_losses, extra_info=True) NMI_score_ms = normalized_mutual_info_score(test_data_label, cluster_result_ms) best_test_b3 = cluster_eval_b3_ms NMI_score = NMI_score_ms if opt.whether_visualize: loger.add_scalar('test_NMI_MeanShift', NMI_score_ms, global_step=0) loger.add_scalar('test_F1_MeanShift', cluster_eval_b3_ms['F1'], global_step=0) print("learning rate decay num:", opt.lr_decay_num) print("learning rate decay step:", opt.lr_decay_record) print("best_epoch:", best_epoch) print("best_step:", best_batch_step) print("best_batch_size:", best_batch_size) print("best_cluster_eval_b3:", best_validation_f1) print("best_cluster_test_b3:", best_test_b3) print("best_NMI_score:", NMI_score) print("seed:", opt.seed)
def train_self(self, opt, dataloader_train, loss_list=None, loger=None, batch_chose=0, global_step=None, temp_epoch=0, chose_decay=False): # batch size 60, num_ratio 0.5 batch_size = opt.batch_size[batch_chose] class_num_ratio = opt.class_num_ratio[batch_chose] assert batch_size is not None self.train() # learning rate decay if temp_epoch > 0 and self.lr > 1e-8 and chose_decay: print("lr decay to {}!".format(self.lr * 0.1)) self.lr = self.lr * 0.1 param = [] param += [{ 'params': filter(lambda p: p.requires_grad, self.Bert_model.parameters()), 'lr': self.lr }] param += [{ 'params': list(self.FFL.parameters())[0], 'weight_decay': 1e-3, 'lr': self.lr_linear }] param += [{ 'params': list(self.FFL.parameters())[1], 'lr': self.lr_linear }] self.optimizer = optim.Adam(param).to(opt.device) # torch tensor batch_data [batch_size, sequence], batch_sentence is BERT input batch_data, batch_label, batch_sentence, cluster_label = dataloader_train.next_batch_cluster( batch_size, class_num_ratio, opt.batch_shuffle, opt.inclass_augment) # BERT forward with batch size of 8 marker_temper_data, b_input_ids, batch_label = self.bert_forward( batch_sentence, batch_label) batch_label, _ = tl.cudafy(batch_label) # Go through fully connect layer and RELU, then norm layer features = self.FFL(marker_temper_data) features = self.norm_layer(features) # Loss calculation loss = self.ml.cluster_loss(opt, features, batch_label, global_step) # check # Backward if isinstance(loss, torch.Tensor): self.optimizer.zero_grad() loss.backward() self.optimizer.step() loss_list.append(loss.item()) if opt.whether_visualize and global_step == 1: try: loger.add_graph(self, input_to_model=batch_data) except: print("*tensorboard : add graph failed") return loss_list
def train_self(self, opt, dataloader_train, loss_list=None, loger=None, batch_chose=0, global_step=None, temp_epoch=1): batch_size = opt.batch_size[batch_chose] class_num_ratio = opt.class_num_ratio[batch_chose] assert batch_size is not None self.train() batch_data, batch_label, cluster_label = dataloader_train.next_batch_cluster( batch_size, class_num_ratio, opt.batch_shuffle, opt.inclass_augment) batch_data, _ = tl.cudafy(batch_data) batch_label, _ = tl.cudafy(batch_label) wordembed, features = self.forward_norm( batch_data) # [batch_size,embedding_dim] if opt.VAT != 0 and temp_epoch >= opt.warm_up: # add VAT total_loss = 0.0 labels = batch_label margin = opt.margin alpha = opt.alpha_rank tval = opt.temp_neg encode_ori = features dist_mat = pairwise_distance(encode_ori, opt.squared) # [batch,batch] ori_distribution = torch.FloatTensor([]).cuda() ori_distribution.requires_grad = True # compute score_ori and RLL for achor in range(dist_mat.shape[0]): is_pos = labels.eq(labels[achor]) is_pos[achor] = 0 is_neg = labels.ne(labels[achor]) dist_ap = dist_mat[achor][is_pos] dist_an = dist_mat[achor][is_neg] ap_is_pos = torch.clamp(torch.add(dist_ap, margin - alpha), min=0.0) ap_pos_num = ap_is_pos.size(0) + 1e-5 ap_pos_val_sum = torch.sum(ap_is_pos) loss_ap = torch.div(ap_pos_val_sum, float(ap_pos_num)) an_is_pos = torch.lt(dist_an, alpha) an_less_alpha = dist_an[an_is_pos] an_weight = torch.exp(tval * (-1 * an_less_alpha + alpha)) an_weight_sum = torch.sum(an_weight) + 1e-5 an_dist_lm = alpha - an_less_alpha an_ln_sum = torch.sum(torch.mul(an_dist_lm, an_weight)) loss_an = torch.div(an_ln_sum, an_weight_sum) total_loss = total_loss + loss_ap + loss_an disturb, _ = tl.cudafy( torch.FloatTensor(wordembed.shape).uniform_(0, 1)) # kl_divergence for _ in range(opt.power_iterations): disturb.requires_grad = True disturb = (opt.p_mult) * F.normalize(disturb, p=2, dim=1) _, encode_disturb = self.forward_norm(batch_data, disturb) dist_el = torch.sum(torch.pow(encode_ori - encode_disturb, 2), dim=1).sqrt() diff = (dist_el / 2.0).clamp(0, 1.0 - 1e-7) disturb_gradient = torch.autograd.grad(diff.sum(), disturb, retain_graph=True)[0] disturb = disturb_gradient.detach() disturb = opt.p_mult * F.normalize(disturb, p=2, dim=1) # virtual adversarial loss _, encode_final = self.forward_norm(batch_data, disturb) # compute pair wise use the new embedding final_distribution = torch.FloatTensor([]).cuda() final_distribution.requires_grad = True dist_el = torch.sum(torch.pow(encode_ori - encode_final, 2), dim=1).sqrt() diff = (dist_el / 2.0).clamp(0, 1.0 - 1e-7) v_adv_losses = torch.mean(diff) loss = total_loss * 1.0 / dist_mat.size( 0) + v_adv_losses * opt.lambda_V assert torch.mean(v_adv_losses).item() > 0.0 else: self.ml = MetricLoss.metric_loss() loss = self.ml.cluster_loss(opt, features, batch_label, global_step) if isinstance(loss, torch.Tensor): self.optimizer.zero_grad() loss.backward() self.word_emb.word_embedding.weight.grad[-1] = 0 self.optimizer.step() loss_list.append(loss.item()) if opt.whether_visualize and global_step == 1: loger.add_graph(self, input_to_model=batch_data) return loss_list