Example #1
0
 def test_partially_empty_batch(self):
     dataset = create_dataloader("task1", shuffle=False).dataset
     dataset.Y_dict["task1"][0] = -1
     model = MultitaskModel([self.task1])
     loss_dict, count_dict = model.calculate_loss(dataset.X_dict,
                                                  dataset.Y_dict)
     self.assertEqual(count_dict["task1"], 9)
Example #2
0
 def test_empty_batch(self):
     dataset = create_dataloader("task1", shuffle=False).dataset
     dataset.Y_dict["task1"] = torch.full_like(dataset.Y_dict["task1"], -1)
     model = MultitaskModel([self.task1])
     loss_dict, count_dict = model.calculate_loss(dataset.X_dict,
                                                  dataset.Y_dict)
     self.assertFalse(loss_dict)
     self.assertFalse(count_dict)
Example #3
0
    def test_remapped_labels(self):
        # Test additional label keys in the Y_dict
        # Without remapping, model should ignore them
        task_name = self.task1.name
        X = torch.FloatTensor([[i, i] for i in range(NUM_EXAMPLES)])
        Y = torch.ones(NUM_EXAMPLES).long()

        Y_dict = {task_name: Y, "other_task": Y}
        dataset = DictDataset(name="dataset",
                              split="train",
                              X_dict={"data": X},
                              Y_dict=Y_dict)
        dataloader = DictDataLoader(dataset, batch_size=BATCH_SIZE)

        model = MultitaskModel([self.task1])
        loss_dict, count_dict = model.calculate_loss(dataset.X_dict,
                                                     dataset.Y_dict)
        self.assertIn("task1", loss_dict)

        # Test setting without remapping
        results = model.predict(dataloader)
        self.assertIn("task1", results["golds"])
        self.assertNotIn("other_task", results["golds"])
        scores = model.score([dataloader])
        self.assertIn("task1/dataset/train/accuracy", scores)
        self.assertNotIn("other_task/dataset/train/accuracy", scores)

        # Test remapped labelsets
        results = model.predict(dataloader,
                                remap_labels={"other_task": task_name})
        self.assertIn("task1", results["golds"])
        self.assertIn("other_task", results["golds"])
        results = model.score([dataloader],
                              remap_labels={"other_task": task_name})
        self.assertIn("task1/dataset/train/accuracy", results)
        self.assertIn("other_task/dataset/train/accuracy", results)
Example #4
0
    def fit(self, model: MultitaskModel,
            dataloaders: List["DictDataLoader"]) -> None:
        """Train a MultitaskModel.

        Parameters
        ----------
        model
            The model to train
        dataloaders
            A list of DataLoaders. These will split into train, valid, and test splits
            based on the ``split`` attribute of the DataLoaders.
        """
        self._check_dataloaders(dataloaders)

        # Identify the dataloaders to train on
        train_dataloaders = [
            dl for dl in dataloaders
            if dl.dataset.split == self.config.train_split  # type: ignore
        ]

        # Calculate the total number of batches per epoch
        self.n_batches_per_epoch = sum(
            [len(dataloader) for dataloader in train_dataloaders])

        # Set training helpers
        self._set_log_writer()
        self._set_checkpointer()
        self._set_log_manager()
        self._set_optimizer(model)
        self._set_lr_scheduler()
        self._set_batch_scheduler()

        # Set to training mode
        model.train()

        logging.info("Start training...")

        self.metrics: Dict[str, float] = dict()
        self._reset_losses()

        for epoch_num in range(self.config.n_epochs):
            batches = tqdm(
                enumerate(self.batch_scheduler.get_batches(train_dataloaders)),
                total=self.n_batches_per_epoch,
                disable=(not self.config.progress_bar),
                desc=f"Epoch {epoch_num}:",
            )
            for batch_num, (batch, dataloader) in batches:
                X_dict, Y_dict = batch

                total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num
                batch_size = len(next(iter(Y_dict.values())))

                # Set gradients of all model parameters to zero
                self.optimizer.zero_grad()

                # Perform forward pass and calcualte the loss and count
                loss_dict, count_dict = model.calculate_loss(X_dict, Y_dict)

                # Update running loss and count
                for task_name in loss_dict.keys():
                    identifier = "/".join([
                        task_name,
                        dataloader.dataset.name,
                        dataloader.dataset.split,
                        "loss",
                    ])
                    self.running_losses[identifier] += (
                        loss_dict[task_name].item() * count_dict[task_name])
                    self.running_counts[task_name] += count_dict[task_name]

                # Skip the backward pass if no loss is calcuated
                if not loss_dict:
                    continue

                # Calculate the average loss
                loss = torch.stack(list(loss_dict.values())).sum()

                # Perform backward pass to calculate gradients
                loss.backward()

                # Clip gradient norm
                if self.config.grad_clip:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   self.config.grad_clip)

                # Update the parameters
                self.optimizer.step()

                # Update lr using lr scheduler
                self._update_lr_scheduler(total_batch_num)

                # Update metrics
                self.metrics.update(
                    self._logging(model, dataloaders, batch_size))

                batches.set_postfix(self.metrics)

        model = self.log_manager.cleanup(model)