コード例 #1
0
ファイル: unet.py プロジェクト: sami-ets/DeepNormalize
    def _train_s(self, S: ModelTrainer, inputs, target, backward=True):
        S.zero_grad()

        target_ohe = to_onehot(torch.squeeze(target, dim=1).long(), num_classes=4)
        target = torch.squeeze(target, dim=1).long()

        seg_pred = torch.nn.functional.softmax(S.forward(inputs), dim=1)

        loss_S = S.compute_loss("DiceLoss", seg_pred, target_ohe)
        S.update_train_loss("DiceLoss", loss_S.mean())

        metrics = S.compute_metrics(seg_pred, target)
        metrics["Dice"] = metrics["Dice"].mean()
        metrics["IoU"] = metrics["IoU"].mean()
        S.update_train_metrics(metrics)

        if backward:
            loss_S.mean().backward()
            S.step()

        return seg_pred, loss_S
コード例 #2
0
class ModelTrainerTest(unittest.TestCase):
    MODEL_NAME_1 = "Harry Potter"
    MODEL_NAME_2 = "Hagrid"
    LOSS_NAME = "CrossEntropy"

    def setUp(self) -> None:
        self._model = SimpleNet()
        self._criterion_mock = spy(nn.CrossEntropyLoss())
        self._optimizer_mock = spy(SGD(self._model.parameters(), lr=0.001))
        self._scheduler_mock = mock(lr_scheduler)
        self._accuracy_computer_mock = spy(Accuracy())
        self._recall_computer_mock = spy(Recall())
        self._gradient_clipping_strategy = None

        self._model_trainer = ModelTrainer(
            self.MODEL_NAME_1, self._model,
            {self.LOSS_NAME: self._criterion_mock}, self._optimizer_mock,
            self._scheduler_mock, {
                "Accuracy": self._accuracy_computer_mock,
                "Recall": self._recall_computer_mock
            }, self._gradient_clipping_strategy)

        self._gradient_clipping_strategy = mock(GradientClippingStrategy)

    def tearDown(self) -> None:
        super().tearDown()

    def test_should_compute_train_metric_and_update_state(self):
        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_0,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_train_metrics(metric)
        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ONE))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ONE))
        verify(self._accuracy_computer_mock).update(
            (MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0))

        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_1,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_train_metrics(metric)
        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ONE / 2))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))

        verify(self._accuracy_computer_mock, times=2).compute()

        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_valid_metric_and_update_state(self):
        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_0,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_valid_metrics(metric)
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ONE))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ONE))
        verify(self._accuracy_computer_mock).update(
            (MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0))

        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_1,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_valid_metrics(metric)
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ONE / 2))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))

        verify(self._accuracy_computer_mock, times=2).compute()
        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_test_metric_and_update_state(self):
        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_0,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_test_metrics(metric)
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ONE))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ONE))
        verify(self._accuracy_computer_mock).update(
            (MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0))

        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_1,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_test_metrics(metric)
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ONE / 2))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))

        verify(self._accuracy_computer_mock, times=2).compute()

        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_train_loss_and_update_state(self):
        loss = self._model_trainer.compute_and_update_train_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0)

        assert_that(loss._loss,
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_train_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.train_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        loss = self._model_trainer.compute_and_update_train_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_1)
        assert_that(loss._loss,
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_train_loss[self.LOSS_NAME],
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.train_loss[self.LOSS_NAME],
                    close_to(AVERAGED_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_valid_loss_and_update_state(self):
        loss = self._model_trainer.compute_and_update_valid_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0)

        assert_that(loss._loss,
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_valid_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.valid_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        loss = self._model_trainer.compute_and_update_valid_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_1)
        assert_that(loss._loss,
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_valid_loss[self.LOSS_NAME],
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.valid_loss[self.LOSS_NAME],
                    close_to(AVERAGED_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))

        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_test_loss_and_update_state(self):
        loss = self._model_trainer.compute_and_update_test_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0)

        assert_that(loss._loss,
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_test_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.test_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        loss = self._model_trainer.compute_and_update_test_losses(
            MODEL_PREDICTION_CLASS_0, TARGET_CLASS_1)
        assert_that(loss[self.LOSS_NAME]._loss,
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_test_loss[self.LOSS_NAME],
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.test_loss[self.LOSS_NAME],
                    close_to(AVERAGED_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))