Beispiel #1
0
    def _logging(
        self,
        model: MultitaskModel,
        dataloaders: List["DictDataLoader"],
        batch_size: int,
    ) -> Metrics:
        """Log and checkpoint if it is time to do so."""

        # Switch to eval mode for evaluation
        model.eval()

        self.log_manager.update(batch_size)

        # Log the loss and lr
        metric_dict: Metrics = dict()
        metric_dict.update(self._aggregate_losses())

        # Evaluate the model and log the metric
        if self.log_manager.trigger_evaluation():

            # Log metrics
            metric_dict.update(
                self._evaluate(model, dataloaders, self.config.valid_split))
            self._log_metrics(metric_dict)
            self._reset_losses()

        # Checkpoint the model
        if self.log_manager.trigger_checkpointing():
            self._checkpoint_model(model, metric_dict)
            self._reset_losses()

        # Switch back to train mode
        model.train()
        return metric_dict
Beispiel #2
0
    def fit(self, model: MultitaskModel,
            dataloaders: List["DictDataLoader"]) -> None:
        """Train a MultitaskModel.

        Parameters
        ----------
        model
            The model to train
        dataloaders
            A list of DataLoaders. These will split into train, valid, and test splits
            based on the ``split`` attribute of the DataLoaders.
        """
        self._check_dataloaders(dataloaders)

        # Identify the dataloaders to train on
        train_dataloaders = [
            dl for dl in dataloaders
            if dl.dataset.split == self.config.train_split  # type: ignore
        ]

        # Calculate the total number of batches per epoch
        self.n_batches_per_epoch = sum(
            [len(dataloader) for dataloader in train_dataloaders])

        # Set training helpers
        self._set_log_writer()
        self._set_checkpointer()
        self._set_log_manager()
        self._set_optimizer(model)
        self._set_lr_scheduler()
        self._set_batch_scheduler()

        # Set to training mode
        model.train()

        logging.info("Start training...")

        self.metrics: Dict[str, float] = dict()
        self._reset_losses()

        for epoch_num in range(self.config.n_epochs):
            batches = tqdm(
                enumerate(self.batch_scheduler.get_batches(train_dataloaders)),
                total=self.n_batches_per_epoch,
                disable=(not self.config.progress_bar),
                desc=f"Epoch {epoch_num}:",
            )
            for batch_num, (batch, dataloader) in batches:
                X_dict, Y_dict = batch

                total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num
                batch_size = len(next(iter(Y_dict.values())))

                # Set gradients of all model parameters to zero
                self.optimizer.zero_grad()

                # Perform forward pass and calcualte the loss and count
                loss_dict, count_dict = model.calculate_loss(X_dict, Y_dict)

                # Update running loss and count
                for task_name in loss_dict.keys():
                    identifier = "/".join([
                        task_name,
                        dataloader.dataset.name,
                        dataloader.dataset.split,
                        "loss",
                    ])
                    self.running_losses[identifier] += (
                        loss_dict[task_name].item() * count_dict[task_name])
                    self.running_counts[task_name] += count_dict[task_name]

                # Skip the backward pass if no loss is calcuated
                if not loss_dict:
                    continue

                # Calculate the average loss
                loss = torch.stack(list(loss_dict.values())).sum()

                # Perform backward pass to calculate gradients
                loss.backward()

                # Clip gradient norm
                if self.config.grad_clip:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   self.config.grad_clip)

                # Update the parameters
                self.optimizer.step()

                # Update lr using lr scheduler
                self._update_lr_scheduler(total_batch_num)

                # Update metrics
                self.metrics.update(
                    self._logging(model, dataloaders, batch_size))

                batches.set_postfix(self.metrics)

        model = self.log_manager.cleanup(model)