def test_partially_empty_batch(self): dataset = create_dataloader("task1", shuffle=False).dataset dataset.Y_dict["task1"][0] = -1 model = MultitaskModel([self.task1]) loss_dict, count_dict = model.calculate_loss(dataset.X_dict, dataset.Y_dict) self.assertEqual(count_dict["task1"], 9)
def _logging( self, model: MultitaskModel, dataloaders: List["DictDataLoader"], batch_size: int, ) -> Metrics: """Log and checkpoint if it is time to do so.""" # Switch to eval mode for evaluation model.eval() self.log_manager.update(batch_size) # Log the loss and lr metric_dict: Metrics = dict() metric_dict.update(self._aggregate_losses()) # Evaluate the model and log the metric if self.log_manager.trigger_evaluation(): # Log metrics metric_dict.update( self._evaluate(model, dataloaders, self.config.valid_split)) self._log_metrics(metric_dict) self._reset_losses() # Checkpoint the model if self.log_manager.trigger_checkpointing(): self._checkpoint_model(model, metric_dict) self._reset_losses() # Switch back to train mode model.train() return metric_dict
def test_empty_batch(self): dataset = create_dataloader("task1", shuffle=False).dataset dataset.Y_dict["task1"] = torch.full_like(dataset.Y_dict["task1"], -1) model = MultitaskModel([self.task1]) loss_dict, count_dict = model.calculate_loss(dataset.X_dict, dataset.Y_dict) self.assertFalse(loss_dict) self.assertFalse(count_dict)
def test_score(self): model = MultitaskModel([self.task1]) metrics = model.score([self.dataloader]) # deterministic random tie breaking alternates predicted labels self.assertEqual(metrics["task1/dataset/train/accuracy"], 0.4) # test dataframe format metrics_df = model.score([self.dataloader], as_dataframe=True) self.assertTrue(isinstance(metrics_df, pd.DataFrame)) self.assertEqual(metrics_df.at[0, "score"], 0.4)
def test_no_input_spec(self): # Confirm model doesn't break when a module does not specify specific inputs dataset = create_dataloader("task", shuffle=False).dataset task = Task( name="task", module_pool=nn.ModuleDict({"identity": nn.Identity()}), op_sequence=[Operation("identity", [])], ) model = MultitaskModel(tasks=[task], dataparallel=False) outputs = model.forward(dataset.X_dict, ["task"]) self.assertIn("_input_", outputs)
def test_save_load(self): fd, checkpoint_path = tempfile.mkstemp() task1 = create_task("task1") task2 = create_task("task2") # Make task2's second linear layer have different weights task2.module_pool["linear2"] = nn.Linear(2, 2) model = MultitaskModel([task1]) self.assertTrue( torch.eq( task1.module_pool["linear2"].weight, model.module_pool["linear2"].module.weight, ).all()) model.save(checkpoint_path) model = MultitaskModel([task2]) self.assertFalse( torch.eq( task1.module_pool["linear2"].weight, model.module_pool["linear2"].module.weight, ).all()) model.load(checkpoint_path) self.assertTrue( torch.eq( task1.module_pool["linear2"].weight, model.module_pool["linear2"].module.weight, ).all()) os.close(fd)
def load_best_model(self, model: MultitaskModel) -> MultitaskModel: """Load the best model from the checkpoint.""" metric = list(self.checkpoint_metric.keys())[0] if metric not in self.best_metric_dict: # pragma: no cover logging.info("No best model found, use the original model.") else: # Load the best model of checkpoint_metric best_model_path = ( f"{self.checkpoint_dir}/best_model_{metric.replace('/', '_')}.pth" ) logging.info(f"Loading the best model from {best_model_path}.") model.load(best_model_path) return model
def test_score_shuffled(self): # Test scoring with a shuffled dataset class SimpleVoter(nn.Module): def forward(self, x): """Set class 0 to -1 if x and 1 otherwise""" mask = x % 2 == 0 out = torch.zeros(x.shape[0], 2) out[mask, 0] = 1 # class 0 out[~mask, 1] = 1 # class 1 return out # Create model task_name = "VotingTask" module_name = "simple_voter" module_pool = nn.ModuleDict({module_name: SimpleVoter()}) op0 = Operation(module_name=module_name, inputs=[("_input_", "data")], name="op0") op_sequence = [op0] task = Task(name=task_name, module_pool=module_pool, op_sequence=op_sequence) model = MultitaskModel([task]) # Create dataset y_list = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] x_list = [i for i in range(len(y_list))] Y = torch.LongTensor(y_list * 100) X = torch.FloatTensor(x_list * 100) dataset = DictDataset(name="dataset", split="train", X_dict={"data": X}, Y_dict={task_name: Y}) # Create dataloaders dataloader = DictDataLoader(dataset, batch_size=2, shuffle=False) scores = model.score([dataloader]) self.assertEqual(scores["VotingTask/dataset/train/accuracy"], 0.6) dataloader_shuffled = DictDataLoader(dataset, batch_size=2, shuffle=True) scores_shuffled = model.score([dataloader_shuffled]) self.assertEqual(scores_shuffled["VotingTask/dataset/train/accuracy"], 0.6)
def test_twotask_partial_overlap_model(self): """Add two tasks with overlapping modules and flows""" task1 = create_task("task1", module_suffixes=["A", "A"]) task2 = create_task("task2", module_suffixes=["A", "B"]) model = MultitaskModel(tasks=[task1, task2]) self.assertEqual(len(model.task_names), 2) self.assertEqual(len(model.op_sequences), 2) self.assertEqual(len(model.module_pool), 3)
def checkpoint( self, iteration: float, model: MultitaskModel, metric_dict: Metrics ) -> None: """Check if iteration and current metrics necessitate a checkpoint. Parameters ---------- iteration Current training iteration model Model to checkpoint metric_dict Current performance metrics for model """ # Check if the checkpoint_runway condition is met if iteration < self.checkpoint_runway: return elif not self.checkpoint_condition_met and iteration >= self.checkpoint_runway: self.checkpoint_condition_met = True logging.info( "checkpoint_runway condition has been met. Start checkpointing." ) checkpoint_path = f"{self.checkpoint_dir}/checkpoint_{iteration}.pth" model.save(checkpoint_path) logging.info( f"Save checkpoint at {iteration} {self.checkpoint_unit} " f"at {checkpoint_path}." ) if not set(self.checkpoint_task_metrics.keys()).isdisjoint( set(metric_dict.keys()) ): new_best_metrics = self._is_new_best(metric_dict) for metric in new_best_metrics: copyfile( checkpoint_path, f"{self.checkpoint_dir}/best_model_" f"{metric.replace('/', '_')}.pth", ) logging.info( f"Save best model of metric {metric} at {self.checkpoint_dir}" f"/best_model_{metric.replace('/', '_')}.pth" )
def test_predict(self): model = MultitaskModel([self.task1]) results = model.predict(self.dataloader) self.assertEqual(sorted(list(results.keys())), ["golds", "probs"]) np.testing.assert_array_equal( results["golds"]["task1"], self.dataloader.dataset.Y_dict["task1"].numpy()) np.testing.assert_array_equal(results["probs"]["task1"], np.ones((NUM_EXAMPLES, 2)) * 0.5) results = model.predict(self.dataloader, return_preds=True) self.assertEqual(sorted(list(results.keys())), ["golds", "preds", "probs"]) # deterministic random tie breaking alternates predicted labels np.testing.assert_array_equal( results["preds"]["task1"], np.array([0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0]), )
def test_load_on_cleanup(self) -> None: log_manager_config = {"counter_unit": "epochs", "evaluation_freq": 1} checkpointer = Checkpointer(**log_manager_config, checkpoint_dir=self.test_dir) log_manager = LogManager( n_batches_per_epoch=2, checkpointer=checkpointer, log_writer=None ) classifier = MultitaskModel([]) best_classifier = log_manager.cleanup(classifier) self.assertEqual(best_classifier, classifier)
def _evaluate( self, model: MultitaskModel, dataloaders: List["DictDataLoader"], split: str, ) -> Metrics: """Evalute the current quality of the model on data for the requested split.""" loaders = [d for d in dataloaders if d.dataset.split in split] # type: ignore return model.score(loaders)
def test_checkpointer_load_best(self) -> None: checkpoint_dir = os.path.join(self.test_dir, "clear") checkpointer = Checkpointer( **log_manager_config, checkpoint_dir=checkpoint_dir, checkpoint_metric="task/dataset/valid/f1:max", ) model = MultitaskModel([]) checkpointer.checkpoint(1, model, {"task/dataset/valid/f1": 0.8}) load_model = checkpointer.load_best_model(model) self.assertEqual(model, load_model)
def test_bad_tasks(self): with self.assertRaisesRegex(ValueError, "Found duplicate task"): MultitaskModel(tasks=[self.task1, self.task1]) with self.assertRaisesRegex(ValueError, "Unrecognized task type"): MultitaskModel(tasks=[self.task1, {"fake_task": 42}]) with self.assertRaisesRegex(ValueError, "Unsuccessful operation"): task1 = create_task("task1") task1.op_sequence[0].inputs[0] = (0, 0) model = MultitaskModel(tasks=[task1]) X_dict = self.dataloader.dataset.X_dict model.forward(X_dict, [task1.name])
def test_checkpointer_min(self) -> None: checkpointer = Checkpointer( **log_manager_config, checkpoint_dir=self.test_dir, checkpoint_runway=3, checkpoint_metric="task/dataset/valid/f1:min", ) model = MultitaskModel([]) checkpointer.checkpoint(3, model, { "task/dataset/valid/f1": 0.8, "task/dataset/valid/f2": 0.5 }) self.assertEqual( checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.8) checkpointer.checkpoint(4, model, {"task/dataset/valid/f1": 0.7}) self.assertEqual( checkpointer.best_metric_dict["task/dataset/valid/f1"], 0.7)
def test_checkpointer_clear(self) -> None: checkpoint_dir = os.path.join(self.test_dir, "clear") checkpointer = Checkpointer( **log_manager_config, checkpoint_dir=checkpoint_dir, checkpoint_metric="task/dataset/valid/f1:max", checkpoint_clear=True, ) model = MultitaskModel([]) checkpointer.checkpoint(1, model, {"task/dataset/valid/f1": 0.8}) expected_files = [ "checkpoint_1.pth", "best_model_task_dataset_valid_f1.pth" ] self.assertEqual(set(os.listdir(checkpoint_dir)), set(expected_files)) checkpointer.clear() expected_files = ["best_model_task_dataset_valid_f1.pth"] self.assertEqual(os.listdir(checkpoint_dir), expected_files)
def test_remapped_labels(self): # Test additional label keys in the Y_dict # Without remapping, model should ignore them task_name = self.task1.name X = torch.FloatTensor([[i, i] for i in range(NUM_EXAMPLES)]) Y = torch.ones(NUM_EXAMPLES).long() Y_dict = {task_name: Y, "other_task": Y} dataset = DictDataset(name="dataset", split="train", X_dict={"data": X}, Y_dict=Y_dict) dataloader = DictDataLoader(dataset, batch_size=BATCH_SIZE) model = MultitaskModel([self.task1]) loss_dict, count_dict = model.calculate_loss(dataset.X_dict, dataset.Y_dict) self.assertIn("task1", loss_dict) # Test setting without remapping results = model.predict(dataloader) self.assertIn("task1", results["golds"]) self.assertNotIn("other_task", results["golds"]) scores = model.score([dataloader]) self.assertIn("task1/dataset/train/accuracy", scores) self.assertNotIn("other_task/dataset/train/accuracy", scores) # Test remapped labelsets results = model.predict(dataloader, remap_labels={"other_task": task_name}) self.assertIn("task1", results["golds"]) self.assertIn("other_task", results["golds"]) results = model.score([dataloader], remap_labels={"other_task": task_name}) self.assertIn("task1/dataset/train/accuracy", results) self.assertIn("other_task/dataset/train/accuracy", results)
def test_no_data_parallel(self): model = MultitaskModel(tasks=[self.task1, self.task2], dataparallel=False) self.assertEqual(len(model.task_names), 2) self.assertIsInstance(model.module_pool["linear1A"], nn.Module)
def test_twotask_none_overlap_model(self): """Add two tasks with totally separate modules and flows""" model = MultitaskModel(tasks=[self.task1, self.task2]) self.assertEqual(len(model.task_names), 2) self.assertEqual(len(model.op_sequences), 2) self.assertEqual(len(model.module_pool), 4)
def test_onetask_model(self): model = MultitaskModel(tasks=[self.task1]) self.assertEqual(len(model.task_names), 1) self.assertEqual(len(model.op_sequences), 1) self.assertEqual(len(model.module_pool), 2)
def fit(self, model: MultitaskModel, dataloaders: List["DictDataLoader"]) -> None: """Train a MultitaskModel. Parameters ---------- model The model to train dataloaders A list of DataLoaders. These will split into train, valid, and test splits based on the ``split`` attribute of the DataLoaders. """ self._check_dataloaders(dataloaders) # Identify the dataloaders to train on train_dataloaders = [ dl for dl in dataloaders if dl.dataset.split == self.config.train_split # type: ignore ] # Calculate the total number of batches per epoch self.n_batches_per_epoch = sum( [len(dataloader) for dataloader in train_dataloaders]) # Set training helpers self._set_log_writer() self._set_checkpointer() self._set_log_manager() self._set_optimizer(model) self._set_lr_scheduler() self._set_batch_scheduler() # Set to training mode model.train() logging.info("Start training...") self.metrics: Dict[str, float] = dict() self._reset_losses() for epoch_num in range(self.config.n_epochs): batches = tqdm( enumerate(self.batch_scheduler.get_batches(train_dataloaders)), total=self.n_batches_per_epoch, disable=(not self.config.progress_bar), desc=f"Epoch {epoch_num}:", ) for batch_num, (batch, dataloader) in batches: X_dict, Y_dict = batch total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num batch_size = len(next(iter(Y_dict.values()))) # Set gradients of all model parameters to zero self.optimizer.zero_grad() # Perform forward pass and calcualte the loss and count loss_dict, count_dict = model.calculate_loss(X_dict, Y_dict) # Update running loss and count for task_name in loss_dict.keys(): identifier = "/".join([ task_name, dataloader.dataset.name, dataloader.dataset.split, "loss", ]) self.running_losses[identifier] += ( loss_dict[task_name].item() * count_dict[task_name]) self.running_counts[task_name] += count_dict[task_name] # Skip the backward pass if no loss is calcuated if not loss_dict: continue # Calculate the average loss loss = torch.stack(list(loss_dict.values())).sum() # Perform backward pass to calculate gradients loss.backward() # Clip gradient norm if self.config.grad_clip: torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.grad_clip) # Update the parameters self.optimizer.step() # Update lr using lr scheduler self._update_lr_scheduler(total_batch_num) # Update metrics self.metrics.update( self._logging(model, dataloaders, batch_size)) batches.set_postfix(self.metrics) model = self.log_manager.cleanup(model)
loss_func=F.mse_loss, output_func=identity_fn, scorer=Scorer(metrics=["mse"])) # - # ## Model # With our tasks defined, constructing a model is simple: we simply pass the list of tasks in and the model constructs itself using information from the task flows. # # Note that the model uses the names of modules (not the modules themselves) to determine whether two modules specified by separate tasks are the same module (and should share weights) or different modules (with separate weights). # So because both the `class_task` and `rgb_task` include "base_net" in their module pools, this module will be shared between the two tasks. # + from cerbero.models import MultitaskModel model = MultitaskModel([class_task, rgb_task]) # - # ### Train Model # Once the model is constructed, we can train it as we would a single-task model, using the `fit` method of a `Trainer` object. The `Trainer` supports multiple schedules or patterns for sampling from different dataloaders; the default is to randomly sample from them proportional to the number of batches, such that all data points will be seen exactly once before any are seen twice. # + from cerbero.trainer import Trainer trainer_config = { "progress_bar": True, "n_epochs": 15, "lr": 2e-3, "checkpointing": True }
op_sequence = [op1, op2] task = Task(name=task_name, module_pool=module_pool, op_sequence=op_sequence) return task dataloaders = [create_dataloader(task_name) for task_name in TASK_NAMES] tasks = [ create_task(TASK_NAMES[0], module_suffixes=["A", "A"]), create_task(TASK_NAMES[1], module_suffixes=["A", "B"]), ] model = MultitaskModel([tasks[0]]) class TrainerTest(unittest.TestCase): def test_trainer_onetask(self): """Train a single-task model""" trainer = Trainer(**base_config) trainer.fit(model, [dataloaders[0]]) def test_trainer_twotask(self): """Train a model with overlapping modules and flows""" multitask_model = MultitaskModel(tasks) trainer = Trainer(**base_config) trainer.fit(multitask_model, dataloaders) def test_trainer_errors(self):
def test_trainer_twotask(self): """Train a model with overlapping modules and flows""" multitask_model = MultitaskModel(tasks) trainer = Trainer(**base_config) trainer.fit(multitask_model, dataloaders)