def evaluate(self, dev_data, prefix_metric=None, save_dir=None, save_result=False, file_name=None): ''' Evaluate the model on a validation set ''' all_batch_list = [] eval_dataloader = self.build_eval_dataloader(dev_data) self.build_record_tracker() self.reset_metrics() pbar = ProgressBar(n_total=len(eval_dataloader), desc='Evaluating') for step, batch in enumerate(eval_dataloader): batch = self.predict_forward(batch) if 'loss' in batch and batch['loss'] is not None: self.records['loss_meter'].update(batch['loss'], n=1) all_batch_list.append(batch) pbar.step(step) self.records['result']['eval_loss'] = self.records['loss_meter'].avg self.update_metrics(all_batch_list, prefix_metric) self.print_evaluate_result() if save_result: if file_name is None: file_name = f"dev_eval_results.pkl" self.save_predict_result(data=all_batch_list, file_name=file_name, save_dir=save_dir) if torch.cuda.is_available(): torch.cuda.empty_cache()
def predict(self, test_data, save_result=True, file_name=None, save_dir=None): ''' test数据集预测 ''' all_batch_list = [] test_dataloader = self.build_test_dataloader(test_data) pbar = ProgressBar(n_total=len(test_dataloader), desc='Predicting') for step, batch in enumerate(test_dataloader): batch = self.predict_forward(batch) all_batch_list.append(batch) pbar.step(step) if save_result: if file_name is None: file_name = f"test_predict_results.{os.path.splitext(os.path.basename(test_data.data_name))[0]}.pkl" self.save_predict_result(data=all_batch_list, file_name=file_name, save_dir=save_dir) return all_batch_list
def train(self, train_data, dev_data=None, resume_path=None, start_epoch=1, state_to_save=dict()): train_dataloader = self.build_train_dataloader(train_data) num_training_steps = len( train_dataloader ) // self.gradient_accumulation_steps * self.num_train_epochs self.steps_in_epoch = len(train_dataloader) if self.scheduler is None: self.scheduler = self.build_lr_scheduler(num_training_steps) self.resume_from_checkpoint(resume_path=resume_path) self.build_model_warp() self.print_summary(len(train_data), num_training_steps) self.optimizer.zero_grad() seed_everything( self.opts.seed, verbose=False ) # Added here for reproductibility (even between python 2 and 3) if self.opts.logging_steps < 0: self.opts.logging_steps = len( train_dataloader) // self.gradient_accumulation_steps self.opts.logging_steps = max(1, self.opts.logging_steps) if self.opts.save_steps < 0: self.opts.save_steps = len( train_dataloader) // self.gradient_accumulation_steps self.opts.save_steps = max(1, self.opts.save_steps) self.build_record_tracker() self.reset_metrics() pbar = ProgressBar(n_total=len(train_dataloader), desc='Training', num_epochs=self.num_train_epochs) for epoch in range(start_epoch, int(self.num_train_epochs) + 1): pbar.epoch(current_epoch=epoch) for step, batch in enumerate(train_dataloader): outputs, should_logging, should_save = self.train_step( step, batch) if outputs is not None: if self.opts.ema_enable: self.model_ema.update(self.model) pbar.step(step, {'loss': outputs['loss'].item()}) if (self.opts.logging_steps > 0 and self.global_step > 0) and \ should_logging and self.opts.evaluate_during_training: self.evaluate(dev_data) if self.opts.ema_enable and self.model_ema is not None: self.evaluate(dev_data, prefix_metric='ema') if hasattr(self.writer, 'save'): self.writer.save() if (self.opts.save_steps > 0 and self.global_step > 0) and should_save: # model checkpoint if self.model_checkpoint: state = self.build_state_object(**state_to_save) if self.opts.evaluate_during_training: if self.model_checkpoint.monitor not in self.records[ 'result']: msg = ( "There were expected keys in the eval result: " f"{', '.join(list(self.records['result'].keys()))}, " f"but get {self.model_checkpoint.monitor}." ) raise TypeError(msg) self.model_checkpoint.step( state=state, current=self.records['result'][ self.model_checkpoint.monitor]) else: self.model_checkpoint.step(state=state, current=None) # early_stopping if self.early_stopping: if self.early_stopping.monitor not in self.records['result']: msg = ( "There were expected keys in the eval result: " f"{', '.join(list(self.records['result'].keys()))}, " f"but get {self.early_stopping.monitor}.") raise TypeError(msg) self.early_stopping.step(current=self.records['result'][ self.early_stopping.monitor]) if self.early_stopping.stop_training: break if torch.cuda.is_available(): torch.cuda.empty_cache() if self.writer: self.writer.close()