def train(self): self.isTrain = True self.getSaveFolder() self.saveConf() self.vocab, self.char_vocab, vocab_embedding = self.preproc.load_data() self.log('-----------------------------------------------') self.log("Initializing model...") self.setup_model(vocab_embedding) if 'RESUME' in self.opt: model_path = os.path.join(self.opt['datadir'], self.opt['MODEL_PATH']) self.load_model(model_path) print('Loading train json...') with open( os.path.join(self.opt['FEATURE_FOLDER'], self.data_prefix + 'train-preprocessed.json'), 'r') as f: train_data = json.load(f) print('Loading dev json...') with open( os.path.join(self.opt['FEATURE_FOLDER'], self.data_prefix + 'dev-preprocessed.json'), 'r') as f: dev_data = json.load(f) best_f1_score = 0.0 numEpochs = self.opt['EPOCH'] for epoch in range(self.epoch_start, numEpochs): self.log('Epoch {}'.format(epoch)) self.network.train() startTime = datetime.now() train_batches = BatchGen(self.opt, train_data['data'], self.use_cuda, self.vocab, self.char_vocab) dev_batches = BatchGen(self.opt, dev_data['data'], self.use_cuda, self.vocab, self.char_vocab, evaluation=True) for i, batch in enumerate(train_batches): if i == len(train_batches) - 1 or ( epoch == 0 and i == 0 and ('RESUME' in self.opt)) or (i > 0 and i % 1500 == 0): print('Saving folder is', self.saveFolder) print('Evaluating on dev set...') predictions = [] confidence = [] dev_answer = [] final_json = [] for j, dev_batch in enumerate(dev_batches): phrase, phrase_score, pred_json = self.predict( dev_batch) final_json.extend(pred_json) predictions.extend(phrase) confidence.extend(phrase_score) dev_answer.extend(dev_batch[-3]) # answer_str result, all_f1s = score(predictions, dev_answer, final_json) f1 = result['f1'] if f1 > best_f1_score: model_file = os.path.join(self.saveFolder, 'best_model.pt') self.save_for_predict(model_file, epoch) best_f1_score = f1 pred_json_file = os.path.join(self.saveFolder, 'prediction.json') with open(pred_json_file, 'w') as output_file: json.dump(final_json, output_file) score_per_instance = [] for instance, s in zip(final_json, all_f1s): score_per_instance.append({ 'id': instance['id'], 'turn_id': instance['turn_id'], 'f1': s }) score_per_instance_json_file = os.path.join( self.saveFolder, 'score_per_instance.json') with open(score_per_instance_json_file, 'w') as output_file: json.dump(score_per_instance, output_file) self.log("Epoch {0} - dev: F1: {1:.3f} (best F1: {2:.3f})". format(epoch, f1, best_f1_score)) self.log("Results breakdown\n{0}".format(result)) self.update(batch) if i % 100 == 0: self.log( 'updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'. format( self.updates, self.train_loss.avg, str((datetime.now() - startTime) / (i + 1) * (len(train_batches) - i - 1)).split('.')[0])) print("PROGRESS: {0:.2f}%".format(100.0 * (epoch + 1) / numEpochs)) print('Config file is at ' + self.opt['confFile'])
def train(self): """ train()函数进行批次处理,即对于一个batch的数据,计算当前预测结果并求导更新参数。 每训练1500个batch,利用predict()函数在验证数据上进行一次预测并计算准确率得分。 当前得分最高的模型参数保存在run_id文件夹中。 """ self.isTrain = True # 标记训练模式 self.getSaveFolder() self.saveConf() self.vocab, self.char_vocab, vocab_embedding = self.preproc.load_data( ) # 从CoQAPreprocess中获得词表和编码 self.log('-----------------------------------------------') self.log('Initializing model...') self.setup_model(vocab_embedding) # 初始化模型 if 'RESUME' in self.opt: # 在继续训练模式时,读取之前存储的模型 model_path = os.path.join(self.opt['datadir'], self.opt['MODEL_PATH']) self.load_model(model_path) print('Loading train json') # 读取处理好的训练数据 with open( os.path.join(self.opt['FEATURE_FOLDER'], self.data_prefix + 'train-preprocessed.json'), 'r') as f: train_data = json.load(f) print('Loading dev json') # 读取处理好的验证数据 with open( os.path.join(self.opt['FEATURE_FOLDER'], self.data_prefix + 'dev-preprocessed.json'), 'r') as f: dev_data = json.load(f) best_f1_score = 0.0 # 训练中得到的验证集上的最高的F1得分 numEpochs = self.opt['EPOCH'] # 配置文件中EPOCH为训练轮数 for epoch in range(self.epoch_start, numEpochs): self.log('Epoch {}'.format(epoch)) # 训练模式,开启Dropout等功能 self.network.train() startTime = datetime.now() # 获得训练数据的batch迭代器 train_batches = BatchGen(self.opt, train_data['data'], self.use_cuda, self.vocab, self.char_vocab) # 获得验证数据的batch迭代器 dev_batches = BatchGen(self.opt, dev_data['data'], self.use_cuda, self.vocab, self.char_vocab, evaluation=True) for i, batch in enumerate(train_batches): # 每轮结束时或继续训练模式的第一个batch或每1500个batch,在验证数据上预测并计算得分 if i == len(train_batches) - 1 or ( epoch == 0 and i == 0 and ('RESUME' in self.opt)) or (i > 0 and i % 1500 == 0): print('Saving folder is', self.saveFolder) print('Evaluating on dev set...') predictions = [] confidence = [] dev_answer = [] final_json = [] for j, dev_batch in enumerate(dev_batches): # 预测的结果包括答案文本、答案可能性打分以及JSON格式结果 phrase, phrase_score, pred_json = self.predict( dev_batch) final_json.extend(pred_json) predictions.extend(phrase) confidence.extend(phrase_score) dev_answer.extend(dev_batch[-3]) # answer_str # 计算精确匹配EM和F1得分 result, all_f1s = score(pred=predictions, truth=dev_answer, final_json=final_json) f1 = result['f1'] # 如果F1得分高于之前的所有模型,则存储此模型 if f1 > best_f1_score: model_file = os.path.join(self.saveFolder, 'best_model.pt') self.save_for_predict(model_file, epoch) best_f1_score = f1 pred_json_file = os.path.join(self.saveFolder, 'prediction.json') with open(pred_json_file, 'w') as output_file: json.dump(final_json, output_file) score_per_instance = [] for instance, s in zip(final_json, all_f1s): score_per_instance.append({ 'id': instance['id'], 'turn_id': instance['turn_id'], 'f1': s }) score_per_instance_json_file = os.path.join( self.saveFolder, 'score_per_instance.json') with open(score_per_instance_json_file, 'w') as output_file: json.dump(score_per_instance, output_file) self.log('Epoch {0} - dev F1: {1:.3f} (best F1: {2:.3f})'. format(epoch, f1, best_f1_score)) self.log('Results breakdown\n{0}'.format(result)) # 对本批次进行计算、求导和参数更新 self.update(batch) if i % 100 == 0: self.log( 'updates[{0: 6}] train loss[{1: .5f}] remaining[{2}]'. format( self.updates, self.train_loss.avg, str((datetime.now() - startTime) / (i + 1) * (len(train_batches) - i - 1)).split('.')[0])) print('PROGRESS: {0:.2F}%'.format(100.0 * (epoch + 1) / numEpochs)) print('Config file is at ' + self.opt['confFile'])