def train(self, cache_path=None): if self.inspectTest and (not self.splitValidation): print('inspecting test, please dont use val acc as early stoping') self.valDataIter = self.testDataIter elif self.inspectTest and self.splitValidation: print('inspectTest and splitValidation can not use same time') print('deset inspectTest') self.inspectTest = False if self.splitValidation: print('splitting test for validation') self.valDataIter = copy.deepcopy(self.trainDataIter) train_val_ids = copy.deepcopy(self.trainDataIter.all_ids) random.shuffle(train_val_ids) split_4_train = 1 - self.splitValidation top_n_4_train = math.floor(len(train_val_ids) * split_4_train) id_4_train = train_val_ids[:top_n_4_train] id_4_val = train_val_ids[top_n_4_train:] self.trainDataIter.all_ids = id_4_train self.valDataIter.all_ids = id_4_val assert self.inspectTest != self.splitValidation, 'splitValidation will overwrite inspectTest, dont use at the same time' if self.dynamicSampling: print('get training data sample weights') trainDataIter.cal_sample_weights() self.trainDataIter._reset_iter() trainBatchIter = BatchIterBert(self.trainDataIter, filling_last_batch=True, postProcessor=batchPostProcessor, batch_size=self.batch_size) if self.valDataIter: self.valDataIter._reset_iter() valBatchIter = BatchIterBert(self.valDataIter, filling_last_batch=False, postProcessor=batchPostProcessor, batch_size=self.batch_size) else: valBatchIter = None print(self.vocab_dim) net = Model(self.config, vocab_dim=self.vocab_dim) self.mUlti = modelUlti(net, gpu=self.gpu) #print(next(trainBatchIter)) self.mUlti.train(trainBatchIter, cache_path=cache_path, num_epohs=self.num_epoches, valBatchIter=valBatchIter, patience=self.patient, earlyStopping=self.earlyStopping)
def train_lda(self, cache_path): print(cache_path) trainBatchIter = BatchIterBert(self.trainDataIter, filling_last_batch=False, postProcessor=batchPostProcessor, batch_size=1) bow_list = [] for item in trainBatchIter: bow = item[1].squeeze().detach().numpy().tolist() bow_list.append(self.bow_2_gensim(bow)) print(len(bow_list)) #print(self.dictProcess.common_dictionary.id2token) lda = LdaModel(np.array(bow_list), num_topics=50, passes=200, chunksize=len(bow_list), id2word=self.dictProcess.common_dictionary) #print(lda.show_topic(1, topn=10)) output_topic_line = '' for topic_id in range(50): current_topic_list = [] current_topic = lda.show_topic(topic_id, topn=10) for topic_tuple in current_topic: current_topic_list.append(topic_tuple[0]) output_topic_line += ' '.join(current_topic_list) + '\n' #print(current_topic_list) topic_file = os.path.join(cache_path, 'ldatopic.txt') with open(topic_file, 'w') as fo: fo.write(output_topic_line) testBatchIter = BatchIterBert(self.testDataIter, filling_last_batch=False, postProcessor=batchPostProcessor, batch_size=1) test_bow_list = [] word_count = 0 for item in testBatchIter: bow = item[1].squeeze().detach().numpy().tolist() word_count += sum(bow) test_bow_list.append(self.bow_2_gensim(bow)) print(word_count) ppl = lda.log_perplexity(test_bow_list, len(test_bow_list)) print(ppl) bound = lda.bound(test_bow_list) print(bound / word_count) print(np.exp2(-bound / word_count))
def train_test_evaluation(self): path = Path(self.cache_path) path.mkdir(parents=True, exist_ok=True) self.train(cache_path=self.cache_path) testBatchIter = BatchIterBert(self.testDataIter, filling_last_batch=False, postProcessor=batchPostProcessor, batch_size=self.batch_size) results = self.mUlti.eval(testBatchIter, get_perp=self.get_perp) print(results)
def buildDict(self): batchiter = BatchIterBert(self.trainDataIter, filling_last_batch=False, postProcessor=xonlyBatchProcessor, batch_size=1) common_dictionary = Dictionary(batchiter) print(len(common_dictionary)) if self.testReaderargs: print('update vocab from test set') batchiter = BatchIterBert(self.testDataIter, filling_last_batch=False, postProcessor=xonlyBatchProcessor, batch_size=1) common_dictionary.add_documents(batchiter) print(len(common_dictionary)) common_dictionary.filter_extremes(no_below=self.dict_no_below, no_above=self.dict_no_above, keep_n=self.dict_keep_n) self.dictProcess = DictionaryProcess(common_dictionary) self.postProcessor.dictProcess = self.dictProcess self.vocab_dim = len(self.dictProcess) self.have_dict = True if 1: count_list = [] self.trainDataIter._reset_iter() batchiter = BatchIterBert(self.trainDataIter, filling_last_batch=False, postProcessor=xonlyBatchProcessor, batch_size=1) for item in batchiter: current_count = sum(item) count_list.append(current_count) #print(current_count) print(sum(count_list) / len(count_list))
def cross_fold_evaluation(self): kf = KFold(n_splits=self.n_fold) fold_index = 1 results_dict = {} results_dict['accuracy'] = [] results_dict['perplexity'] = [] results_dict['log_perplexity'] = [] results_dict['perplexity_x_only'] = [] results_dict['f-measure'] = {} for each_fold in kf.split(self.all_ids): train_ids, test_ids = self.reconstruct_ids(each_fold) self.trainDataIter.all_ids = train_ids self.testDataIter.all_ids = test_ids self.testDataIter._reset_iter() fold_cache_path = os.path.join(self.cache_path, 'fold' + str(fold_index)) path = Path(fold_cache_path) path.mkdir(parents=True, exist_ok=True) if self.trainLDA: self.train_lda(cache_path=fold_cache_path) else: self.train(cache_path=fold_cache_path) testBatchIter = BatchIterBert(self.testDataIter, filling_last_batch=False, postProcessor=batchPostProcessor, batch_size=self.batch_size) results = self.mUlti.eval(testBatchIter, get_perp=self.get_perp) print(results) results_dict['accuracy'].append(results['accuracy']) if 'perplexity' in results: results_dict['perplexity'].append(results['perplexity']) results_dict['log_perplexity'].append( results['log_perplexity']) results_dict['perplexity_x_only'].append( results['perplexity_x_only']) for f_measure_class in results['f-measure']: if f_measure_class not in results_dict['f-measure']: results_dict['f-measure'][f_measure_class] = { 'precision': [], 'recall': [], 'f-measure': [], 'total_pred': [], 'total_true': [], 'matches': [] } results_dict['f-measure'][f_measure_class][ 'precision'].append( results['f-measure'][f_measure_class][0]) results_dict['f-measure'][f_measure_class][ 'recall'].append( results['f-measure'][f_measure_class][1]) results_dict['f-measure'][f_measure_class][ 'f-measure'].append( results['f-measure'][f_measure_class][2]) results_dict['f-measure'][f_measure_class][ 'total_pred'].append( results['f-measure'][f_measure_class][3]) results_dict['f-measure'][f_measure_class][ 'total_true'].append( results['f-measure'][f_measure_class][4]) results_dict['f-measure'][f_measure_class][ 'matches'].append( results['f-measure'][f_measure_class][5]) fold_index += 1 print(results_dict) overall_accuracy = sum(results_dict['accuracy']) / len( results_dict['accuracy']) if len(results_dict['perplexity']) > 0: overall_perplexity = sum(results_dict['perplexity']) / len( results_dict['perplexity']) print('perplexity: ', overall_perplexity) overall_log_perplexity = sum(results_dict['log_perplexity']) / len( results_dict['log_perplexity']) print('log perplexity: ', overall_log_perplexity) overall_perplexity_x = sum( results_dict['perplexity_x_only']) / len( results_dict['perplexity_x_only']) print('perplexity_x_only: ', overall_perplexity_x) macro_precision = get_average_fmeasure_score(results_dict, 'precision') macro_recall = get_average_fmeasure_score(results_dict, 'recall') macro_fmeasure = get_average_fmeasure_score(results_dict, 'f-measure') micro_precision = get_micro_fmeasure(results_dict, 'matches', 'total_pred') micro_recall = get_micro_fmeasure(results_dict, 'matches', 'total_true') micro_fmeasure = 2 * ((micro_precision * micro_recall) / (micro_precision + micro_recall)) print('accuracy: ', overall_accuracy) print('micro_precision: ', micro_precision) print('micro_recall: ', micro_recall) print('micro_f-measure: ', micro_fmeasure) print('macro_precision: ', macro_precision) print('macro_recall: ', macro_recall) print('macro_f-measure: ', macro_fmeasure)
param.requires_grad = True param.data.uniform_(-1.0, 1.0) elif name in trainable_bias: param.requires_grad = True param.data.fill_(0) else: param.requires_grad = False postProcessor = ReaderPostProcessor(config=config, word2id=True, remove_single_list=False, add_spec_tokens=True, x_fields=x_fields, y_field=args.y_field, max_sent_len=510) postProcessor.dictProcess = mUlti.bowdict testDataIter = dataIter(*testReaderargs, postProcessor=postProcessor, config=config, shuffle=True) testBatchIter = BatchIterBert(testDataIter, filling_last_batch=True, postProcessor=batchPostProcessor, batch_size=32) mUlti.train(testBatchIter, num_epohs=args.num_epoches, cache_path=args.cachePath)