def test_bad_tasks(self): with self.assertRaisesRegex(ValueError, "Found duplicate task"): MultitaskClassifier(tasks=[self.task1, self.task1]) with self.assertRaisesRegex(ValueError, "Unrecognized task type"): MultitaskClassifier(tasks=[self.task1, {"fake_task": 42}]) with self.assertRaisesRegex(ValueError, "Unsuccessful operation"): task1 = create_task("task1") task1.op_sequence[0].inputs[0] = (0, 0) model = MultitaskClassifier(tasks=[task1]) X_dict = self.dataloader.dataset.X_dict model.forward(X_dict, [task1.name])
def test_no_input_spec(self): # Confirm model doesn't break when a module does not specify specific inputs dataset = create_dataloader("task", shuffle=False).dataset task = Task( name="task", module_pool=nn.ModuleDict({"identity": nn.Identity()}), op_sequence=[Operation("identity", [])], ) model = MultitaskClassifier(tasks=[task], dataparallel=False) outputs = model.forward(dataset.X_dict, ["task"]) self.assertIn("_input_", outputs)