Example #1
0
 def test_partially_empty_batch(self):
     dataset = create_dataloader("task1", shuffle=False).dataset
     dataset.Y_dict["task1"][0] = -1
     model = MultitaskModel([self.task1])
     loss_dict, count_dict = model.calculate_loss(dataset.X_dict,
                                                  dataset.Y_dict)
     self.assertEqual(count_dict["task1"], 9)
Example #2
0
    def _logging(
        self,
        model: MultitaskModel,
        dataloaders: List["DictDataLoader"],
        batch_size: int,
    ) -> Metrics:
        """Log and checkpoint if it is time to do so."""

        # Switch to eval mode for evaluation
        model.eval()

        self.log_manager.update(batch_size)

        # Log the loss and lr
        metric_dict: Metrics = dict()
        metric_dict.update(self._aggregate_losses())

        # Evaluate the model and log the metric
        if self.log_manager.trigger_evaluation():

            # Log metrics
            metric_dict.update(
                self._evaluate(model, dataloaders, self.config.valid_split))
            self._log_metrics(metric_dict)
            self._reset_losses()

        # Checkpoint the model
        if self.log_manager.trigger_checkpointing():
            self._checkpoint_model(model, metric_dict)
            self._reset_losses()

        # Switch back to train mode
        model.train()
        return metric_dict
Example #3
0
 def test_empty_batch(self):
     dataset = create_dataloader("task1", shuffle=False).dataset
     dataset.Y_dict["task1"] = torch.full_like(dataset.Y_dict["task1"], -1)
     model = MultitaskModel([self.task1])
     loss_dict, count_dict = model.calculate_loss(dataset.X_dict,
                                                  dataset.Y_dict)
     self.assertFalse(loss_dict)
     self.assertFalse(count_dict)
Example #4
0
    def test_score(self):
        model = MultitaskModel([self.task1])
        metrics = model.score([self.dataloader])
        # deterministic random tie breaking alternates predicted labels
        self.assertEqual(metrics["task1/dataset/train/accuracy"], 0.4)

        # test dataframe format
        metrics_df = model.score([self.dataloader], as_dataframe=True)
        self.assertTrue(isinstance(metrics_df, pd.DataFrame))
        self.assertEqual(metrics_df.at[0, "score"], 0.4)
Example #5
0
 def test_no_input_spec(self):
     # Confirm model doesn't break when a module does not specify specific inputs
     dataset = create_dataloader("task", shuffle=False).dataset
     task = Task(
         name="task",
         module_pool=nn.ModuleDict({"identity": nn.Identity()}),
         op_sequence=[Operation("identity", [])],
     )
     model = MultitaskModel(tasks=[task], dataparallel=False)
     outputs = model.forward(dataset.X_dict, ["task"])
     self.assertIn("_input_", outputs)
Example #6
0
    def test_save_load(self):
        fd, checkpoint_path = tempfile.mkstemp()

        task1 = create_task("task1")
        task2 = create_task("task2")
        # Make task2's second linear layer have different weights
        task2.module_pool["linear2"] = nn.Linear(2, 2)

        model = MultitaskModel([task1])
        self.assertTrue(
            torch.eq(
                task1.module_pool["linear2"].weight,
                model.module_pool["linear2"].module.weight,
            ).all())
        model.save(checkpoint_path)
        model = MultitaskModel([task2])
        self.assertFalse(
            torch.eq(
                task1.module_pool["linear2"].weight,
                model.module_pool["linear2"].module.weight,
            ).all())
        model.load(checkpoint_path)
        self.assertTrue(
            torch.eq(
                task1.module_pool["linear2"].weight,
                model.module_pool["linear2"].module.weight,
            ).all())

        os.close(fd)
Example #7
0
    def load_best_model(self, model: MultitaskModel) -> MultitaskModel:
        """Load the best model from the checkpoint."""
        metric = list(self.checkpoint_metric.keys())[0]
        if metric not in self.best_metric_dict:  # pragma: no cover
            logging.info("No best model found, use the original model.")
        else:
            # Load the best model of checkpoint_metric
            best_model_path = (
                f"{self.checkpoint_dir}/best_model_{metric.replace('/', '_')}.pth"
            )
            logging.info(f"Loading the best model from {best_model_path}.")
            model.load(best_model_path)

        return model
Example #8
0
    def test_score_shuffled(self):
        # Test scoring with a shuffled dataset

        class SimpleVoter(nn.Module):
            def forward(self, x):
                """Set class 0 to -1 if x and 1 otherwise"""
                mask = x % 2 == 0
                out = torch.zeros(x.shape[0], 2)
                out[mask, 0] = 1  # class 0
                out[~mask, 1] = 1  # class 1
                return out

        # Create model
        task_name = "VotingTask"
        module_name = "simple_voter"
        module_pool = nn.ModuleDict({module_name: SimpleVoter()})
        op0 = Operation(module_name=module_name,
                        inputs=[("_input_", "data")],
                        name="op0")
        op_sequence = [op0]
        task = Task(name=task_name,
                    module_pool=module_pool,
                    op_sequence=op_sequence)
        model = MultitaskModel([task])

        # Create dataset
        y_list = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
        x_list = [i for i in range(len(y_list))]
        Y = torch.LongTensor(y_list * 100)
        X = torch.FloatTensor(x_list * 100)
        dataset = DictDataset(name="dataset",
                              split="train",
                              X_dict={"data": X},
                              Y_dict={task_name: Y})

        # Create dataloaders
        dataloader = DictDataLoader(dataset, batch_size=2, shuffle=False)
        scores = model.score([dataloader])

        self.assertEqual(scores["VotingTask/dataset/train/accuracy"], 0.6)

        dataloader_shuffled = DictDataLoader(dataset,
                                             batch_size=2,
                                             shuffle=True)
        scores_shuffled = model.score([dataloader_shuffled])
        self.assertEqual(scores_shuffled["VotingTask/dataset/train/accuracy"],
                         0.6)
Example #9
0
 def test_twotask_partial_overlap_model(self):
     """Add two tasks with overlapping modules and flows"""
     task1 = create_task("task1", module_suffixes=["A", "A"])
     task2 = create_task("task2", module_suffixes=["A", "B"])
     model = MultitaskModel(tasks=[task1, task2])
     self.assertEqual(len(model.task_names), 2)
     self.assertEqual(len(model.op_sequences), 2)
     self.assertEqual(len(model.module_pool), 3)
Example #10
0
    def checkpoint(
        self, iteration: float, model: MultitaskModel, metric_dict: Metrics
    ) -> None:
        """Check if iteration and current metrics necessitate a checkpoint.

        Parameters
        ----------
        iteration
            Current training iteration
        model
            Model to checkpoint
        metric_dict
            Current performance metrics for model
        """
        # Check if the checkpoint_runway condition is met
        if iteration < self.checkpoint_runway:
            return
        elif not self.checkpoint_condition_met and iteration >= self.checkpoint_runway:
            self.checkpoint_condition_met = True
            logging.info(
                "checkpoint_runway condition has been met. Start checkpointing."
            )

        checkpoint_path = f"{self.checkpoint_dir}/checkpoint_{iteration}.pth"
        model.save(checkpoint_path)
        logging.info(
            f"Save checkpoint at {iteration} {self.checkpoint_unit} "
            f"at {checkpoint_path}."
        )

        if not set(self.checkpoint_task_metrics.keys()).isdisjoint(
            set(metric_dict.keys())
        ):
            new_best_metrics = self._is_new_best(metric_dict)
            for metric in new_best_metrics:
                copyfile(
                    checkpoint_path,
                    f"{self.checkpoint_dir}/best_model_"
                    f"{metric.replace('/', '_')}.pth",
                )

                logging.info(
                    f"Save best model of metric {metric} at {self.checkpoint_dir}"
                    f"/best_model_{metric.replace('/', '_')}.pth"
                )
Example #11
0
    def test_predict(self):
        model = MultitaskModel([self.task1])
        results = model.predict(self.dataloader)
        self.assertEqual(sorted(list(results.keys())), ["golds", "probs"])
        np.testing.assert_array_equal(
            results["golds"]["task1"],
            self.dataloader.dataset.Y_dict["task1"].numpy())
        np.testing.assert_array_equal(results["probs"]["task1"],
                                      np.ones((NUM_EXAMPLES, 2)) * 0.5)

        results = model.predict(self.dataloader, return_preds=True)
        self.assertEqual(sorted(list(results.keys())),
                         ["golds", "preds", "probs"])
        # deterministic random tie breaking alternates predicted labels
        np.testing.assert_array_equal(
            results["preds"]["task1"],
            np.array([0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0]),
        )
Example #12
0
    def test_load_on_cleanup(self) -> None:
        log_manager_config = {"counter_unit": "epochs", "evaluation_freq": 1}
        checkpointer = Checkpointer(**log_manager_config, checkpoint_dir=self.test_dir)
        log_manager = LogManager(
            n_batches_per_epoch=2, checkpointer=checkpointer, log_writer=None
        )

        classifier = MultitaskModel([])
        best_classifier = log_manager.cleanup(classifier)
        self.assertEqual(best_classifier, classifier)
Example #13
0
 def _evaluate(
     self,
     model: MultitaskModel,
     dataloaders: List["DictDataLoader"],
     split: str,
 ) -> Metrics:
     """Evalute the current quality of the model on data for the requested split."""
     loaders = [d for d in dataloaders
                if d.dataset.split in split]  # type: ignore
     return model.score(loaders)
Example #14
0
 def test_checkpointer_load_best(self) -> None:
     checkpoint_dir = os.path.join(self.test_dir, "clear")
     checkpointer = Checkpointer(
         **log_manager_config,
         checkpoint_dir=checkpoint_dir,
         checkpoint_metric="task/dataset/valid/f1:max",
     )
     model = MultitaskModel([])
     checkpointer.checkpoint(1, model, {"task/dataset/valid/f1": 0.8})
     load_model = checkpointer.load_best_model(model)
     self.assertEqual(model, load_model)
Example #15
0
 def test_bad_tasks(self):
     with self.assertRaisesRegex(ValueError, "Found duplicate task"):
         MultitaskModel(tasks=[self.task1, self.task1])
     with self.assertRaisesRegex(ValueError, "Unrecognized task type"):
         MultitaskModel(tasks=[self.task1, {"fake_task": 42}])
     with self.assertRaisesRegex(ValueError, "Unsuccessful operation"):
         task1 = create_task("task1")
         task1.op_sequence[0].inputs[0] = (0, 0)
         model = MultitaskModel(tasks=[task1])
         X_dict = self.dataloader.dataset.X_dict
         model.forward(X_dict, [task1.name])
Example #16
0
 def test_checkpointer_min(self) -> None:
     checkpointer = Checkpointer(
         **log_manager_config,
         checkpoint_dir=self.test_dir,
         checkpoint_runway=3,
         checkpoint_metric="task/dataset/valid/f1:min",
     )
     model = MultitaskModel([])
     checkpointer.checkpoint(3, model, {
         "task/dataset/valid/f1": 0.8,
         "task/dataset/valid/f2": 0.5
     })
     self.assertEqual(
         checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.8)
     checkpointer.checkpoint(4, model, {"task/dataset/valid/f1": 0.7})
     self.assertEqual(
         checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.7)
Example #17
0
 def test_checkpointer_clear(self) -> None:
     checkpoint_dir = os.path.join(self.test_dir, "clear")
     checkpointer = Checkpointer(
         **log_manager_config,
         checkpoint_dir=checkpoint_dir,
         checkpoint_metric="task/dataset/valid/f1:max",
         checkpoint_clear=True,
     )
     model = MultitaskModel([])
     checkpointer.checkpoint(1, model, {"task/dataset/valid/f1": 0.8})
     expected_files = [
         "checkpoint_1.pth", "best_model_task_dataset_valid_f1.pth"
     ]
     self.assertEqual(set(os.listdir(checkpoint_dir)), set(expected_files))
     checkpointer.clear()
     expected_files = ["best_model_task_dataset_valid_f1.pth"]
     self.assertEqual(os.listdir(checkpoint_dir), expected_files)
Example #18
0
    def test_remapped_labels(self):
        # Test additional label keys in the Y_dict
        # Without remapping, model should ignore them
        task_name = self.task1.name
        X = torch.FloatTensor([[i, i] for i in range(NUM_EXAMPLES)])
        Y = torch.ones(NUM_EXAMPLES).long()

        Y_dict = {task_name: Y, "other_task": Y}
        dataset = DictDataset(name="dataset",
                              split="train",
                              X_dict={"data": X},
                              Y_dict=Y_dict)
        dataloader = DictDataLoader(dataset, batch_size=BATCH_SIZE)

        model = MultitaskModel([self.task1])
        loss_dict, count_dict = model.calculate_loss(dataset.X_dict,
                                                     dataset.Y_dict)
        self.assertIn("task1", loss_dict)

        # Test setting without remapping
        results = model.predict(dataloader)
        self.assertIn("task1", results["golds"])
        self.assertNotIn("other_task", results["golds"])
        scores = model.score([dataloader])
        self.assertIn("task1/dataset/train/accuracy", scores)
        self.assertNotIn("other_task/dataset/train/accuracy", scores)

        # Test remapped labelsets
        results = model.predict(dataloader,
                                remap_labels={"other_task": task_name})
        self.assertIn("task1", results["golds"])
        self.assertIn("other_task", results["golds"])
        results = model.score([dataloader],
                              remap_labels={"other_task": task_name})
        self.assertIn("task1/dataset/train/accuracy", results)
        self.assertIn("other_task/dataset/train/accuracy", results)
Example #19
0
 def test_no_data_parallel(self):
     model = MultitaskModel(tasks=[self.task1, self.task2],
                            dataparallel=False)
     self.assertEqual(len(model.task_names), 2)
     self.assertIsInstance(model.module_pool["linear1A"], nn.Module)
Example #20
0
 def test_twotask_none_overlap_model(self):
     """Add two tasks with totally separate modules and flows"""
     model = MultitaskModel(tasks=[self.task1, self.task2])
     self.assertEqual(len(model.task_names), 2)
     self.assertEqual(len(model.op_sequences), 2)
     self.assertEqual(len(model.module_pool), 4)
Example #21
0
 def test_onetask_model(self):
     model = MultitaskModel(tasks=[self.task1])
     self.assertEqual(len(model.task_names), 1)
     self.assertEqual(len(model.op_sequences), 1)
     self.assertEqual(len(model.module_pool), 2)
Example #22
0
    def fit(self, model: MultitaskModel,
            dataloaders: List["DictDataLoader"]) -> None:
        """Train a MultitaskModel.

        Parameters
        ----------
        model
            The model to train
        dataloaders
            A list of DataLoaders. These will split into train, valid, and test splits
            based on the ``split`` attribute of the DataLoaders.
        """
        self._check_dataloaders(dataloaders)

        # Identify the dataloaders to train on
        train_dataloaders = [
            dl for dl in dataloaders
            if dl.dataset.split == self.config.train_split  # type: ignore
        ]

        # Calculate the total number of batches per epoch
        self.n_batches_per_epoch = sum(
            [len(dataloader) for dataloader in train_dataloaders])

        # Set training helpers
        self._set_log_writer()
        self._set_checkpointer()
        self._set_log_manager()
        self._set_optimizer(model)
        self._set_lr_scheduler()
        self._set_batch_scheduler()

        # Set to training mode
        model.train()

        logging.info("Start training...")

        self.metrics: Dict[str, float] = dict()
        self._reset_losses()

        for epoch_num in range(self.config.n_epochs):
            batches = tqdm(
                enumerate(self.batch_scheduler.get_batches(train_dataloaders)),
                total=self.n_batches_per_epoch,
                disable=(not self.config.progress_bar),
                desc=f"Epoch {epoch_num}:",
            )
            for batch_num, (batch, dataloader) in batches:
                X_dict, Y_dict = batch

                total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num
                batch_size = len(next(iter(Y_dict.values())))

                # Set gradients of all model parameters to zero
                self.optimizer.zero_grad()

                # Perform forward pass and calcualte the loss and count
                loss_dict, count_dict = model.calculate_loss(X_dict, Y_dict)

                # Update running loss and count
                for task_name in loss_dict.keys():
                    identifier = "/".join([
                        task_name,
                        dataloader.dataset.name,
                        dataloader.dataset.split,
                        "loss",
                    ])
                    self.running_losses[identifier] += (
                        loss_dict[task_name].item() * count_dict[task_name])
                    self.running_counts[task_name] += count_dict[task_name]

                # Skip the backward pass if no loss is calcuated
                if not loss_dict:
                    continue

                # Calculate the average loss
                loss = torch.stack(list(loss_dict.values())).sum()

                # Perform backward pass to calculate gradients
                loss.backward()

                # Clip gradient norm
                if self.config.grad_clip:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   self.config.grad_clip)

                # Update the parameters
                self.optimizer.step()

                # Update lr using lr scheduler
                self._update_lr_scheduler(total_batch_num)

                # Update metrics
                self.metrics.update(
                    self._logging(model, dataloaders, batch_size))

                batches.set_postfix(self.metrics)

        model = self.log_manager.cleanup(model)
Example #23
0
                loss_func=F.mse_loss,
                output_func=identity_fn,
                scorer=Scorer(metrics=["mse"]))
# -

# ## Model

# With our tasks defined, constructing a model is simple: we simply pass the list of tasks in and the model constructs itself using information from the task flows.
#
# Note that the model uses the names of modules (not the modules themselves) to determine whether two modules specified by separate tasks are the same module (and should share weights) or different modules (with separate weights).
# So because both the `class_task` and `rgb_task` include "base_net" in their module pools, this module will be shared between the two tasks.

# +
from cerbero.models import MultitaskModel

model = MultitaskModel([class_task, rgb_task])
# -

# ### Train Model

# Once the model is constructed, we can train it as we would a single-task model, using the `fit` method of a `Trainer` object. The `Trainer` supports multiple schedules or patterns for sampling from different dataloaders; the default is to randomly sample from them proportional to the number of batches, such that all data points will be seen exactly once before any are seen twice.

# +
from cerbero.trainer import Trainer

trainer_config = {
    "progress_bar": True,
    "n_epochs": 15,
    "lr": 2e-3,
    "checkpointing": True
}
Example #24
0
    op_sequence = [op1, op2]

    task = Task(name=task_name,
                module_pool=module_pool,
                op_sequence=op_sequence)

    return task


dataloaders = [create_dataloader(task_name) for task_name in TASK_NAMES]
tasks = [
    create_task(TASK_NAMES[0], module_suffixes=["A", "A"]),
    create_task(TASK_NAMES[1], module_suffixes=["A", "B"]),
]
model = MultitaskModel([tasks[0]])


class TrainerTest(unittest.TestCase):
    def test_trainer_onetask(self):
        """Train a single-task model"""
        trainer = Trainer(**base_config)
        trainer.fit(model, [dataloaders[0]])

    def test_trainer_twotask(self):
        """Train a model with overlapping modules and flows"""
        multitask_model = MultitaskModel(tasks)
        trainer = Trainer(**base_config)
        trainer.fit(multitask_model, dataloaders)

    def test_trainer_errors(self):
Example #25
0
 def test_trainer_twotask(self):
     """Train a model with overlapping modules and flows"""
     multitask_model = MultitaskModel(tasks)
     trainer = Trainer(**base_config)
     trainer.fit(multitask_model, dataloaders)