Ejemplo n.º 1
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    history: History,
                    ratio_width=None,
                    gradient_accumulation=1,
                    encoder_grad_norm=None,
                    decoder_grad_norm=None,
                    patience=0.5,
                    eval_trn=False,
                    **kwargs):
     self.model.train()
     encoder_optimizer, encoder_scheduler, decoder_optimizers = optimizer
     timer = CountdownTimer(len(trn))
     total_loss = 0
     self.reset_metrics(metric)
     model = self.model_
     encoder_parameters = model.encoder.parameters()
     decoder_parameters = model.decoders.parameters()
     for idx, (task_name, batch) in enumerate(trn):
         decoder_optimizer = decoder_optimizers.get(task_name, None)
         output_dict, _ = self.feed_batch(batch, task_name)
         loss = self.compute_loss(batch, output_dict[task_name]['output'], criterion[task_name],
                                  self.tasks[task_name])
         if gradient_accumulation and gradient_accumulation > 1:
             loss /= gradient_accumulation
         loss.backward()
         total_loss += float(loss.item())
         if history.step(gradient_accumulation):
             if self.config.get('grad_norm', None):
                 clip_grad_norm(model, self.config.grad_norm)
             if encoder_grad_norm:
                 torch.nn.utils.clip_grad_norm_(encoder_parameters, encoder_grad_norm)
             if decoder_grad_norm:
                 torch.nn.utils.clip_grad_norm_(decoder_parameters, decoder_grad_norm)
             encoder_optimizer.step()
             encoder_optimizer.zero_grad()
             encoder_scheduler.step()
             if decoder_optimizer:
                 if isinstance(decoder_optimizer, tuple):
                     decoder_optimizer, decoder_scheduler = decoder_optimizer
                 else:
                     decoder_scheduler = None
                 decoder_optimizer.step()
                 decoder_optimizer.zero_grad()
                 if decoder_scheduler:
                     decoder_scheduler.step()
         if eval_trn:
             self.decode_output(output_dict, batch, task_name)
             self.update_metrics(batch, output_dict, metric, task_name)
         timer.log(self.report_metrics(total_loss / (timer.current + 1), metric if eval_trn else None),
                   ratio_percentage=None,
                   ratio_width=ratio_width,
                   logger=logger)
         del loss
         del output_dict
     return total_loss / timer.total
Ejemplo n.º 2
0
 def _step(self, optimizer, scheduler, grad_norm, transformer_grad_norm, lambda_scheduler):
     clip_grad_norm(self.model, grad_norm, self.model.encoder.transformer, transformer_grad_norm)
     optimizer.step()
     scheduler.step()
     if lambda_scheduler:
         lambda_scheduler.step()
     optimizer.zero_grad()
Ejemplo n.º 3
0
 def fit_dataloader(self,
                    trn: DataLoader,
                    criterion,
                    optimizer,
                    metric,
                    logger: logging.Logger,
                    linear_scheduler=None,
                    **kwargs):
     self.model.train()
     timer = CountdownTimer(len(trn))
     total_loss = 0
     self.reset_metrics()
     for batch in trn:
         optimizer.zero_grad()
         output_dict = self.feed_batch(batch)
         loss = output_dict['loss']
         loss.backward()
         if self.config.grad_norm:
             clip_grad_norm(self.model, self.config.grad_norm)
         optimizer.step()
         if linear_scheduler:
             linear_scheduler.step()
         total_loss += loss.item()
         timer.log(self.report_metrics(total_loss / (timer.current + 1)),
                   ratio_percentage=None,
                   logger=logger)
         del loss
     return total_loss / timer.total
Ejemplo n.º 4
0
 def _step(self, optimizer, scheduler, grad_norm):
     clip_grad_norm(self.model, grad_norm)
     optimizer.step()
     scheduler.step()
     optimizer.zero_grad()