コード例 #1
0
ファイル: EvaluationManager.py プロジェクト: GateNLP/CANTM
    def train(self, cache_path=None):
        if self.inspectTest and (not self.splitValidation):
            print('inspecting test, please dont use val acc as early stoping')
            self.valDataIter = self.testDataIter
        elif self.inspectTest and self.splitValidation:
            print('inspectTest and splitValidation can not use same time')
            print('deset inspectTest')
            self.inspectTest = False

        if self.splitValidation:
            print('splitting test for validation')
            self.valDataIter = copy.deepcopy(self.trainDataIter)
            train_val_ids = copy.deepcopy(self.trainDataIter.all_ids)
            random.shuffle(train_val_ids)
            split_4_train = 1 - self.splitValidation
            top_n_4_train = math.floor(len(train_val_ids) * split_4_train)
            id_4_train = train_val_ids[:top_n_4_train]
            id_4_val = train_val_ids[top_n_4_train:]
            self.trainDataIter.all_ids = id_4_train
            self.valDataIter.all_ids = id_4_val

        assert self.inspectTest != self.splitValidation, 'splitValidation will overwrite inspectTest, dont use at the same time'

        if self.dynamicSampling:
            print('get training data sample weights')
            trainDataIter.cal_sample_weights()

        self.trainDataIter._reset_iter()
        trainBatchIter = BatchIterBert(self.trainDataIter,
                                       filling_last_batch=True,
                                       postProcessor=batchPostProcessor,
                                       batch_size=self.batch_size)

        if self.valDataIter:
            self.valDataIter._reset_iter()
            valBatchIter = BatchIterBert(self.valDataIter,
                                         filling_last_batch=False,
                                         postProcessor=batchPostProcessor,
                                         batch_size=self.batch_size)
        else:
            valBatchIter = None

        print(self.vocab_dim)
        net = Model(self.config, vocab_dim=self.vocab_dim)
        self.mUlti = modelUlti(net, gpu=self.gpu)

        #print(next(trainBatchIter))

        self.mUlti.train(trainBatchIter,
                         cache_path=cache_path,
                         num_epohs=self.num_epoches,
                         valBatchIter=valBatchIter,
                         patience=self.patient,
                         earlyStopping=self.earlyStopping)
コード例 #2
0
ファイル: EvaluationManager.py プロジェクト: GateNLP/CANTM
    def train_lda(self, cache_path):
        print(cache_path)
        trainBatchIter = BatchIterBert(self.trainDataIter,
                                       filling_last_batch=False,
                                       postProcessor=batchPostProcessor,
                                       batch_size=1)
        bow_list = []
        for item in trainBatchIter:
            bow = item[1].squeeze().detach().numpy().tolist()
            bow_list.append(self.bow_2_gensim(bow))
        print(len(bow_list))
        #print(self.dictProcess.common_dictionary.id2token)
        lda = LdaModel(np.array(bow_list),
                       num_topics=50,
                       passes=200,
                       chunksize=len(bow_list),
                       id2word=self.dictProcess.common_dictionary)
        #print(lda.show_topic(1, topn=10))
        output_topic_line = ''
        for topic_id in range(50):
            current_topic_list = []
            current_topic = lda.show_topic(topic_id, topn=10)
            for topic_tuple in current_topic:
                current_topic_list.append(topic_tuple[0])
            output_topic_line += ' '.join(current_topic_list) + '\n'
            #print(current_topic_list)

        topic_file = os.path.join(cache_path, 'ldatopic.txt')
        with open(topic_file, 'w') as fo:
            fo.write(output_topic_line)

        testBatchIter = BatchIterBert(self.testDataIter,
                                      filling_last_batch=False,
                                      postProcessor=batchPostProcessor,
                                      batch_size=1)

        test_bow_list = []
        word_count = 0
        for item in testBatchIter:
            bow = item[1].squeeze().detach().numpy().tolist()
            word_count += sum(bow)
            test_bow_list.append(self.bow_2_gensim(bow))

        print(word_count)
        ppl = lda.log_perplexity(test_bow_list, len(test_bow_list))
        print(ppl)
        bound = lda.bound(test_bow_list)
        print(bound / word_count)
        print(np.exp2(-bound / word_count))
コード例 #3
0
ファイル: EvaluationManager.py プロジェクト: GateNLP/CANTM
 def train_test_evaluation(self):
     path = Path(self.cache_path)
     path.mkdir(parents=True, exist_ok=True)
     self.train(cache_path=self.cache_path)
     testBatchIter = BatchIterBert(self.testDataIter,
                                   filling_last_batch=False,
                                   postProcessor=batchPostProcessor,
                                   batch_size=self.batch_size)
     results = self.mUlti.eval(testBatchIter, get_perp=self.get_perp)
     print(results)
コード例 #4
0
ファイル: EvaluationManager.py プロジェクト: GateNLP/CANTM
    def buildDict(self):
        batchiter = BatchIterBert(self.trainDataIter,
                                  filling_last_batch=False,
                                  postProcessor=xonlyBatchProcessor,
                                  batch_size=1)
        common_dictionary = Dictionary(batchiter)
        print(len(common_dictionary))
        if self.testReaderargs:
            print('update vocab from test set')
            batchiter = BatchIterBert(self.testDataIter,
                                      filling_last_batch=False,
                                      postProcessor=xonlyBatchProcessor,
                                      batch_size=1)
            common_dictionary.add_documents(batchiter)
            print(len(common_dictionary))

        common_dictionary.filter_extremes(no_below=self.dict_no_below,
                                          no_above=self.dict_no_above,
                                          keep_n=self.dict_keep_n)
        self.dictProcess = DictionaryProcess(common_dictionary)
        self.postProcessor.dictProcess = self.dictProcess
        self.vocab_dim = len(self.dictProcess)
        self.have_dict = True

        if 1:
            count_list = []
            self.trainDataIter._reset_iter()
            batchiter = BatchIterBert(self.trainDataIter,
                                      filling_last_batch=False,
                                      postProcessor=xonlyBatchProcessor,
                                      batch_size=1)
            for item in batchiter:
                current_count = sum(item)
                count_list.append(current_count)
                #print(current_count)
            print(sum(count_list) / len(count_list))
コード例 #5
0
ファイル: EvaluationManager.py プロジェクト: GateNLP/CANTM
    def cross_fold_evaluation(self):
        kf = KFold(n_splits=self.n_fold)
        fold_index = 1
        results_dict = {}
        results_dict['accuracy'] = []
        results_dict['perplexity'] = []
        results_dict['log_perplexity'] = []
        results_dict['perplexity_x_only'] = []
        results_dict['f-measure'] = {}
        for each_fold in kf.split(self.all_ids):
            train_ids, test_ids = self.reconstruct_ids(each_fold)
            self.trainDataIter.all_ids = train_ids
            self.testDataIter.all_ids = test_ids
            self.testDataIter._reset_iter()
            fold_cache_path = os.path.join(self.cache_path,
                                           'fold' + str(fold_index))
            path = Path(fold_cache_path)
            path.mkdir(parents=True, exist_ok=True)

            if self.trainLDA:
                self.train_lda(cache_path=fold_cache_path)
            else:
                self.train(cache_path=fold_cache_path)

                testBatchIter = BatchIterBert(self.testDataIter,
                                              filling_last_batch=False,
                                              postProcessor=batchPostProcessor,
                                              batch_size=self.batch_size)

                results = self.mUlti.eval(testBatchIter,
                                          get_perp=self.get_perp)
                print(results)
                results_dict['accuracy'].append(results['accuracy'])
                if 'perplexity' in results:
                    results_dict['perplexity'].append(results['perplexity'])
                    results_dict['log_perplexity'].append(
                        results['log_perplexity'])
                    results_dict['perplexity_x_only'].append(
                        results['perplexity_x_only'])

                for f_measure_class in results['f-measure']:
                    if f_measure_class not in results_dict['f-measure']:
                        results_dict['f-measure'][f_measure_class] = {
                            'precision': [],
                            'recall': [],
                            'f-measure': [],
                            'total_pred': [],
                            'total_true': [],
                            'matches': []
                        }
                    results_dict['f-measure'][f_measure_class][
                        'precision'].append(
                            results['f-measure'][f_measure_class][0])
                    results_dict['f-measure'][f_measure_class][
                        'recall'].append(
                            results['f-measure'][f_measure_class][1])
                    results_dict['f-measure'][f_measure_class][
                        'f-measure'].append(
                            results['f-measure'][f_measure_class][2])
                    results_dict['f-measure'][f_measure_class][
                        'total_pred'].append(
                            results['f-measure'][f_measure_class][3])
                    results_dict['f-measure'][f_measure_class][
                        'total_true'].append(
                            results['f-measure'][f_measure_class][4])
                    results_dict['f-measure'][f_measure_class][
                        'matches'].append(
                            results['f-measure'][f_measure_class][5])
            fold_index += 1

        print(results_dict)
        overall_accuracy = sum(results_dict['accuracy']) / len(
            results_dict['accuracy'])
        if len(results_dict['perplexity']) > 0:
            overall_perplexity = sum(results_dict['perplexity']) / len(
                results_dict['perplexity'])
            print('perplexity: ', overall_perplexity)
            overall_log_perplexity = sum(results_dict['log_perplexity']) / len(
                results_dict['log_perplexity'])
            print('log perplexity: ', overall_log_perplexity)
            overall_perplexity_x = sum(
                results_dict['perplexity_x_only']) / len(
                    results_dict['perplexity_x_only'])
            print('perplexity_x_only: ', overall_perplexity_x)

        macro_precision = get_average_fmeasure_score(results_dict, 'precision')
        macro_recall = get_average_fmeasure_score(results_dict, 'recall')
        macro_fmeasure = get_average_fmeasure_score(results_dict, 'f-measure')

        micro_precision = get_micro_fmeasure(results_dict, 'matches',
                                             'total_pred')
        micro_recall = get_micro_fmeasure(results_dict, 'matches',
                                          'total_true')
        micro_fmeasure = 2 * ((micro_precision * micro_recall) /
                              (micro_precision + micro_recall))

        print('accuracy: ', overall_accuracy)
        print('micro_precision: ', micro_precision)
        print('micro_recall: ', micro_recall)
        print('micro_f-measure: ', micro_fmeasure)
        print('macro_precision: ', macro_precision)
        print('macro_recall: ', macro_recall)
        print('macro_f-measure: ', macro_fmeasure)
コード例 #6
0
ファイル: updateTopics.py プロジェクト: GateNLP/CANTM
            param.requires_grad = True
            param.data.uniform_(-1.0, 1.0)
        elif name in trainable_bias:
            param.requires_grad = True
            param.data.fill_(0)
        else:
            param.requires_grad = False

    postProcessor = ReaderPostProcessor(config=config,
                                        word2id=True,
                                        remove_single_list=False,
                                        add_spec_tokens=True,
                                        x_fields=x_fields,
                                        y_field=args.y_field,
                                        max_sent_len=510)
    postProcessor.dictProcess = mUlti.bowdict

    testDataIter = dataIter(*testReaderargs,
                            postProcessor=postProcessor,
                            config=config,
                            shuffle=True)

    testBatchIter = BatchIterBert(testDataIter,
                                  filling_last_batch=True,
                                  postProcessor=batchPostProcessor,
                                  batch_size=32)

    mUlti.train(testBatchIter,
                num_epohs=args.num_epoches,
                cache_path=args.cachePath)