Beispiel #1
0
 def evaluate(self,
              dev_data,
              prefix_metric=None,
              save_dir=None,
              save_result=False,
              file_name=None):
     '''
     Evaluate the model on a validation set
     '''
     all_batch_list = []
     eval_dataloader = self.build_eval_dataloader(dev_data)
     self.build_record_tracker()
     self.reset_metrics()
     pbar = ProgressBar(n_total=len(eval_dataloader), desc='Evaluating')
     for step, batch in enumerate(eval_dataloader):
         batch = self.predict_forward(batch)
         if 'loss' in batch and batch['loss'] is not None:
             self.records['loss_meter'].update(batch['loss'], n=1)
         all_batch_list.append(batch)
         pbar.step(step)
     self.records['result']['eval_loss'] = self.records['loss_meter'].avg
     self.update_metrics(all_batch_list, prefix_metric)
     self.print_evaluate_result()
     if save_result:
         if file_name is None: file_name = f"dev_eval_results.pkl"
         self.save_predict_result(data=all_batch_list,
                                  file_name=file_name,
                                  save_dir=save_dir)
     if torch.cuda.is_available():
         torch.cuda.empty_cache()
Beispiel #2
0
 def predict(self, test_data, save_result=True, file_name=None, save_dir=None):
     '''
     test数据集预测
     '''
     all_batch_list = []
     test_dataloader = self.build_test_dataloader(test_data)
     pbar = ProgressBar(n_total=len(test_dataloader), desc='Predicting')
     for step, batch in enumerate(test_dataloader):
         batch = self.predict_forward(batch)
         all_batch_list.append(batch)
         pbar.step(step)
     if save_result:
         if file_name is None: file_name = f"test_predict_results.{os.path.splitext(os.path.basename(test_data.data_name))[0]}.pkl"
         self.save_predict_result(data=all_batch_list, file_name=file_name, save_dir=save_dir)
     return all_batch_list
Beispiel #3
0
    def train(self,
              train_data,
              dev_data=None,
              resume_path=None,
              start_epoch=1,
              state_to_save=dict()):
        train_dataloader = self.build_train_dataloader(train_data)
        num_training_steps = len(
            train_dataloader
        ) // self.gradient_accumulation_steps * self.num_train_epochs
        self.steps_in_epoch = len(train_dataloader)
        if self.scheduler is None:
            self.scheduler = self.build_lr_scheduler(num_training_steps)
        self.resume_from_checkpoint(resume_path=resume_path)
        self.build_model_warp()
        self.print_summary(len(train_data), num_training_steps)
        self.optimizer.zero_grad()
        seed_everything(
            self.opts.seed, verbose=False
        )  # Added here for reproductibility (even between python 2 and 3)
        if self.opts.logging_steps < 0:
            self.opts.logging_steps = len(
                train_dataloader) // self.gradient_accumulation_steps
            self.opts.logging_steps = max(1, self.opts.logging_steps)
        if self.opts.save_steps < 0:
            self.opts.save_steps = len(
                train_dataloader) // self.gradient_accumulation_steps
            self.opts.save_steps = max(1, self.opts.save_steps)
        self.build_record_tracker()
        self.reset_metrics()
        pbar = ProgressBar(n_total=len(train_dataloader),
                           desc='Training',
                           num_epochs=self.num_train_epochs)
        for epoch in range(start_epoch, int(self.num_train_epochs) + 1):
            pbar.epoch(current_epoch=epoch)
            for step, batch in enumerate(train_dataloader):
                outputs, should_logging, should_save = self.train_step(
                    step, batch)
                if outputs is not None:
                    if self.opts.ema_enable:
                        self.model_ema.update(self.model)
                    pbar.step(step, {'loss': outputs['loss'].item()})
                if (self.opts.logging_steps > 0 and self.global_step > 0) and \
                        should_logging and self.opts.evaluate_during_training:
                    self.evaluate(dev_data)
                    if self.opts.ema_enable and self.model_ema is not None:
                        self.evaluate(dev_data, prefix_metric='ema')
                    if hasattr(self.writer, 'save'):
                        self.writer.save()
                if (self.opts.save_steps > 0
                        and self.global_step > 0) and should_save:
                    # model checkpoint
                    if self.model_checkpoint:
                        state = self.build_state_object(**state_to_save)
                        if self.opts.evaluate_during_training:
                            if self.model_checkpoint.monitor not in self.records[
                                    'result']:
                                msg = (
                                    "There were expected keys in the eval result: "
                                    f"{', '.join(list(self.records['result'].keys()))}, "
                                    f"but get {self.model_checkpoint.monitor}."
                                )
                                raise TypeError(msg)
                            self.model_checkpoint.step(
                                state=state,
                                current=self.records['result'][
                                    self.model_checkpoint.monitor])
                        else:
                            self.model_checkpoint.step(state=state,
                                                       current=None)

            # early_stopping
            if self.early_stopping:
                if self.early_stopping.monitor not in self.records['result']:
                    msg = (
                        "There were expected keys in the eval result: "
                        f"{', '.join(list(self.records['result'].keys()))}, "
                        f"but get {self.early_stopping.monitor}.")
                    raise TypeError(msg)
                self.early_stopping.step(current=self.records['result'][
                    self.early_stopping.monitor])
                if self.early_stopping.stop_training:
                    break
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        if self.writer:
            self.writer.close()