def test_partially_empty_batch(self): dataset = create_dataloader("task1", shuffle=False).dataset dataset.Y_dict["task1"][0] = -1 model = MultitaskModel([self.task1]) loss_dict, count_dict = model.calculate_loss(dataset.X_dict, dataset.Y_dict) self.assertEqual(count_dict["task1"], 9)
def test_empty_batch(self): dataset = create_dataloader("task1", shuffle=False).dataset dataset.Y_dict["task1"] = torch.full_like(dataset.Y_dict["task1"], -1) model = MultitaskModel([self.task1]) loss_dict, count_dict = model.calculate_loss(dataset.X_dict, dataset.Y_dict) self.assertFalse(loss_dict) self.assertFalse(count_dict)
def test_remapped_labels(self): # Test additional label keys in the Y_dict # Without remapping, model should ignore them task_name = self.task1.name X = torch.FloatTensor([[i, i] for i in range(NUM_EXAMPLES)]) Y = torch.ones(NUM_EXAMPLES).long() Y_dict = {task_name: Y, "other_task": Y} dataset = DictDataset(name="dataset", split="train", X_dict={"data": X}, Y_dict=Y_dict) dataloader = DictDataLoader(dataset, batch_size=BATCH_SIZE) model = MultitaskModel([self.task1]) loss_dict, count_dict = model.calculate_loss(dataset.X_dict, dataset.Y_dict) self.assertIn("task1", loss_dict) # Test setting without remapping results = model.predict(dataloader) self.assertIn("task1", results["golds"]) self.assertNotIn("other_task", results["golds"]) scores = model.score([dataloader]) self.assertIn("task1/dataset/train/accuracy", scores) self.assertNotIn("other_task/dataset/train/accuracy", scores) # Test remapped labelsets results = model.predict(dataloader, remap_labels={"other_task": task_name}) self.assertIn("task1", results["golds"]) self.assertIn("other_task", results["golds"]) results = model.score([dataloader], remap_labels={"other_task": task_name}) self.assertIn("task1/dataset/train/accuracy", results) self.assertIn("other_task/dataset/train/accuracy", results)
def fit(self, model: MultitaskModel, dataloaders: List["DictDataLoader"]) -> None: """Train a MultitaskModel. Parameters ---------- model The model to train dataloaders A list of DataLoaders. These will split into train, valid, and test splits based on the ``split`` attribute of the DataLoaders. """ self._check_dataloaders(dataloaders) # Identify the dataloaders to train on train_dataloaders = [ dl for dl in dataloaders if dl.dataset.split == self.config.train_split # type: ignore ] # Calculate the total number of batches per epoch self.n_batches_per_epoch = sum( [len(dataloader) for dataloader in train_dataloaders]) # Set training helpers self._set_log_writer() self._set_checkpointer() self._set_log_manager() self._set_optimizer(model) self._set_lr_scheduler() self._set_batch_scheduler() # Set to training mode model.train() logging.info("Start training...") self.metrics: Dict[str, float] = dict() self._reset_losses() for epoch_num in range(self.config.n_epochs): batches = tqdm( enumerate(self.batch_scheduler.get_batches(train_dataloaders)), total=self.n_batches_per_epoch, disable=(not self.config.progress_bar), desc=f"Epoch {epoch_num}:", ) for batch_num, (batch, dataloader) in batches: X_dict, Y_dict = batch total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num batch_size = len(next(iter(Y_dict.values()))) # Set gradients of all model parameters to zero self.optimizer.zero_grad() # Perform forward pass and calcualte the loss and count loss_dict, count_dict = model.calculate_loss(X_dict, Y_dict) # Update running loss and count for task_name in loss_dict.keys(): identifier = "/".join([ task_name, dataloader.dataset.name, dataloader.dataset.split, "loss", ]) self.running_losses[identifier] += ( loss_dict[task_name].item() * count_dict[task_name]) self.running_counts[task_name] += count_dict[task_name] # Skip the backward pass if no loss is calcuated if not loss_dict: continue # Calculate the average loss loss = torch.stack(list(loss_dict.values())).sum() # Perform backward pass to calculate gradients loss.backward() # Clip gradient norm if self.config.grad_clip: torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.grad_clip) # Update the parameters self.optimizer.step() # Update lr using lr scheduler self._update_lr_scheduler(total_batch_num) # Update metrics self.metrics.update( self._logging(model, dataloaders, batch_size)) batches.set_postfix(self.metrics) model = self.log_manager.cleanup(model)