def _train_s(self, S: ModelTrainer, inputs, target, backward=True): S.zero_grad() target_ohe = to_onehot(torch.squeeze(target, dim=1).long(), num_classes=4) target = torch.squeeze(target, dim=1).long() seg_pred = torch.nn.functional.softmax(S.forward(inputs), dim=1) loss_S = S.compute_loss("DiceLoss", seg_pred, target_ohe) S.update_train_loss("DiceLoss", loss_S.mean()) metrics = S.compute_metrics(seg_pred, target) metrics["Dice"] = metrics["Dice"].mean() metrics["IoU"] = metrics["IoU"].mean() S.update_train_metrics(metrics) if backward: loss_S.mean().backward() S.step() return seg_pred, loss_S
class ModelTrainerTest(unittest.TestCase): MODEL_NAME_1 = "Harry Potter" MODEL_NAME_2 = "Hagrid" LOSS_NAME = "CrossEntropy" def setUp(self) -> None: self._model = SimpleNet() self._criterion_mock = spy(nn.CrossEntropyLoss()) self._optimizer_mock = spy(SGD(self._model.parameters(), lr=0.001)) self._scheduler_mock = mock(lr_scheduler) self._accuracy_computer_mock = spy(Accuracy()) self._recall_computer_mock = spy(Recall()) self._gradient_clipping_strategy = None self._model_trainer = ModelTrainer( self.MODEL_NAME_1, self._model, {self.LOSS_NAME: self._criterion_mock}, self._optimizer_mock, self._scheduler_mock, { "Accuracy": self._accuracy_computer_mock, "Recall": self._recall_computer_mock }, self._gradient_clipping_strategy) self._gradient_clipping_strategy = mock(GradientClippingStrategy) def tearDown(self) -> None: super().tearDown() def test_should_compute_train_metric_and_update_state(self): metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0) self._model_trainer.update_train_metrics(metric) assert_that(self._model_trainer.train_metrics["Accuracy"], equal_to(ONE)) assert_that(self._model_trainer.step_train_metrics["Accuracy"], equal_to(ONE)) verify(self._accuracy_computer_mock).update( (MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0)) metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_1, TARGET_CLASS_0) self._model_trainer.update_train_metrics(metric) assert_that(self._model_trainer.train_metrics["Accuracy"], equal_to(ONE / 2)) assert_that(self._model_trainer.step_train_metrics["Accuracy"], equal_to(ZERO)) verify(self._accuracy_computer_mock, times=2).compute() assert_that(self._model_trainer.valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.step_valid_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.step_test_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.valid_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.test_loss, equal_to({self.LOSS_NAME: ZERO})) def test_should_compute_valid_metric_and_update_state(self): metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0) self._model_trainer.update_valid_metrics(metric) assert_that(self._model_trainer.valid_metrics["Accuracy"], equal_to(ONE)) assert_that(self._model_trainer.step_valid_metrics["Accuracy"], equal_to(ONE)) verify(self._accuracy_computer_mock).update( (MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0)) metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_1, TARGET_CLASS_0) self._model_trainer.update_valid_metrics(metric) assert_that(self._model_trainer.valid_metrics["Accuracy"], equal_to(ONE / 2)) assert_that(self._model_trainer.step_valid_metrics["Accuracy"], equal_to(ZERO)) verify(self._accuracy_computer_mock, times=2).compute() assert_that(self._model_trainer.train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.step_valid_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.step_test_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.valid_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.test_loss, equal_to({self.LOSS_NAME: ZERO})) def test_should_compute_test_metric_and_update_state(self): metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0) self._model_trainer.update_test_metrics(metric) assert_that(self._model_trainer.test_metrics["Accuracy"], equal_to(ONE)) assert_that(self._model_trainer.step_test_metrics["Accuracy"], equal_to(ONE)) verify(self._accuracy_computer_mock).update( (MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0)) metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_1, TARGET_CLASS_0) self._model_trainer.update_test_metrics(metric) assert_that(self._model_trainer.test_metrics["Accuracy"], equal_to(ONE / 2)) assert_that(self._model_trainer.step_test_metrics["Accuracy"], equal_to(ZERO)) verify(self._accuracy_computer_mock, times=2).compute() assert_that(self._model_trainer.train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.step_valid_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.step_test_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.valid_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.test_loss, equal_to({self.LOSS_NAME: ZERO})) def test_should_compute_train_loss_and_update_state(self): loss = self._model_trainer.compute_and_update_train_loss( self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0) assert_that(loss._loss, close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.step_train_loss[self.LOSS_NAME], close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.train_loss[self.LOSS_NAME], close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) loss = self._model_trainer.compute_and_update_train_loss( self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_1) assert_that(loss._loss, close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.step_train_loss[self.LOSS_NAME], close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.train_loss[self.LOSS_NAME], close_to(AVERAGED_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_valid_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.step_test_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.valid_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.test_loss, equal_to({self.LOSS_NAME: ZERO})) def test_should_compute_valid_loss_and_update_state(self): loss = self._model_trainer.compute_and_update_valid_loss( self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0) assert_that(loss._loss, close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.step_valid_loss[self.LOSS_NAME], close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.valid_loss[self.LOSS_NAME], close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) loss = self._model_trainer.compute_and_update_valid_loss( self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_1) assert_that(loss._loss, close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.step_valid_loss[self.LOSS_NAME], close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.valid_loss[self.LOSS_NAME], close_to(AVERAGED_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.step_test_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.test_loss, equal_to({self.LOSS_NAME: ZERO})) def test_should_compute_test_loss_and_update_state(self): loss = self._model_trainer.compute_and_update_test_loss( self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0) assert_that(loss._loss, close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.step_test_loss[self.LOSS_NAME], close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.test_loss[self.LOSS_NAME], close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) loss = self._model_trainer.compute_and_update_test_losses( MODEL_PREDICTION_CLASS_0, TARGET_CLASS_1) assert_that(loss[self.LOSS_NAME]._loss, close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.step_test_loss[self.LOSS_NAME], close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.test_loss[self.LOSS_NAME], close_to(AVERAGED_BINARY_CROSS_ENTROPY_LOSS, DELTA)) assert_that(self._model_trainer.train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_valid_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_test_metrics["Accuracy"], equal_to(ZERO)) assert_that(self._model_trainer.step_train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.step_valid_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.train_loss, equal_to({self.LOSS_NAME: ZERO})) assert_that(self._model_trainer.valid_loss, equal_to({self.LOSS_NAME: ZERO}))