Beispiel #1
0
    def _train_g(self, G: ModelTrainer, real, backward=True):
        G.zero_grad()

        gen_pred = torch.nn.functional.sigmoid(G.forward(real))

        loss_G = G.compute_and_update_train_loss("MSELoss", gen_pred, real)

        metric = G.compute_metric("MeanSquaredError", gen_pred, real)
        G.update_train_metric("MeanSquaredError", metric / 32768)

        if backward:
            loss_G.backward()
            G.step()

        return gen_pred
Beispiel #2
0
class ModelTrainerTest(unittest.TestCase):
    MODEL_NAME_1 = "Harry Potter"
    MODEL_NAME_2 = "Hagrid"
    LOSS_NAME = "CrossEntropy"

    def setUp(self) -> None:
        self._model = SimpleNet()
        self._criterion_mock = spy(nn.CrossEntropyLoss())
        self._optimizer_mock = spy(SGD(self._model.parameters(), lr=0.001))
        self._scheduler_mock = mock(lr_scheduler)
        self._accuracy_computer_mock = spy(Accuracy())
        self._recall_computer_mock = spy(Recall())
        self._gradient_clipping_strategy = None

        self._model_trainer = ModelTrainer(
            self.MODEL_NAME_1, self._model,
            {self.LOSS_NAME: self._criterion_mock}, self._optimizer_mock,
            self._scheduler_mock, {
                "Accuracy": self._accuracy_computer_mock,
                "Recall": self._recall_computer_mock
            }, self._gradient_clipping_strategy)

        self._gradient_clipping_strategy = mock(GradientClippingStrategy)

    def tearDown(self) -> None:
        super().tearDown()

    def test_should_compute_train_metric_and_update_state(self):
        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_0,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_train_metrics(metric)
        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ONE))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ONE))
        verify(self._accuracy_computer_mock).update(
            (MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0))

        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_1,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_train_metrics(metric)
        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ONE / 2))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))

        verify(self._accuracy_computer_mock, times=2).compute()

        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_valid_metric_and_update_state(self):
        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_0,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_valid_metrics(metric)
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ONE))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ONE))
        verify(self._accuracy_computer_mock).update(
            (MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0))

        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_1,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_valid_metrics(metric)
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ONE / 2))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))

        verify(self._accuracy_computer_mock, times=2).compute()
        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_test_metric_and_update_state(self):
        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_0,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_test_metrics(metric)
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ONE))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ONE))
        verify(self._accuracy_computer_mock).update(
            (MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0))

        metric = self._model_trainer.compute_metrics(MODEL_PREDICTION_CLASS_1,
                                                     TARGET_CLASS_0)
        self._model_trainer.update_test_metrics(metric)
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ONE / 2))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))

        verify(self._accuracy_computer_mock, times=2).compute()

        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_train_loss_and_update_state(self):
        loss = self._model_trainer.compute_and_update_train_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0)

        assert_that(loss._loss,
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_train_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.train_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        loss = self._model_trainer.compute_and_update_train_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_1)
        assert_that(loss._loss,
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_train_loss[self.LOSS_NAME],
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.train_loss[self.LOSS_NAME],
                    close_to(AVERAGED_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_valid_loss_and_update_state(self):
        loss = self._model_trainer.compute_and_update_valid_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0)

        assert_that(loss._loss,
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_valid_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.valid_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        loss = self._model_trainer.compute_and_update_valid_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_1)
        assert_that(loss._loss,
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_valid_loss[self.LOSS_NAME],
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.valid_loss[self.LOSS_NAME],
                    close_to(AVERAGED_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))

        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.test_loss,
                    equal_to({self.LOSS_NAME: ZERO}))

    def test_should_compute_test_loss_and_update_state(self):
        loss = self._model_trainer.compute_and_update_test_loss(
            self.LOSS_NAME, MODEL_PREDICTION_CLASS_0, TARGET_CLASS_0)

        assert_that(loss._loss,
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_test_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.test_loss[self.LOSS_NAME],
                    close_to(MINIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        loss = self._model_trainer.compute_and_update_test_losses(
            MODEL_PREDICTION_CLASS_0, TARGET_CLASS_1)
        assert_that(loss[self.LOSS_NAME]._loss,
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.step_test_loss[self.LOSS_NAME],
                    close_to(MAXIMUM_BINARY_CROSS_ENTROPY_LOSS, DELTA))
        assert_that(self._model_trainer.test_loss[self.LOSS_NAME],
                    close_to(AVERAGED_BINARY_CROSS_ENTROPY_LOSS, DELTA))

        assert_that(self._model_trainer.train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_valid_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_test_metrics["Accuracy"],
                    equal_to(ZERO))
        assert_that(self._model_trainer.step_train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.step_valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.train_loss,
                    equal_to({self.LOSS_NAME: ZERO}))
        assert_that(self._model_trainer.valid_loss,
                    equal_to({self.LOSS_NAME: ZERO}))