Beispiel #1
0
 def evaluate(self,
              dev_data,
              prefix_metric=None,
              save_dir=None,
              save_result=False,
              file_name=None):
     '''
     Evaluate the model on a validation set
     '''
     all_batch_list = []
     eval_dataloader = self.build_eval_dataloader(dev_data)
     self.build_record_tracker()
     self.reset_metrics()
     pbar = ProgressBar(n_total=len(eval_dataloader), desc='Evaluating')
     for step, batch in enumerate(eval_dataloader):
         batch = self.predict_forward(batch)
         if 'loss' in batch and batch['loss'] is not None:
             self.records['loss_meter'].update(batch['loss'], n=1)
         all_batch_list.append(batch)
         pbar.step(step)
     self.records['result']['eval_loss'] = self.records['loss_meter'].avg
     self.update_metrics(all_batch_list, prefix_metric)
     self.print_evaluate_result()
     if save_result:
         if file_name is None: file_name = f"dev_eval_results.pkl"
         self.save_predict_result(data=all_batch_list,
                                  file_name=file_name,
                                  save_dir=save_dir)
     if torch.cuda.is_available():
         torch.cuda.empty_cache()
Beispiel #2
0
 def predict(self, test_data, save_result=True, file_name=None, save_dir=None):
     '''
     test数据集预测
     '''
     all_batch_list = []
     test_dataloader = self.build_test_dataloader(test_data)
     pbar = ProgressBar(n_total=len(test_dataloader), desc='Predicting')
     for step, batch in enumerate(test_dataloader):
         batch = self.predict_forward(batch)
         all_batch_list.append(batch)
         pbar.step(step)
     if save_result:
         if file_name is None: file_name = f"test_predict_results.{os.path.splitext(os.path.basename(test_data.data_name))[0]}.pkl"
         self.save_predict_result(data=all_batch_list, file_name=file_name, save_dir=save_dir)
     return all_batch_list
 def _predict_forward(self, model, data_loader, do_eval=True, **kwargs):
     self.build_record_object()
     pbar = ProgressBar(n_total=len(data_loader),
                        desc='Evaluating' if do_eval else 'Predicting')
     for step, batch in enumerate(data_loader):
         model.eval()
         inputs = self.build_inputs(batch)
         with torch.no_grad():
             outputs = model.module(**inputs) if isinstance(
                 model, nn.DataParallel) else model(**inputs)
         if do_eval:
             loss, logits = outputs[:2]
             self.records['loss_meter'].update(loss.item(), n=1)
             self.records['target'].extend(tensor_to_list(inputs['labels']))
         else:
             if outputs[0].dim() == 1 and outputs[0].size(0) == 1:
                 logits = outputs[1]
             else:
                 logits = outputs[0]
         if self.args.use_crf:
             crf_model = model.module.crf if isinstance(
                 model, nn.DataParallel) else model.crf
             tags = crf_model.decode(logits, inputs['attention_mask'])
             self.records['preds'].extend(tensor_to_list(tags.squeeze(0)))
         else:
             self.records['preds'].extend(
                 tensor_to_list(torch.argmax(logits, dim=2)))
         self.records['input_lens'].extend(
             tensor_to_list(torch.sum(inputs['attention_mask'], 1)))
         pbar(step)
    def _predict_forward(self, model, data_loader, do_eval=True, **kwargs):
        self.build_record_object()
        pbar = ProgressBar(n_total=len(data_loader),
                           desc='Evaluating' if do_eval else 'Predicting')
        for step, batch in enumerate(data_loader):
            model.eval()
            inputs = self.build_inputs(batch)
            with torch.no_grad():
                outputs = model(**inputs)
            if do_eval:
                loss, start_logits, end_logits = outputs[:3]
                start_positions = tensor_to_list(inputs['start_positions'])
                end_positions = tensor_to_list(inputs['end_positions'])
                self.records['loss_meter'].update(loss.mean().item(), n=1)
                self.records['target'].extend(
                    zip(start_positions, end_positions))
            else:
                if outputs[0].dim() == 1 and outputs[0].size(0) == 1:
                    _, start_logits, end_logits = outputs[:3]
                else:
                    start_logits, end_logits = outputs[:2]

            start_logits = tensor_to_list(torch.argmax(start_logits, -1))
            end_logits = tensor_to_list(torch.argmax(end_logits, -1))
            self.records['preds'].extend(zip(start_logits, end_logits))
            self.records['input_lens'].extend(
                tensor_to_list(torch.sum(inputs['attention_mask'], 1)))
            pbar(step)
 def predict_step(self, model, data_loader, do_eval, **kwargs):
     self.build_record_object()
     pbar = ProgressBar(n_total=len(data_loader), desc='Evaluating' if do_eval else 'Predicting')
     for step, batch in enumerate(data_loader):
         model.eval()
         inputs = self.build_inputs(batch)
         with torch.no_grad():
             outputs = model(**inputs)
         if do_eval:
             loss, logits = outputs[:2]
             loss = loss.mean()
             self.records['target'].append(tensor_to_cpu(inputs['labels']))
             self.records['loss_meter'].update(loss.item(), n=1)
         else:
             if outputs[0].dim() == 1 and outputs[0].size(0) == 1:
                 logits = outputs[1]
             else:
                 logits = outputs[0]
         anchor, positive, negative = logits
         distance_metric = DISTANCE2METRIC[self.args.distance_metric]
         distance_positive = distance_metric(anchor, positive)
         distance_negative = distance_metric(anchor, negative)
         diff_dist = 1 - (distance_positive > distance_negative).int()
         self.records['preds'].append(tensor_to_cpu(diff_dist))
         pbar(step)
     self.records['preds'] = torch.cat(self.records['preds'], dim=0)
     if do_eval:
         self.records['target'] = torch.cat(self.records['target'], dim=0)
Beispiel #6
0
 def _predict_forward(self, model, data_loader, do_eval, **kwargs):
     self.build_record_object()
     pbar = ProgressBar(n_total=len(data_loader),
                        desc='Evaluating' if do_eval else 'Predicting')
     for step, batch in enumerate(data_loader):
         model.eval()
         inputs = self.build_inputs(batch)
         with torch.no_grad():
             outputs = model(**inputs)
         if do_eval:
             loss, logits = outputs[:2]
             loss = loss.mean()
             labels = inputs['labels']
             self.records['target'].append(tensor_to_cpu(labels))
             self.records['loss_meter'].update(loss.item(), n=1)
         else:
             if outputs[0].dim() == 1 and outputs[0].size(0) == 1:
                 logits = outputs[1]
             else:
                 logits = outputs[0]
         self.records['preds'].append(tensor_to_cpu(logits))
         pbar(step)
     self.records['preds'] = torch.cat(self.records['preds'], dim=0)
     if do_eval:
         self.records['target'] = torch.cat(self.records['target'], dim=0)
Beispiel #7
0
    def train(self,
              train_data,
              dev_data=None,
              resume_path=None,
              start_epoch=1,
              state_to_save=dict()):
        train_dataloader = self.build_train_dataloader(train_data)
        num_training_steps = len(
            train_dataloader
        ) // self.gradient_accumulation_steps * self.num_train_epochs
        self.steps_in_epoch = len(train_dataloader)
        if self.scheduler is None:
            self.scheduler = self.build_lr_scheduler(num_training_steps)
        self.resume_from_checkpoint(resume_path=resume_path)
        self.build_model_warp()
        self.print_summary(len(train_data), num_training_steps)
        self.optimizer.zero_grad()
        seed_everything(
            self.opts.seed, verbose=False
        )  # Added here for reproductibility (even between python 2 and 3)
        if self.opts.logging_steps < 0:
            self.opts.logging_steps = len(
                train_dataloader) // self.gradient_accumulation_steps
            self.opts.logging_steps = max(1, self.opts.logging_steps)
        if self.opts.save_steps < 0:
            self.opts.save_steps = len(
                train_dataloader) // self.gradient_accumulation_steps
            self.opts.save_steps = max(1, self.opts.save_steps)
        self.build_record_tracker()
        self.reset_metrics()
        pbar = ProgressBar(n_total=len(train_dataloader),
                           desc='Training',
                           num_epochs=self.num_train_epochs)
        for epoch in range(start_epoch, int(self.num_train_epochs) + 1):
            pbar.epoch(current_epoch=epoch)
            for step, batch in enumerate(train_dataloader):
                outputs, should_logging, should_save = self.train_step(
                    step, batch)
                if outputs is not None:
                    if self.opts.ema_enable:
                        self.model_ema.update(self.model)
                    pbar.step(step, {'loss': outputs['loss'].item()})
                if (self.opts.logging_steps > 0 and self.global_step > 0) and \
                        should_logging and self.opts.evaluate_during_training:
                    self.evaluate(dev_data)
                    if self.opts.ema_enable and self.model_ema is not None:
                        self.evaluate(dev_data, prefix_metric='ema')
                    if hasattr(self.writer, 'save'):
                        self.writer.save()
                if (self.opts.save_steps > 0
                        and self.global_step > 0) and should_save:
                    # model checkpoint
                    if self.model_checkpoint:
                        state = self.build_state_object(**state_to_save)
                        if self.opts.evaluate_during_training:
                            if self.model_checkpoint.monitor not in self.records[
                                    'result']:
                                msg = (
                                    "There were expected keys in the eval result: "
                                    f"{', '.join(list(self.records['result'].keys()))}, "
                                    f"but get {self.model_checkpoint.monitor}."
                                )
                                raise TypeError(msg)
                            self.model_checkpoint.step(
                                state=state,
                                current=self.records['result'][
                                    self.model_checkpoint.monitor])
                        else:
                            self.model_checkpoint.step(state=state,
                                                       current=None)

            # early_stopping
            if self.early_stopping:
                if self.early_stopping.monitor not in self.records['result']:
                    msg = (
                        "There were expected keys in the eval result: "
                        f"{', '.join(list(self.records['result'].keys()))}, "
                        f"but get {self.early_stopping.monitor}.")
                    raise TypeError(msg)
                self.early_stopping.step(current=self.records['result'][
                    self.early_stopping.monitor])
                if self.early_stopping.stop_training:
                    break
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        if self.writer:
            self.writer.close()
Beispiel #8
0
 def train(self, model, train_dataset, eval_dataset):
     """
     Main training entry point.
     """
     train_dataloader = self.build_train_dataloader(train_dataset)
     t_total = len(
         train_dataloader
     ) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
     optimizer = self.build_optimizer(model)
     scheduler = self.build_scheduler(optimizer, t_total)
     optimizer, scheduler = self.restore_optimizer(optimizer, scheduler)
     model, optimizer = self.build_apex_and_distribute(model, optimizer)
     # Train!
     self.print_training_parameters(model, len(train_dataset), t_total)
     model.zero_grad()
     # ema
     if self.args.do_ema:
         ema = EMA(model, decay=self.args.ema_decay)
     seed_everything(
         self.args.seed
     )  # Added here for reproductibility (even between python 2 and 3)
     print('Start training.')
     for epoch in range(0, int(self.args.num_train_epochs)):
         self.build_record_object()
         pbar = ProgressBar(n_total=len(train_dataloader), desc='Training')
         for step, batch in enumerate(train_dataloader):
             loss = self._train_step(model, batch, optimizer)
             if (step + 1) % self.args.gradient_accumulation_steps == 0:
                 self._train_update(model, optimizer, loss, scheduler)
                 if self.args.do_ema:
                     ema.update(model)
                 pbar(step, {'loss': loss})
             if (self.args.local_rank in [-1, 0]
                     and self.args.logging_steps > 0
                     and self.global_step % self.args.logging_steps == 0):
                 if self.args.do_ema:
                     ema.apply_shadow(model)
                 self.tb_writer.add_scalar(
                     'Loss/train_epoch_loss',
                     self.records['loss_meter'].avg,
                     int(self.global_step / self.args.logging_steps))
                 self.evaluate(model, eval_dataset)
                 if self.args.do_ema:
                     ema.restore(model)
                 if hasattr(self.tb_writer, 'save'):
                     self.tb_writer.save()
             if (self.args.local_rank in [-1, 0]
                     and self.args.save_steps > 0
                     and self.global_step % self.args.save_steps == 0):
                 # model checkpoint
                 if self.model_checkpoint:
                     state = self.build_state_object(
                         model, optimizer, scheduler, self.global_step)
                     self.model_checkpoint.step(
                         state=state,
                         current=self.records['result'][
                             self.model_checkpoint.monitor])
         if not self.scheduler_on_batch:  # epoch scheduler
             scheduler.step()
         # early_stopping
         if self.early_stopping:
             self.early_stopping.step(current=self.records['result'][
                 self.early_stopping.monitor])
             if self.early_stopping.stop_training:
                 break
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
     if self.tb_writer:
         self.tb_writer.close()