def _logging( self, model: MultitaskModel, dataloaders: List["DictDataLoader"], batch_size: int, ) -> Metrics: """Log and checkpoint if it is time to do so.""" # Switch to eval mode for evaluation model.eval() self.log_manager.update(batch_size) # Log the loss and lr metric_dict: Metrics = dict() metric_dict.update(self._aggregate_losses()) # Evaluate the model and log the metric if self.log_manager.trigger_evaluation(): # Log metrics metric_dict.update( self._evaluate(model, dataloaders, self.config.valid_split)) self._log_metrics(metric_dict) self._reset_losses() # Checkpoint the model if self.log_manager.trigger_checkpointing(): self._checkpoint_model(model, metric_dict) self._reset_losses() # Switch back to train mode model.train() return metric_dict
def fit(self, model: MultitaskModel, dataloaders: List["DictDataLoader"]) -> None: """Train a MultitaskModel. Parameters ---------- model The model to train dataloaders A list of DataLoaders. These will split into train, valid, and test splits based on the ``split`` attribute of the DataLoaders. """ self._check_dataloaders(dataloaders) # Identify the dataloaders to train on train_dataloaders = [ dl for dl in dataloaders if dl.dataset.split == self.config.train_split # type: ignore ] # Calculate the total number of batches per epoch self.n_batches_per_epoch = sum( [len(dataloader) for dataloader in train_dataloaders]) # Set training helpers self._set_log_writer() self._set_checkpointer() self._set_log_manager() self._set_optimizer(model) self._set_lr_scheduler() self._set_batch_scheduler() # Set to training mode model.train() logging.info("Start training...") self.metrics: Dict[str, float] = dict() self._reset_losses() for epoch_num in range(self.config.n_epochs): batches = tqdm( enumerate(self.batch_scheduler.get_batches(train_dataloaders)), total=self.n_batches_per_epoch, disable=(not self.config.progress_bar), desc=f"Epoch {epoch_num}:", ) for batch_num, (batch, dataloader) in batches: X_dict, Y_dict = batch total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num batch_size = len(next(iter(Y_dict.values()))) # Set gradients of all model parameters to zero self.optimizer.zero_grad() # Perform forward pass and calcualte the loss and count loss_dict, count_dict = model.calculate_loss(X_dict, Y_dict) # Update running loss and count for task_name in loss_dict.keys(): identifier = "/".join([ task_name, dataloader.dataset.name, dataloader.dataset.split, "loss", ]) self.running_losses[identifier] += ( loss_dict[task_name].item() * count_dict[task_name]) self.running_counts[task_name] += count_dict[task_name] # Skip the backward pass if no loss is calcuated if not loss_dict: continue # Calculate the average loss loss = torch.stack(list(loss_dict.values())).sum() # Perform backward pass to calculate gradients loss.backward() # Clip gradient norm if self.config.grad_clip: torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.grad_clip) # Update the parameters self.optimizer.step() # Update lr using lr scheduler self._update_lr_scheduler(total_batch_num) # Update metrics self.metrics.update( self._logging(model, dataloaders, batch_size)) batches.set_postfix(self.metrics) model = self.log_manager.cleanup(model)