def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, history: History, ratio_width=None, gradient_accumulation=1, encoder_grad_norm=None, decoder_grad_norm=None, patience=0.5, eval_trn=False, **kwargs): self.model.train() encoder_optimizer, encoder_scheduler, decoder_optimizers = optimizer timer = CountdownTimer(len(trn)) total_loss = 0 self.reset_metrics(metric) model = self.model_ encoder_parameters = model.encoder.parameters() decoder_parameters = model.decoders.parameters() for idx, (task_name, batch) in enumerate(trn): decoder_optimizer = decoder_optimizers.get(task_name, None) output_dict, _ = self.feed_batch(batch, task_name) loss = self.compute_loss(batch, output_dict[task_name]['output'], criterion[task_name], self.tasks[task_name]) if gradient_accumulation and gradient_accumulation > 1: loss /= gradient_accumulation loss.backward() total_loss += float(loss.item()) if history.step(gradient_accumulation): if self.config.get('grad_norm', None): clip_grad_norm(model, self.config.grad_norm) if encoder_grad_norm: torch.nn.utils.clip_grad_norm_(encoder_parameters, encoder_grad_norm) if decoder_grad_norm: torch.nn.utils.clip_grad_norm_(decoder_parameters, decoder_grad_norm) encoder_optimizer.step() encoder_optimizer.zero_grad() encoder_scheduler.step() if decoder_optimizer: if isinstance(decoder_optimizer, tuple): decoder_optimizer, decoder_scheduler = decoder_optimizer else: decoder_scheduler = None decoder_optimizer.step() decoder_optimizer.zero_grad() if decoder_scheduler: decoder_scheduler.step() if eval_trn: self.decode_output(output_dict, batch, task_name) self.update_metrics(batch, output_dict, metric, task_name) timer.log(self.report_metrics(total_loss / (timer.current + 1), metric if eval_trn else None), ratio_percentage=None, ratio_width=ratio_width, logger=logger) del loss del output_dict return total_loss / timer.total
def _step(self, optimizer, scheduler, grad_norm, transformer_grad_norm, lambda_scheduler): clip_grad_norm(self.model, grad_norm, self.model.encoder.transformer, transformer_grad_norm) optimizer.step() scheduler.step() if lambda_scheduler: lambda_scheduler.step() optimizer.zero_grad()
def fit_dataloader(self, trn: DataLoader, criterion, optimizer, metric, logger: logging.Logger, linear_scheduler=None, **kwargs): self.model.train() timer = CountdownTimer(len(trn)) total_loss = 0 self.reset_metrics() for batch in trn: optimizer.zero_grad() output_dict = self.feed_batch(batch) loss = output_dict['loss'] loss.backward() if self.config.grad_norm: clip_grad_norm(self.model, self.config.grad_norm) optimizer.step() if linear_scheduler: linear_scheduler.step() total_loss += loss.item() timer.log(self.report_metrics(total_loss / (timer.current + 1)), ratio_percentage=None, logger=logger) del loss return total_loss / timer.total
def _step(self, optimizer, scheduler, grad_norm): clip_grad_norm(self.model, grad_norm) optimizer.step() scheduler.step() optimizer.zero_grad()