示例#1
0
    def get_v_adv_loss(self, model: nn.Module, ul_left_input, ul_right_input, p_mult, power_iterations=1):
        bernoulli = dist.Bernoulli
        prob, left_word_emb, right_word_emb = model.siamese_forward(ul_left_input, ul_right_input)[0:3]
        prob = prob.clamp(min=1e-7, max=1. - 1e-7)
        prob_dist = bernoulli(probs=prob)
        # generate virtual adversarial perturbation
        left_d, _ = tl.cudafy(torch.FloatTensor(left_word_emb.shape).uniform_(0, 1))
        right_d, _ = tl.cudafy(torch.FloatTensor(right_word_emb.shape).uniform_(0, 1))
        left_d.requires_grad, right_d.requires_grad = True, True
        # prob_dist.requires_grad = True
        # kl_divergence
        for _ in range(power_iterations):
            left_d = (0.02) * F.normalize(left_d, p=2, dim=1)
            right_d = (0.02) * F.normalize(right_d, p=2, dim=1)
            # d1 = dist.Categorical(a)
            # d2 = dist.Categorical(torch.ones(5))
            p_prob = model.siamese_forward(ul_left_input, ul_right_input, left_d, right_d)[0]
            p_prob = p_prob.clamp(min=1e-7, max=1. - 1e-7)
            # torch.distribution
            try:
                kl = dist.kl_divergence(prob_dist, bernoulli(probs=p_prob))
            except:
                wait = True
            left_gradient, right_gradient = torch.autograd.grad(kl.sum(), [left_d, right_d], retain_graph=True)
            left_d = left_gradient.detach()
            right_d = right_gradient.detach()
        left_d = p_mult * F.normalize(left_d, p=2, dim=1)
        right_d = p_mult * F.normalize(right_d, p=2, dim=1)
        # virtual adversarial loss
        p_prob = model.siamese_forward(ul_left_input, ul_right_input, left_d, right_d)[0].clamp(min=1e-7, max=1. - 1e-7)
        v_adv_losses = dist.kl_divergence(prob_dist, bernoulli(probs=p_prob))

        return torch.mean(v_adv_losses)
示例#2
0
def pairwise_distance(embeddings, squared=False):
    pairwise_distances_squared = torch.sum(embeddings ** 2, dim=1, keepdim=True) + \
                                 torch.sum(embeddings.t() ** 2, dim=0, keepdim=True) - \
                                 2.0 * torch.matmul(embeddings, embeddings.t())

    error_mask = pairwise_distances_squared <= 0.0
    if squared:
        pairwise_distances = pairwise_distances_squared.clamp(min=0)
    else:
        pairwise_distances = pairwise_distances_squared.clamp(min=1e-16).sqrt()

    pairwise_distances = torch.mul(pairwise_distances, ~error_mask)

    num_data = embeddings.shape[0]
    # Explicitly set diagonals to zero.
    if pairwise_distances.is_cuda:
        mask_offdiagonals = torch.ones_like(pairwise_distances) - torch.diag(
            tl.cudafy(torch.ones([num_data]))[0])
    else:
        mask_offdiagonals = torch.ones_like(pairwise_distances) - torch.diag(
            torch.ones([num_data]))

    pairwise_distances = torch.mul(pairwise_distances, mask_offdiagonals)

    return pairwise_distances
示例#3
0
    def compute_gt_cluster_score(self, pairwise_distances, labels):
        """Compute ground truth facility location score.

        Loop over each unique classes and compute average travel distances.

        Args:
          pairwise_distances: 2-D numpy array of pairwise distances.
          labels: 1-D numpy array of ground truth cluster assignment.

        Returns:
          gt_cluster_score: dtypes.float32 score.
        """
        unique_class_ids = torch.unique(labels)
        num_classes = len(unique_class_ids)
        gt_cluster_score = tl.cudafy(torch.from_numpy(np.array([0.0])))[0]

        for i in range(num_classes):
            """Per each cluster, compute the average travel distance."""
            mask = labels == unique_class_ids[i]
            this_cluster_ids = torch.where(mask)[0]
            temp = (tl.gather(pairwise_distances, this_cluster_ids)).T
            pairwise_distances_subset = (tl.gather(temp, this_cluster_ids)).T
            this_cluster_score = -1.0 * torch.min(
                torch.sum(pairwise_distances_subset, 0))
            gt_cluster_score += this_cluster_score

        return gt_cluster_score
示例#4
0
    def cos_smi(self, data_left, data_right):
        self.eval()
        if isinstance(data_right, np.ndarray):
            data_right, _ = tl.cudafy(
                torch.from_numpy(np.array(data_right, dtype=np.int64)))
        if isinstance(data_left, np.ndarray):
            data_left, _ = tl.cudafy(
                torch.from_numpy(np.array(data_left, dtype=np.int64)))
        _, vector_l = self.forward_norm(data_left)
        _, vector_r = self.forward_norm(data_right)

        length_l = torch.sum(torch.pow(vector_l, 2), dim=1).sqrt()
        length_r = torch.sum(torch.pow(vector_r, 2), dim=1).sqrt()

        rns = torch.sum(torch.mul(vector_l, vector_r), dim=1) / torch.mul(
            length_l, length_r).float()

        return rns, vector_l, vector_r
示例#5
0
    def pred_vector(self, data, opt):
        self.eval()
        if not isinstance(data, torch.Tensor):
            data, _ = tl.cudafy(
                torch.from_numpy(np.array(data, dtype=np.int64)))
        if opt.train_loss_type.startswith('Siamese'):
            _, vectors = self.forward(data)
        else:
            _, vectors = self.forward_norm(data)

        return vectors
示例#6
0
    def pred_X(self, data_left, data_right):
        self.eval()
        if isinstance(data_right, np.ndarray):
            data_right, _ = tl.cudafy(
                torch.from_numpy(np.array(data_right, dtype=np.int64)))
        if isinstance(data_left, np.ndarray):
            data_left, _ = tl.cudafy(
                torch.from_numpy(np.array(data_left, dtype=np.int64)))
        _, vector_l = self.forward_norm(data_left)
        _, vector_r = self.forward_norm(data_right)

        distances_squared = torch.sum(torch.pow(vector_l - vector_r, 2), dim=1)
        if not self.squared:
            prediction = distances_squared.sqrt()
            # the euclidean dist between two normalized vector is in [0,2]
            rns = 1 - prediction / 2.0
        else:
            # the euclidean dist(squared) between two normalized vector is in [0,4]
            prediction = distances_squared
            rns = 1 - prediction / 4.0
        # prediction, _l, _r, encoded_l, encoded_r = self.siamese_forward(data_left, data_right)
        return rns, vector_l, vector_r
示例#7
0
def process_and_train_FL(model: BasicModel, opt: config.Option):
    # preparing saving files.
    save_path = os.path.join(opt.save_dir + '/model_file',
                             opt.save_model_name).replace('\\', '/')
    print("model file save path: ", save_path)

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    msger = messager(
        save_path=save_path,
        types=[
            'train_data_file', 'val_data_file', 'test_data_file',
            'load_model_name', 'save_model_name', 'trainset_loss_type',
            'testset_loss_type', 'class_num_ratio'
        ],
        json_name='train_information_msg_' +
        time.strftime('%m{}%d{}_%H:%M'.format('月', '日')) + '.json')
    msger.record_message([
        opt.train_data_file, opt.val_data_file, opt.test_data_file,
        opt.load_model_name, opt.save_model_name, opt.train_loss_type,
        opt.testset_loss_type, opt.class_num_ratio
    ])
    msger.save_json()

    # train data loading
    print('-----Data Loading-----')
    if opt.BERT:
        dataloader_train = dataloader_BERT(opt.train_data_file,
                                           opt.wordvec_file,
                                           opt.rel2id_file,
                                           opt.similarity_file,
                                           opt.same_level_pair_file,
                                           max_len=opt.max_len,
                                           random_init=opt.random_init,
                                           seed=opt.seed)
        dataloader_val = dataloader_BERT(opt.val_data_file,
                                         opt.wordvec_file,
                                         opt.rel2id_file,
                                         opt.similarity_file,
                                         max_len=opt.max_len)
        dataloader_test = dataloader_BERT(opt.test_data_file,
                                          opt.wordvec_file,
                                          opt.rel2id_file,
                                          opt.similarity_file,
                                          max_len=opt.max_len)
    else:
        dataloader_train = dataloader(opt.train_data_file,
                                      opt.wordvec_file,
                                      opt.rel2id_file,
                                      opt.similarity_file,
                                      opt.same_level_pair_file,
                                      max_len=opt.max_len,
                                      random_init=opt.random_init,
                                      seed=opt.seed,
                                      data_type=opt.data_type)
        dataloader_val = dataloader(opt.val_data_file,
                                    opt.wordvec_file,
                                    opt.rel2id_file,
                                    opt.similarity_file,
                                    max_len=opt.max_len,
                                    data_type=opt.data_type)
        dataloader_test = dataloader(opt.test_data_file,
                                     opt.wordvec_file,
                                     opt.rel2id_file,
                                     opt.similarity_file,
                                     max_len=opt.max_len,
                                     data_type=opt.data_type)
    word_emb_dim = dataloader_train._word_emb_dim_()
    word_vec_mat = dataloader_train._word_vec_mat_()  # numpy.array float32
    print('word_emb_dim is {}'.format(word_emb_dim))

    # compile model
    print('-----Model Initializing-----')
    if opt.BERT != True:
        model.set_embedding_weight(word_vec_mat)

    if opt.load_model_name is not None:
        model.load_model(opt.load_model_name)

    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
    if torch.cuda.is_available():
        torch.cuda.set_device(int(opt.gpu))

    model, cuda_flag = tl.cudafy(model)
    if not cuda_flag:
        print("There is no gpu,use default cpu")

    count = tl.count_parameters(model)
    print("num of parameters:", count)

    # if the datasets are imbalanced such as nyt_su or trex , we load all test/dev data to perform open setting
    print('-----Validation Data Preparing-----')
    try:
        opt.data_type.index('imbalance')
        print("try load all imbalance dev data!")
        val_data, val_data_label = dataloader_val._data_()
    except:
        print("load part of data")
        if opt.data_type.startswith('fewrel'):
            val_data, val_data_label = dataloader_val._part_data_(
                100
            )  # 16 relation classes in validation data,each class has 100 sample in fewrel
        else:
            # other data sets has the problem of label imbalance
            val_data, val_data_label = dataloader_val._part_data_(
                100
            )  # for nyt_fb :sample 5 instance per relation, will get 490 dev instance

    print("-------Test Data Preparing--------")
    try:
        opt.data_type.index('imbalance')
        print("try load all imbalance test data!")
        test_data, test_data_label = dataloader_test._data_()
    except:
        print("load part of data")
        if opt.data_type.startswith('fewrel'):
            test_data, test_data_label = dataloader_test._data_()
        else:
            test_data, test_data_label = dataloader_test._data_(
                100)  # sample as the dev setting

    print("val_data:", len(val_data))
    print("val_data_label:", len(set(val_data_label)))
    print("test_data:", len(test_data))
    print("test_data_label:", len(set(test_data_label)))

    # intializing parameters
    batch_num_list = opt.batch_num

    msger_cluster = messager(
        save_path=save_path,
        types=[
            'method', 'temp_epoch', 'temp_batch_num', 'temp_batch_size',
            'temp_lr', 'NMI', 'F1', 'precision', 'recall', 'msg'
        ],
        json_name='Validation_cluster_msg_' +
        time.strftime('%m{}%d{}_%H:%M'.format('月', '日')) + '.json')

    if opt.record_test:
        msger_test = messager(
            save_path=save_path,
            types=[
                'temp_globle_step', 'temp_batch_size', 'temp_learning_rate',
                'NMI', 'F1', 'precision', 'recall', 'msg'
            ],
            json_name='Test_cluster_msg_' +
            time.strftime('%m{}%d{}_%H:%M'.format('月', '日')) + '.json')

    if opt.whether_visualize:
        loger = SummaryWriter(comment=opt.save_model_name)
    else:
        loger = None

    best_batch_step = 0
    best_epoch = 0
    batch_size_chose = -1
    print_flag = opt.print_losses
    best_validation_f1 = 0
    best_test_f1 = 0
    loss_list = []
    global_step = 0
    for epoch in range(opt.epoch_num):
        print('------epoch {}------'.format(epoch))
        print('max batch num to train is {}'.format(batch_num_list[epoch]))
        loss_reduce = 10000.
        early_stop_record = 0
        for i in range(1, batch_num_list[epoch] + 1):
            global_step += 1
            loss_list = model.train_self(opt,
                                         dataloader_train,
                                         loss_list,
                                         loger,
                                         batch_chose=batch_size_chose,
                                         global_step=global_step,
                                         temp_epoch=epoch)

            # print loss & record loss
            if i % 100 == 0:
                ave_loss = sum(loss_list) / 100.
                print('temp_batch_num: ', i, ' total_batch_num: ',
                      batch_num_list[epoch], " ave_loss: ", ave_loss,
                      ' temp learning rate: ', opt.lr[opt.lr_chose])
                # empty the loss list
                loss_list = []
                # visualize
                if opt.whether_visualize:
                    loger.add_scalar('all_epoch_loss',
                                     ave_loss,
                                     global_step=global_step)
                # early stop
                if opt.early_stop is not None:
                    if ave_loss < loss_reduce:
                        early_stop_record = 0
                        loss_reduce = ave_loss
                    else:
                        early_stop_record += 1
                    if early_stop_record == opt.early_stop:
                        print(
                            "~~~~~~~~~ The loss can't be reduced in {} step, early stop! ~~~~~~~~~~~~"
                            .format(opt.early_stop * 100))
                        cluster_result, cluster_msg, cluster_center, features = K_means_BERT(
                            test_data, model.pred_vector, test_data_label,
                            opt) if opt.BERT else K_means(
                                test_data, model.pred_vector,
                                len(np.unique(test_data_label)), opt)
                        cluster_test_b3 = ClusterEvaluation(
                            test_data_label,
                            cluster_result).printEvaluation(extra_info=True,
                                                            print_flag=True)
                        print("learning rate decay num:", opt.lr_decay_num)
                        print("learning rate decay step:", opt.lr_decay_record)
                        print("best_epoch:", best_epoch)
                        print("best_step:", best_batch_step)
                        print("best_batch_size:", best_batch_size)
                        print("best_cluster_eval_b3:", best_validation_f1)
                        print("seed:", opt.seed)

            # clustering & validation
            if i % 200 == 0:
                print(opt.save_model_name, 'epoch:', epoch)
                with torch.no_grad():
                    # fewrel -> K-means ; nyt+su -> Mean-Shift
                    if opt.dataset.startswith("fewrel"):
                        print("chose k-means >>>")
                        F_score = -1.0
                        best_cluster_result = None
                        best_cluster_msg = None
                        best_cluster_center = None
                        best_features = None
                        best_cluster_eval_b3 = None
                        for iterion in range(opt.eval_num):
                            K_num = opt.K_num if opt.K_num != 0 else len(
                                np.unique(val_data_label))
                            cluster_result, cluster_msg, cluster_center, features = K_means_BERT(
                                val_data, model.pred_vector, val_data_label,
                                opt) if opt.BERT else K_means(
                                    val_data, model.pred_vector, K_num, opt)

                            cluster_eval_b3 = ClusterEvaluation(
                                val_data_label,
                                cluster_result).printEvaluation(
                                    print_flag=False)

                            if F_score < cluster_eval_b3['F1']:
                                F_score = cluster_eval_b3['F1']
                                best_cluster_result = cluster_result
                                best_cluster_msg = cluster_msg
                                best_cluster_center = cluster_center
                                best_features = features
                                best_cluster_eval_b3 = cluster_eval_b3

                        cluster_result = best_cluster_result
                        cluster_msg = best_cluster_msg
                        cluster_center = best_cluster_center
                        features = best_features
                        cluster_eval_b3 = best_cluster_eval_b3

                    else:
                        print("chose mean-shift >>>")
                        cluster_result, cluster_msg, cluster_center, features = mean_shift_BERT(
                            val_data, model.pred_vector, val_data_label,
                            opt) if opt.BERT else mean_shift(
                                val_data, model.pred_vector, opt)

                        cluster_eval_b3 = ClusterEvaluation(
                            val_data_label,
                            cluster_result).printEvaluation(print_flag=False,
                                                            extra_info=True)

                    NMI_score = normalized_mutual_info_score(
                        val_data_label, cluster_result)

                    print("NMI:{} ,F1:{} ,precision:{} ,recall:{}".format(
                        NMI_score,
                        cluster_eval_b3['F1'],
                        cluster_eval_b3['precision'],
                        cluster_eval_b3['recall'],
                    ))

                    msger_cluster.record_message([
                        opt.select_cluster, epoch, i,
                        opt.batch_size[batch_size_chose], model.lr, NMI_score,
                        cluster_eval_b3['F1'], cluster_eval_b3['precision'],
                        cluster_eval_b3['recall'], cluster_msg
                    ])

                    msger_cluster.save_json()

                    two_f1 = cluster_eval_b3['F1']
                    if two_f1 > best_validation_f1:  # acc
                        if opt.record_test == False:
                            model.save_model(model_name=opt.save_model_name,
                                             global_step=global_step)
                        best_batch_step = i
                        best_epoch = epoch
                        best_batch_size = opt.batch_size[batch_size_chose]
                        best_validation_f1 = two_f1

                    if opt.whether_visualize:
                        loger.add_embedding(features,
                                            metadata=val_data_label,
                                            label_img=None,
                                            global_step=global_step,
                                            tag='ground_truth',
                                            metadata_header=None)
                        loger.add_embedding(features,
                                            metadata=cluster_result,
                                            label_img=None,
                                            global_step=global_step,
                                            tag='prediction',
                                            metadata_header=None)
                        loger.add_scalar('all_epoch_NMI',
                                         NMI_score,
                                         global_step=global_step)
                        loger.add_scalar('all_epoch_F1',
                                         cluster_eval_b3['F1'],
                                         global_step=global_step)
                        loger.add_scalar('all_epoch_precision',
                                         cluster_eval_b3['precision'],
                                         global_step=global_step)
                        loger.add_scalar('all_epoch_recall',
                                         cluster_eval_b3['recall'],
                                         global_step=global_step)

                    if opt.record_test:
                        if opt.dataset.startswith("fewrel"):
                            cluster_result, cluster_msg, cluster_center, features = K_means_BERT(
                                test_data, model.pred_vector, test_data_label,
                                opt) if opt.BERT else K_means(
                                    test_data, model.pred_vector,
                                    len(np.unique(test_data_label)), opt)

                            cluster_test_b3 = ClusterEvaluation(
                                test_data_label,
                                cluster_result).printEvaluation(
                                    print_flag=False)
                        else:
                            cluster_result, cluster_msg, cluster_center, features = mean_shift_BERT(
                                test_data, model.pred_vector, test_data_label,
                                opt) if opt.BERT else mean_shift(
                                    test_data, model.pred_vector, opt)

                            cluster_test_b3 = ClusterEvaluation(
                                test_data_label,
                                cluster_result).printEvaluation(
                                    print_flag=False, extra_info=True)

                        msger_test.record_message([
                            global_step, opt.batch_size[batch_size_chose],
                            opt.lr[opt.lr_chose], NMI_score,
                            cluster_test_b3['F1'],
                            cluster_test_b3['precision'],
                            cluster_test_b3['recall'], cluster_msg
                        ])
                        msger_test.save_json()
                        print('test messages saved.')

                        if cluster_test_b3['F1'] > best_test_f1:
                            model.save_model(model_name=opt.save_model_name,
                                             global_step=global_step)
                            best_batch_step = i
                            best_epoch = epoch
                            best_batch_size = opt.batch_size[batch_size_chose]
                            best_test_f1 = cluster_test_b3['F1']

        model.lr_decay(opt)
        opt.lr_decay_record.append(global_step)

        print('End: The model is:', opt.save_model_name, opt.train_loss_type,
              opt.testset_loss_type)

    if opt.dataset.startswith("fewrel"):
        print('\n-----K-means Clustering test-----')
        best_test_b3, NMI_score = k_means_cluster_evaluation(
            model, opt, test_data, test_data_label, loger)
    else:
        print("\n-----------Mean_shift Clustering test:---------------")
        model.load_model(opt.save_model_name + "_best.pt")
        cluster_result_ms, cluster_msg_ms, _, _ = mean_shift_BERT(
            test_data, model.pred_vector, test_data_label,
            opt) if opt.BERT else mean_shift(test_data, model.pred_vector, opt)
        cluster_eval_b3_ms = ClusterEvaluation(
            test_data_label,
            cluster_result_ms).printEvaluation(print_flag=opt.print_losses,
                                               extra_info=True)
        NMI_score_ms = normalized_mutual_info_score(test_data_label,
                                                    cluster_result_ms)

        best_test_b3 = cluster_eval_b3_ms
        NMI_score = NMI_score_ms

        if opt.whether_visualize:
            loger.add_scalar('test_NMI_MeanShift', NMI_score_ms, global_step=0)
            loger.add_scalar('test_F1_MeanShift',
                             cluster_eval_b3_ms['F1'],
                             global_step=0)

    print("learning rate decay num:", opt.lr_decay_num)
    print("learning rate decay step:", opt.lr_decay_record)
    print("best_epoch:", best_epoch)
    print("best_step:", best_batch_step)
    print("best_batch_size:", best_batch_size)
    print("best_cluster_eval_b3:", best_validation_f1)
    print("best_cluster_test_b3:", best_test_b3)
    print("best_NMI_score:", NMI_score)
    print("seed:", opt.seed)
示例#8
0
    def train_self(self,
                   opt,
                   dataloader_train,
                   loss_list=None,
                   loger=None,
                   batch_chose=0,
                   global_step=None,
                   temp_epoch=0,
                   chose_decay=False):
        # batch size 60, num_ratio 0.5
        batch_size = opt.batch_size[batch_chose]
        class_num_ratio = opt.class_num_ratio[batch_chose]
        assert batch_size is not None
        self.train()

        # learning rate decay
        if temp_epoch > 0 and self.lr > 1e-8 and chose_decay:
            print("lr decay to {}!".format(self.lr * 0.1))
            self.lr = self.lr * 0.1
            param = []
            param += [{
                'params':
                filter(lambda p: p.requires_grad,
                       self.Bert_model.parameters()),
                'lr':
                self.lr
            }]
            param += [{
                'params': list(self.FFL.parameters())[0],
                'weight_decay': 1e-3,
                'lr': self.lr_linear
            }]
            param += [{
                'params': list(self.FFL.parameters())[1],
                'lr': self.lr_linear
            }]

            self.optimizer = optim.Adam(param).to(opt.device)

        # torch tensor batch_data [batch_size, sequence], batch_sentence is BERT input
        batch_data, batch_label, batch_sentence, cluster_label = dataloader_train.next_batch_cluster(
            batch_size, class_num_ratio, opt.batch_shuffle,
            opt.inclass_augment)

        # BERT forward with batch size of 8
        marker_temper_data, b_input_ids, batch_label = self.bert_forward(
            batch_sentence, batch_label)
        batch_label, _ = tl.cudafy(batch_label)

        # Go through fully connect layer and RELU, then norm layer
        features = self.FFL(marker_temper_data)
        features = self.norm_layer(features)

        # Loss calculation
        loss = self.ml.cluster_loss(opt, features, batch_label,
                                    global_step)  # check

        # Backward
        if isinstance(loss, torch.Tensor):
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            loss_list.append(loss.item())

        if opt.whether_visualize and global_step == 1:
            try:
                loger.add_graph(self, input_to_model=batch_data)
            except:
                print("*tensorboard : add graph failed")

        return loss_list
示例#9
0
    def train_self(self,
                   opt,
                   dataloader_train,
                   loss_list=None,
                   loger=None,
                   batch_chose=0,
                   global_step=None,
                   temp_epoch=1):
        batch_size = opt.batch_size[batch_chose]
        class_num_ratio = opt.class_num_ratio[batch_chose]

        assert batch_size is not None
        self.train()

        batch_data, batch_label, cluster_label = dataloader_train.next_batch_cluster(
            batch_size, class_num_ratio, opt.batch_shuffle,
            opt.inclass_augment)
        batch_data, _ = tl.cudafy(batch_data)
        batch_label, _ = tl.cudafy(batch_label)

        wordembed, features = self.forward_norm(
            batch_data)  # [batch_size,embedding_dim]

        if opt.VAT != 0 and temp_epoch >= opt.warm_up:
            # add VAT
            total_loss = 0.0
            labels = batch_label
            margin = opt.margin
            alpha = opt.alpha_rank
            tval = opt.temp_neg
            encode_ori = features

            dist_mat = pairwise_distance(encode_ori,
                                         opt.squared)  # [batch,batch]
            ori_distribution = torch.FloatTensor([]).cuda()
            ori_distribution.requires_grad = True

            # compute score_ori and RLL
            for achor in range(dist_mat.shape[0]):
                is_pos = labels.eq(labels[achor])
                is_pos[achor] = 0
                is_neg = labels.ne(labels[achor])

                dist_ap = dist_mat[achor][is_pos]
                dist_an = dist_mat[achor][is_neg]

                ap_is_pos = torch.clamp(torch.add(dist_ap, margin - alpha),
                                        min=0.0)
                ap_pos_num = ap_is_pos.size(0) + 1e-5
                ap_pos_val_sum = torch.sum(ap_is_pos)
                loss_ap = torch.div(ap_pos_val_sum, float(ap_pos_num))

                an_is_pos = torch.lt(dist_an, alpha)
                an_less_alpha = dist_an[an_is_pos]
                an_weight = torch.exp(tval * (-1 * an_less_alpha + alpha))
                an_weight_sum = torch.sum(an_weight) + 1e-5
                an_dist_lm = alpha - an_less_alpha
                an_ln_sum = torch.sum(torch.mul(an_dist_lm, an_weight))
                loss_an = torch.div(an_ln_sum, an_weight_sum)
                total_loss = total_loss + loss_ap + loss_an

            disturb, _ = tl.cudafy(
                torch.FloatTensor(wordembed.shape).uniform_(0, 1))

            # kl_divergence
            for _ in range(opt.power_iterations):
                disturb.requires_grad = True
                disturb = (opt.p_mult) * F.normalize(disturb, p=2, dim=1)
                _, encode_disturb = self.forward_norm(batch_data, disturb)
                dist_el = torch.sum(torch.pow(encode_ori - encode_disturb, 2),
                                    dim=1).sqrt()
                diff = (dist_el / 2.0).clamp(0, 1.0 - 1e-7)
                disturb_gradient = torch.autograd.grad(diff.sum(),
                                                       disturb,
                                                       retain_graph=True)[0]
                disturb = disturb_gradient.detach()

            disturb = opt.p_mult * F.normalize(disturb, p=2, dim=1)

            # virtual adversarial loss
            _, encode_final = self.forward_norm(batch_data, disturb)

            # compute pair wise use the new embedding
            final_distribution = torch.FloatTensor([]).cuda()
            final_distribution.requires_grad = True
            dist_el = torch.sum(torch.pow(encode_ori - encode_final, 2),
                                dim=1).sqrt()
            diff = (dist_el / 2.0).clamp(0, 1.0 - 1e-7)
            v_adv_losses = torch.mean(diff)
            loss = total_loss * 1.0 / dist_mat.size(
                0) + v_adv_losses * opt.lambda_V
            assert torch.mean(v_adv_losses).item() > 0.0
        else:
            self.ml = MetricLoss.metric_loss()
            loss = self.ml.cluster_loss(opt, features, batch_label,
                                        global_step)

        if isinstance(loss, torch.Tensor):
            self.optimizer.zero_grad()
            loss.backward()
            self.word_emb.word_embedding.weight.grad[-1] = 0
            self.optimizer.step()
            loss_list.append(loss.item())

        if opt.whether_visualize and global_step == 1:
            loger.add_graph(self, input_to_model=batch_data)

        return loss_list