예제 #1
0
 def get_v_adv_loss(self, ul_left_input, ul_right_input, p_mult, power_iterations=1):
     bernoulli = dist.Bernoulli
     prob, left_word_emb, right_word_emb = self.siamese_forward(ul_left_input, ul_right_input)[0:3]
     prob = prob.clamp(min=1e-7, max=1. - 1e-7)
     prob_dist = bernoulli(probs=prob)
     # generate virtual adversarial perturbation
     left_d = cudafy(torch.FloatTensor(left_word_emb.shape).uniform_(0, 1))
     right_d = cudafy(torch.FloatTensor(right_word_emb.shape).uniform_(0, 1))
     left_d.requires_grad, right_d.requires_grad = True, True
     # prob_dist.requires_grad = True
     # kl_divergence
     for _ in range(power_iterations):
         left_d = (0.02) * F.normalize(left_d, p=2, dim=1)
         right_d = (0.02) * F.normalize(right_d, p=2, dim=1)
         # d1 = dist.Categorical(a)
         # d2 = dist.Categorical(torch.ones(5))
         p_prob = self.siamese_forward(ul_left_input, ul_right_input, left_d, right_d)[0]
         p_prob = p_prob.clamp(min=1e-7, max=1. - 1e-7)
         # torch.distribution
         try:
             kl = dist.kl_divergence(prob_dist, bernoulli(probs=p_prob))
         except:
             wait = True
         left_gradient, right_gradient = torch.autograd.grad(kl.sum(), [left_d, right_d], retain_graph=True)
         left_d = left_gradient.detach()
         right_d = right_gradient.detach()
     left_d = p_mult * F.normalize(left_d, p=2, dim=1)
     right_d = p_mult * F.normalize(right_d, p=2, dim=1)
     # virtual adversarial loss
     p_prob = self.siamese_forward(ul_left_input, ul_right_input, left_d, right_d)[0].clamp(min=1e-7, max=1. - 1e-7)
     v_adv_losses = dist.kl_divergence(prob_dist, bernoulli(probs=p_prob))
     return torch.mean(v_adv_losses)
예제 #2
0
 def pred_X(self, data_left, data_right):
     self.eval()
     if isinstance(data_right, np.ndarray):
         data_right = cudafy(torch.from_numpy(np.array(data_right, dtype=np.int64)))
     if isinstance(data_left, np.ndarray):
         data_left = cudafy(torch.from_numpy(np.array(data_left, dtype=np.int64)))
     prediction, _l, _r, encoded_l, encoded_r = self.siamese_forward(data_left, data_right)
     return prediction, encoded_l, encoded_r
예제 #3
0
    def get_triplet_distance_v_adv_loss(self, ori_encoded, ori_word_emb, batch_input, p_mult, power_iterations=1):
        batch_d = cudafy(torch.FloatTensor(ori_word_emb.shape).uniform_(0, 1))
        batch_d.requires_grad = True
        criterion = TripletSemihardLoss()
        batch_size = batch_input.shape[0]
        ori_dist_mat = criterion.pairwise_distance(ori_encoded)
        for _ in range(power_iterations):
            batch_d = (0.02) * F.normalize(batch_d, p=2, dim=1)
            p_encoded = self.forward_norm(batch_input, batch_d)[1]

            p_dist_mat = criterion.pairwise_distance(p_encoded)
            # abs:
            # temp_loss = torch.abs(ori_dist_mat - p_dist_mat).sum()
            # squared:(better)
            temp_loss = (ori_dist_mat - p_dist_mat).pow(2).sum()

            batch_gradient = torch.autograd.grad(temp_loss, batch_d, retain_graph=True)[0]
            batch_d = batch_gradient.detach()

        # worst
        batch_d = p_mult * F.normalize(batch_d, p=2, dim=1)
        p_encoded = self.forward_norm(batch_input, batch_d)[1]
        # distance_loss:

        p_dist_mat = criterion.pairwise_distance(p_encoded)

        # squared
        v_adv_loss = (ori_dist_mat - p_dist_mat).pow(2).mean() * (batch_size / (batch_size - 1))
        # abs
        # v_adv_loss = torch.abs(ori_dist_mat - p_dist_mat).mean()

        return v_adv_loss
예제 #4
0
    def get_triplet_v_adv_loss(self, ori_triplet_loss, ori_word_emb, batch_input, labels, margins, p_mult,
                               power_iterations=1):
        # bernoulli = dist.Bernoulli
        batch_d = cudafy(torch.FloatTensor(ori_word_emb.shape).uniform_(0, 1))
        batch_d.requires_grad = True
        for _ in range(power_iterations):
            batch_d = (0.02) * F.normalize(batch_d, p=2, dim=1)
            p_encoded = self.forward_norm(batch_input, batch_d)[1]
            triplet_loss = self.get_triplet_semihard_loss(p_encoded, labels, self.margin)
            batch_gradient = torch.autograd.grad(triplet_loss, batch_d, retain_graph=True)[0]
            batch_d = batch_gradient.detach()
            # left_d = left_gradient.detach()
            # right_d = right_gradient.detach()
        # worst
        batch_d = p_mult * F.normalize(batch_d, p=2, dim=1)
        p_encoded = self.forward_norm(batch_input, batch_d)[1]
        # triplet_loss = self.get_triplet_semihard_loss(p_encoded, labels, self.margin)

        triplet_loss = self.get_triplet_semihard_loss(p_encoded, labels, margins)
        # squared
        v_adv_loss = (triplet_loss - ori_triplet_loss).pow(2)
        # abs (better)
        # v_adv_loss = torch.abs(triplet_loss - ori_triplet_loss)

        return v_adv_loss
예제 #5
0
 def pred_vector(self, data):
     self.eval()
     if isinstance(data, list):
         data = cudafy(torch.from_numpy(np.array(data, dtype=np.int64)))
     if self.train_loss_type.startswith('triplet'):
         _, vectors = self.forward_norm(data)
         # _, vectors = self.forward(data)
     else:
         _, vectors = self.forward(data)
     return vectors
예제 #6
0
    def train_triplet_same_level(self, dataloader_trainset, batch_size=None, dynamic_margin=True, K_num=4, level=1,
                                 same_v_adv=True):
        if batch_size is None:
            batch_size = self.batch_size
        self.train()
        torch.retain_graph = True

        batch_input, data_label, margins = dataloader_trainset.next_triplet_same_level_batch(batch_size, K_num,
                                                                                             dynamic_margin, level)
        batch_input, data_label, margins = cudafy(batch_input), cudafy(data_label), cudafy(margins)
        margins = margins * self.margin if dynamic_margin else self.margin

        word_embed, encoded = self.forward_norm(batch_input)
        loss = self.get_triplet_semihard_loss(encoded, data_label, margins)
        if self.train_loss_type == 'triplet_v_adv' and same_v_adv:
            get_triplet_v_adv_loss = self.get_triplet_v_adv_loss(loss, word_embed, batch_input, data_label, margins,
                                                                 self.p_mult)
            loss += get_triplet_v_adv_loss
        elif self.train_loss_type == 'triplet_v_adv_distance' and same_v_adv:
            get_triplet_v_adv_loss = self.get_triplet_distance_v_adv_loss(encoded, word_embed, batch_input, self.p_mult)
            loss += get_triplet_v_adv_loss

        self.back_propagation(loss)
예제 #7
0
    def train_RSN(self, dataloader_trainset, dataloader_testset, batch_size=None, same_ratio=0.06):
        if batch_size is None:
            batch_size = self.batch_size
        self.train()

        data_left, data_right, data_label = dataloader_trainset.next_batch(batch_size, same_ratio=same_ratio)
        data_left, data_right, data_label = cudafy(data_left), cudafy(data_right), cudafy(data_label)
        prediction, left_word_emb, right_word_emb, encoded_l, encoded_r = self.siamese_forward(data_left, data_right)
        if self.train_loss_type == "cross":
            loss = self.get_cross_entropy_loss(prediction, labels=data_label)
        elif self.train_loss_type == "cross_denoise":
            loss = self.get_cross_entropy_loss(prediction, labels=data_label)
            loss += self.get_cond_loss(prediction) * self.p_denoise
        elif self.train_loss_type == "v_adv":
            loss = self.get_cross_entropy_loss(prediction, labels=data_label)
            loss += self.get_v_adv_loss(data_left, data_right, self.p_mult) * self.lambda_s
        elif self.train_loss_type == "v_adv_denoise":
            loss = self.get_cross_entropy_loss(prediction, labels=data_label)
            loss += self.get_v_adv_loss(data_left, data_right, self.p_mult) * self.lambda_s
            loss += self.get_cond_loss(prediction) * self.p_denoise
        else:
            raise NotImplementedError()

        self.back_propagation(loss)
def train_SN(train_data_file,
             val_data_file,
             test_data_file,
             wordvec_file,
             load_model_name=None,
             save_model_name='SN',
             trainset_loss_type='triplet',
             testset_loss_type='none',
             testset_loss_mask_epoch=3,
             p_cond=0.03,
             p_denoise=1.0,
             rel2id_file=None,
             similarity_file=None,
             dynamic_margin=True,
             margin=1.0,
             louvain_weighted=False,
             level_train=False,
             shallow_to_deep=False,
             same_level_pair_file=None,
             max_len=120,
             pos_emb_dim=5,
             same_ratio=0.06,
             batch_size=64,
             batch_num=10000,
             epoch_num=1,
             val_size=10000,
             select_cluster=None,
             omit_relid=None,
             labeled_sample_num=None,
             squared=True,
             same_level_part=None,
             mask_same_level_epoch=1,
             same_v_adv=False,
             random_init=False,
             seed=42,
             K_num=4,
             evaluate_hierarchy=False,
             train_for_cluster_file=None,
             train_structure_file=None,
             all_structure_file=None,
             to_cluster_data_num=100,
             incre_threshold=0,
             iso_threshold=5,
             avg_link_increment=True,
             modularity_increment=False):
    # preparing saving files.
    if select_cluster is None:
        select_cluster = ['Louvain']
    if load_model_name is not None:

        load_path = os.path.join('model_file',
                                 load_model_name).replace('\\', '/')
    else:
        load_path = None

    save_path = os.path.join('model_file', save_model_name).replace('\\', '/')
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    msger = messager(save_path=save_path,
                     types=[
                         'train_data_file', 'val_data_file', 'test_data_file',
                         'load_model_name', 'save_model_name',
                         'trainset_loss_type', 'testset_loss_type',
                         'testset_loss_mask_epoch', 'p_cond', 'p_denoise',
                         'same_ratio', 'labeled_sample_num'
                     ],
                     json_name='train_msg.json')
    msger.record_message([
        train_data_file, val_data_file, test_data_file, load_model_name,
        save_model_name, trainset_loss_type, testset_loss_type,
        testset_loss_mask_epoch, p_cond, p_denoise, same_ratio,
        labeled_sample_num
    ])
    msger.save_json()

    print('-----Data Loading-----')
    # for train
    dataloader_train = dataloader(train_data_file,
                                  wordvec_file,
                                  rel2id_file,
                                  similarity_file,
                                  same_level_pair_file,
                                  max_len=max_len,
                                  random_init=random_init,
                                  seed=seed)
    # for cluster never seen instances
    dataloader_train_for_cluster = dataloader(train_for_cluster_file,
                                              wordvec_file,
                                              rel2id_file,
                                              similarity_file,
                                              same_level_pair_file,
                                              max_len=max_len,
                                              random_init=random_init,
                                              seed=seed)
    # for validation, to select best model
    dataloader_val = dataloader(val_data_file,
                                wordvec_file,
                                rel2id_file,
                                similarity_file,
                                max_len=max_len)
    # for cluster
    dataloader_test = dataloader(test_data_file,
                                 wordvec_file,
                                 rel2id_file,
                                 similarity_file,
                                 max_len=max_len)
    word_emb_dim = dataloader_train._word_emb_dim_()
    word_vec_mat = dataloader_train._word_vec_mat_()
    print('word_emb_dim is {}'.format(word_emb_dim))

    # compile model
    print('-----Model Initializing-----')

    rsn = RSN(word_vec_mat=word_vec_mat,
              max_len=max_len,
              pos_emb_dim=pos_emb_dim,
              dropout=0.2)

    if load_path:
        rsn.load_model(load_path)
    rsn = cudafy(rsn)
    rsn.set_train_op(batch_size=batch_size,
                     train_loss_type=trainset_loss_type,
                     testset_loss_type=testset_loss_type,
                     p_cond=p_cond,
                     p_denoise=p_denoise,
                     p_mult=0.02,
                     squared=squared,
                     margin=margin)

    print('-----Validation Data Preparing-----')

    val_data, val_data_label = dataloader_val._part_data_(100)

    print('-----Clustering Data Preparing-----')
    train_hierarchy_structure_info = json.load(open(train_structure_file))
    all_hierarchy_structure_info = json.load(open(all_structure_file))
    train_hierarchy_cluster_list, gt_hierarchy_cluster_list, train_data_num, test_data_num, train_data, train_label, test_data, test_label = prepare_cluster_list(
        dataloader_train_for_cluster, dataloader_test,
        train_hierarchy_structure_info, all_hierarchy_structure_info,
        to_cluster_data_num)
    batch_num_list = [batch_num] * epoch_num
    # start_cluster_accuracy = 0.5
    best_validation_f1 = 0
    least_epoch = 1
    best_step = 0
    for epoch in range(epoch_num):
        msger = messager(save_path=save_path,
                         types=[
                             'batch_num', 'train_tp', 'train_fp', 'train_fn',
                             'train_tn', 'train_l', 'test_tp', 'test_fp',
                             'test_fn', 'test_tn', 'test_l'
                         ],
                         json_name='SNmsg' + str(epoch) + '.json')
        # for cluster
        # test_data, test_data_label = dataloader_test._data_()
        print('------epoch {}------'.format(epoch))
        print('max batch num to train is {}'.format(batch_num_list[epoch]))
        for i in range(1, batch_num_list[epoch] + 1):
            to_cluster_flag = False
            if trainset_loss_type.startswith("triplet"):
                if level_train and epoch < mask_same_level_epoch:
                    if i <= 1 / same_level_part * batch_num_list[epoch]:
                        rsn.train_triplet_same_level(
                            dataloader_train,
                            batch_size=batch_size,
                            K_num=4,
                            dynamic_margin=dynamic_margin,
                            level=1,
                            same_v_adv=same_v_adv)
                    elif i <= 2 / same_level_part * batch_num_list[epoch]:
                        rsn.train_triplet_same_level(
                            dataloader_train,
                            batch_size=batch_size,
                            K_num=4,
                            dynamic_margin=dynamic_margin,
                            level=2,
                            same_v_adv=same_v_adv)
                    else:
                        rsn.train_triplet_loss(dataloader_train,
                                               batch_size=batch_size,
                                               dynamic_margin=dynamic_margin)
                else:
                    rsn.train_triplet_loss(dataloader_train,
                                           batch_size=batch_size,
                                           dynamic_margin=dynamic_margin)
            else:
                rsn.train_RSN(dataloader_train,
                              dataloader_test,
                              batch_size=batch_size)

            if i % 100 == 0:
                print('temp_batch_num: ', i, ' total_batch_num: ',
                      batch_num_list[epoch])
            if i % 1000 == 0 and epoch >= least_epoch:
                print(save_model_name, 'epoch:', epoch)

                print('Validation:')
                cluster_result, cluster_msg = Louvain_no_isolation(
                    dataset=val_data,
                    edge_measure=rsn.pred_X,
                    weighted=louvain_weighted)
                cluster_eval_b3 = ClusterEvaluation(
                    val_data_label,
                    cluster_result).printEvaluation(print_flag=False)

                cluster_eval_new = ClusterEvaluationNew(
                    val_data_label,
                    cluster_result).printEvaluation(print_flag=False)
                two_f1 = cluster_eval_new['F1']
                if two_f1 > best_validation_f1:  # acc
                    to_cluster_flag = True
                    best_step = i
                    best_validation_f1 = two_f1

            if to_cluster_flag:
                # if True:
                if 'Louvain' in select_cluster:
                    print('-----Top Down Hierarchy Louvain Clustering-----')
                    if avg_link_increment:
                        # link_th_list = [0.5, 1, 2, 5, 10, 15, 20, 50, 100]
                        # link_th_list = [0.05, 0.08, 0.1, 0.12, 0.15, 0.18, 0.2, 0.3, 0.4]
                        # link_th_list = [i * 0.02 for i in range(1, 100)]
                        link_th_list = [0.05]
                        cluster_result, cluster_msg = Louvain_no_isolation(
                            dataset=test_data,
                            edge_measure=rsn.pred_X,
                            weighted=louvain_weighted)

                        predicted_cluster_dict_list = Top_Down_Louvain_with_test_cluster_done_avg_link_list(
                            # predicted_cluster_dict_list = Louvain_with_test_cluster_done_avg_link_list(
                            cluster_result,
                            train_data_num,
                            test_data_num,
                            train_data,
                            test_data,
                            train_hierarchy_cluster_list,
                            rsn.pred_X,
                            link_th_list)
                        best_hyper_score = 0
                        best_eval_info = None
                        for predicted_cluster_dict in predicted_cluster_dict_list:
                            predicted_cluster_list = predicted_cluster_dict[
                                'list']
                            evaluation = HierarchyClusterEvaluation(
                                gt_hierarchy_cluster_list,
                                predicted_cluster_list, test_data_num)
                            eval_info = evaluation.printEvaluation()
                            if eval_info['total_F1'] > best_hyper_score:
                                best_eval_info = eval_info
                                best_hyper_score = eval_info['total_F1']

                    rsn.save_model(save_path=save_path,
                                   global_step=i + epoch * batch_num)
                    print('model and clustering messages saved.')
        print('End: The model is:', save_model_name, trainset_loss_type,
              testset_loss_type, 'p_cond is:', p_cond)
    print(seed)
    print("best step:", best_step)
    print("new metric Info:")
    print("F1(%)")
    print(best_eval_info['match_f1'] * 100)

    print("taxonomy Info:")
    print("Precision(%); Recall(%); F1(%)")
    print(round(best_eval_info['taxonomy_precision'] * 100, 3), "; ",
          round(best_eval_info['taxonomy_recall'] * 100, 3), "; ",
          round(best_eval_info['taxonomy_F1'] * 100, 3))

    print("Total Info:")
    print("Precision(%); Recall(%); F1(%)")
    print(round(best_eval_info['total_precision'] * 100, 3), "; ",
          round(best_eval_info['total_recall'] * 100, 3), "; ",
          round(best_eval_info['total_F1'] * 100, 3))
예제 #9
0
파일: train_OHRE.py 프로젝트: thunlp/OHRE
def train_SN(train_data_file,
             val_data_file,
             test_data_file,
             wordvec_file,
             load_model_name=None,
             save_model_name='SN',
             trainset_loss_type='triplet',
             testset_loss_type='none',
             testset_loss_mask_epoch=3,
             p_cond=0.03,
             p_denoise=1.0,
             rel2id_file=None,
             similarity_file=None,
             dynamic_margin=True,
             margin=1.0,
             louvain_weighted=False,
             level_train=False,
             shallow_to_deep=False,
             same_level_pair_file=None,
             max_len=120,
             pos_emb_dim=5,
             same_ratio=0.06,
             batch_size=64,
             batch_num=10000,
             epoch_num=1,
             val_size=10000,
             select_cluster=None,
             omit_relid=None,
             labeled_sample_num=None,
             squared=True,
             same_level_part=None,
             mask_same_level_epoch=1,
             same_v_adv=False,
             random_init=False,
             seed=42,
             K_num=4,
             evaluate_hierarchy=False,
             gt_hierarchy_file=None):
    # preparing saving files.
    if select_cluster is None:
        select_cluster = ['Louvain']
    if load_model_name is not None:
        load_path = os.path.join('model_file',
                                 load_model_name).replace('\\', '/')
    else:
        load_path = None

    save_path = os.path.join('model_file', save_model_name).replace('\\', '/')
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    msger = messager(save_path=save_path,
                     types=[
                         'train_data_file', 'val_data_file', 'test_data_file',
                         'load_model_name', 'save_model_name',
                         'trainset_loss_type', 'testset_loss_type',
                         'testset_loss_mask_epoch', 'p_cond', 'p_denoise',
                         'same_ratio', 'labeled_sample_num'
                     ],
                     json_name='train_msg.json')
    msger.record_message([
        train_data_file, val_data_file, test_data_file, load_model_name,
        save_model_name, trainset_loss_type, testset_loss_type,
        testset_loss_mask_epoch, p_cond, p_denoise, same_ratio,
        labeled_sample_num
    ])
    msger.save_json()
    # if not trainset_loss_type.startswith("triplet"):
    #     batch_size = 100
    # train data loading
    print('-----Data Loading-----')
    dataloader_train = dataloader(train_data_file,
                                  wordvec_file,
                                  rel2id_file,
                                  similarity_file,
                                  same_level_pair_file,
                                  max_len=max_len,
                                  random_init=random_init,
                                  seed=seed)
    dataloader_val = dataloader(val_data_file,
                                wordvec_file,
                                rel2id_file,
                                similarity_file,
                                max_len=max_len)
    dataloader_test = dataloader(test_data_file,
                                 wordvec_file,
                                 rel2id_file,
                                 similarity_file,
                                 max_len=max_len)
    word_emb_dim = dataloader_train._word_emb_dim_()
    word_vec_mat = dataloader_train._word_vec_mat_()
    print('word_emb_dim is {}'.format(word_emb_dim))

    # compile model
    print('-----Model Initializing-----')

    rsn = RSN(word_vec_mat=word_vec_mat,
              max_len=max_len,
              pos_emb_dim=pos_emb_dim,
              dropout=0.2)
    # rsn
    if load_path:
        rsn.load_model(load_path)
    rsn = cudafy(rsn)
    rsn.set_train_op(batch_size=batch_size,
                     train_loss_type=trainset_loss_type,
                     testset_loss_type=testset_loss_type,
                     p_cond=p_cond,
                     p_denoise=p_denoise,
                     p_mult=0.02,
                     squared=squared,
                     margin=margin)

    print('-----Validation Data Preparing-----')

    val_data, val_data_label = dataloader_val._part_data_(100)

    # intializing parameters
    batch_num_list = [batch_num] * epoch_num
    # clustering_test_time = np.arange(19999, batch_num, 20000).tolist()
    msger_cluster = messager(
        save_path=save_path,
        types=['method', 'temp_batch_num', 'F1', 'precision', 'recall', 'msg'],
        json_name='cluster_msg.json')
    # best_validation_accuracy = 0.9
    least_epoch = 1
    best_step = 0
    print_flag = True
    best_validation_f1 = 0
    for epoch in range(epoch_num):
        test_data, test_data_label = dataloader_test._data_()
        print('------epoch {}------'.format(epoch))
        print('max batch num to train is {}'.format(batch_num_list[epoch]))
        for i in range(1, batch_num_list[epoch] + 1):
            to_cluster_flag = False
            if trainset_loss_type.startswith("triplet"):
                if level_train and epoch < mask_same_level_epoch:
                    if i <= 1 / same_level_part * batch_num_list[epoch]:
                        rsn.train_triplet_same_level(
                            dataloader_train,
                            batch_size=batch_size,
                            K_num=4,
                            dynamic_margin=dynamic_margin,
                            level=1,
                            same_v_adv=same_v_adv)
                    elif i <= 2 / same_level_part * batch_num_list[epoch]:
                        rsn.train_triplet_same_level(
                            dataloader_train,
                            batch_size=batch_size,
                            K_num=4,
                            dynamic_margin=dynamic_margin,
                            level=2,
                            same_v_adv=same_v_adv)
                    else:
                        rsn.train_triplet_loss(dataloader_train,
                                               batch_size=batch_size,
                                               dynamic_margin=dynamic_margin)
                else:
                    rsn.train_triplet_loss(dataloader_train,
                                           batch_size=batch_size,
                                           dynamic_margin=dynamic_margin)
            else:
                rsn.train_RSN(dataloader_train,
                              dataloader_test,
                              batch_size=batch_size)

            if i % 500 == 0:
                print('temp_batch_num: ', i, ' total_batch_num: ',
                      batch_num_list[epoch])
            if i % 1000 == 0 and epoch >= least_epoch:
                print(save_model_name, 'epoch:', epoch)

                print('Validation:')
                cluster_result, cluster_msg = Louvain_no_isolation(
                    dataset=val_data,
                    edge_measure=rsn.pred_X,
                    weighted=louvain_weighted)
                cluster_eval_new = ClusterEvaluationNew(
                    val_data_label,
                    cluster_result).printEvaluation(print_flag=False)

                cluster_eval_b3 = ClusterEvaluation(
                    val_data_label,
                    cluster_result).printEvaluation(print_flag=False)
                # two_f1 = cluster_eval_new['F1'] + cluster_eval_b3['F1']
                two_f1 = cluster_eval_b3['F1']
                if two_f1 > best_validation_f1:  # acc
                    to_cluster_flag = True
                    best_step = i
                    best_validation_f1 = two_f1

            if to_cluster_flag:
                if 'Louvain' in select_cluster:
                    print('-----Louvain Clustering-----')
                    if not evaluate_hierarchy:
                        cluster_result, cluster_msg = Louvain_no_isolation(
                            dataset=test_data,
                            edge_measure=rsn.pred_X,
                            weighted=louvain_weighted)
                        cluster_eval_new = ClusterEvaluationNew(
                            test_data_label, cluster_result).printEvaluation(
                                print_flag=print_flag)
                        # msger_cluster.record_message(['Louvain_New', i, cluster_eval_new['F1'], cluster_msg])
                        # print("New Metric", cluster_eval)
                        cluster_eval_b3 = ClusterEvaluation(
                            test_data_label, cluster_result).printEvaluation(
                                print_flag=print_flag, extra_info=True)

                        # msger_cluster.record_message(['Louvain', i, cluster_eval_b3['F1'], cluster_eval_b3['precision'],
                        #                               cluster_eval_b3['recall'], cluster_msg])
                        best_cluster_eval_new = cluster_eval_new
                        best_cluster_eval_b3 = cluster_eval_b3
                    rsn.save_model(save_path=save_path,
                                   global_step=i + epoch * batch_num)
                    print('model and clustering messages saved.')

        print('End: The model is:', save_model_name, trainset_loss_type,
              testset_loss_type, 'p_cond is:', p_cond)
    print("best_cluster_eval_new", best_cluster_eval_new)
    print("best_cluster_eval_b3", best_cluster_eval_b3)
    print(seed)
    return best_cluster_eval_new, best_cluster_eval_b3
예제 #10
0
def load_cluster(train_data_file,
                 test_data_file,
                 wordvec_file,
                 load_model_name=None,
                 all_structure_file=None,
                 trainset_loss_type='triplet',
                 testset_loss_type='none',
                 p_cond=0.03,
                 to_cluster_data_num=100,
                 p_denoise=1.0,
                 rel2id_file=None,
                 similarity_file=None,
                 margin=1.0,
                 save_cluster=False,
                 louvain_weighted=False,
                 same_level_pair_file=None,
                 train_for_cluster_file=None,
                 train_structure_file=None,
                 test_infos_file=None,
                 val_hier=False,
                 golden=False,
                 max_len=120,
                 pos_emb_dim=5,
                 batch_size=64,
                 squared=True,
                 random_init=False,
                 seed=42):
    if load_model_name is not None:
        load_path = os.path.join('model_file',
                                 load_model_name).replace('\\', '/')
    else:
        load_path = None

    print('-----Data Loading-----')
    # for train
    dataloader_train = dataloader(train_data_file,
                                  wordvec_file,
                                  rel2id_file,
                                  similarity_file,
                                  same_level_pair_file,
                                  max_len=max_len,
                                  random_init=random_init,
                                  seed=seed)
    # for cluster never seen instances
    dataloader_train_for_cluster = dataloader(train_for_cluster_file,
                                              wordvec_file,
                                              rel2id_file,
                                              similarity_file,
                                              same_level_pair_file,
                                              max_len=max_len)

    dataloader_test = dataloader(test_data_file,
                                 wordvec_file,
                                 rel2id_file,
                                 similarity_file,
                                 max_len=max_len)
    word_emb_dim = dataloader_train._word_emb_dim_()
    word_vec_mat = dataloader_train._word_vec_mat_()
    print('word_emb_dim is {}'.format(word_emb_dim))

    # compile model
    print('-----Model Initializing-----')

    rsn = RSN(word_vec_mat=word_vec_mat,
              max_len=max_len,
              pos_emb_dim=pos_emb_dim,
              dropout=0)
    rsn.set_train_op(batch_size=batch_size,
                     train_loss_type=trainset_loss_type,
                     testset_loss_type=testset_loss_type,
                     p_cond=p_cond,
                     p_denoise=p_denoise,
                     p_mult=0.02,
                     squared=squared,
                     margin=margin)

    if load_path:
        rsn.load_model(load_path + "/RSNbest.pt")
    rsn = cudafy(rsn)
    rsn.eval()
    print('-----Louvain Clustering-----')

    if val_hier:
        print('-----Top Down Hierarchy Expansion-----')
        train_hierarchy_structure_info = json.load(open(train_structure_file))
        all_hierarchy_structure_info = json.load(open(all_structure_file))
        train_hierarchy_cluster_list, gt_hierarchy_cluster_list, train_data_num, test_data_num, train_data, train_label, test_data, test_label = prepare_cluster_list(
            dataloader_train_for_cluster, dataloader_test,
            train_hierarchy_structure_info, all_hierarchy_structure_info,
            to_cluster_data_num)
        link_th_list = [0.2]

        if golden:
            link_th_list = [0.3]
            predicted_cluster_dict_list = Top_Down_Louvain_with_test_cluster_done_avg_link_list_golden(
                gt_hierarchy_cluster_list, train_data_num, test_data_num,
                train_data, test_data, train_hierarchy_cluster_list,
                rsn.pred_X, link_th_list)
        else:
            cluster_result, cluster_msg = Louvain_no_isolation(
                dataset=test_data,
                edge_measure=rsn.pred_X,
                weighted=louvain_weighted)
            predicted_cluster_dict_list = Top_Down_Louvain_with_test_cluster_done_avg_link_list(
                cluster_result, train_data_num, test_data_num, train_data,
                test_data, train_hierarchy_cluster_list, rsn.pred_X,
                link_th_list)
            if save_cluster:
                json.dump(cluster_result, open("cluster_result.json", "w"))
                pickle.dump(predicted_cluster_dict_list,
                            open("predicted_cluster_dict_list.pkl", "wb"))
                pickle.dump(gt_hierarchy_cluster_list,
                            open("gt_hierarchy_cluster_list.pkl", "wb"))
                print("saved results!")
        for predicted_cluster_dict in predicted_cluster_dict_list:
            print("\n\n")
            predicted_cluster_list = predicted_cluster_dict['list']
            print("Isolation threhold", predicted_cluster_dict['iso'])
            print("Average Link threhold", predicted_cluster_dict['link_th'])
            pickle.dump(predicted_cluster_list,
                        open("predicted_cluster_list.pkl", "wb"))
            evaluation = HierarchyClusterEvaluation(gt_hierarchy_cluster_list,
                                                    predicted_cluster_list,
                                                    test_data_num)
            eval_info = evaluation.printEvaluation(print_flag=True)
            HierarchyClusterEvaluationTypes(gt_hierarchy_cluster_list,
                                            predicted_cluster_list,
                                            test_infos_file,
                                            rel2id_file).printEvaluation()
    else:
        test_data, test_data_label = dataloader_test._data_()
        cluster_result, cluster_msg = Louvain_no_isolation(
            dataset=test_data,
            edge_measure=rsn.pred_X,
            weighted=louvain_weighted)

        cluster_eval_b3 = ClusterEvaluation(
            test_data_label, cluster_result).printEvaluation(print_flag=True,
                                                             extra_info=True)

        ClusterEvaluationB3Types(test_data_label, cluster_result,
                                 test_infos_file,
                                 rel2id_file).printEvaluation()
        print("100 times")

        print({k: v * 100 for k, v in cluster_eval_b3.items()})