Exemplo n.º 1
0
    def test_train_model_pipeline_young_model(
        self, serializer_mock, pipeline_mock, save_model_mock
    ):
        """Test pipeline core is not called when model is young"""
        old_model_mock = MagicMock()
        old_model_mock.age = 3

        serializer_mock_instance = MagicMock()
        serializer_mock_instance.load_model.return_value = (
            old_model_mock,
            self.model_specs,
        )
        serializer_mock.return_value = serializer_mock_instance

        report_mock = MagicMock()
        pipeline_mock.return_value = ("a", report_mock)

        train_model_pipeline(
            pj=self.pj,
            input_data=self.train_input,
            check_old_model_age=True,
            mlflow_tracking_uri="./test/unit/trained_models/mlruns",
            artifact_folder="./test/unit/trained_models",
        )
        self.assertFalse(pipeline_mock.called)
Exemplo n.º 2
0
    def test_train_model_pipeline_happy_flow(
        self, serializer_mock, pipeline_mock, save_model_mock
    ):
        """Test happy flow of the train model pipeline"""

        old_model_mock = MagicMock()
        old_model_mock.age = 8

        serializer_mock_instance = MagicMock()
        serializer_mock_instance.load_model.return_value = (
            old_model_mock,
            self.model_specs,
        )
        serializer_mock.return_value = serializer_mock_instance

        report_mock = MagicMock()
        pipeline_mock.return_value = (
            "a",
            report_mock,
            self.model_specs,
            (None, None, None),
        )

        train_model_pipeline(
            pj=self.pj,
            input_data=self.train_input,
            check_old_model_age=False,
            mlflow_tracking_uri="./test/unit/trained_models/mlruns",
            artifact_folder="./test/unit/trained_models",
        )
Exemplo n.º 3
0
    def test_train_model_pipeline_update_stored_model(self):
        """Test happy flow of the train model pipeline"""

        train_model_pipeline(
            pj=self.pj,
            input_data=self.train_input,
            check_old_model_age=False,
            mlflow_tracking_uri="./test/unit/trained_models/mlruns",
            artifact_folder="./test/unit/trained_models",
        )
Exemplo n.º 4
0
    def test_train_model_InputDataWrongColumnOrderError(self, save_model_mock):
        # change the column order
        input_data = self.train_input.iloc[:, ::-1]

        with self.assertRaises(InputDataWrongColumnOrderError):
            train_model_pipeline(
                pj=self.pj,
                input_data=input_data,
                check_old_model_age=False,
                mlflow_tracking_uri="./test/unit/trained_models/mlruns",
                artifact_folder="./test/unit/trained_models",
            )
Exemplo n.º 5
0
    def test_train_model_InputDataInsufficientError(
        self, validation_is_data_sufficient_mock, save_model_mock
    ):
        # This error is caught and then raised again and logged

        with self.assertRaises(InputDataInsufficientError):
            train_model_pipeline(
                pj=self.pj,
                input_data=self.train_input,
                check_old_model_age=False,
                mlflow_tracking_uri="./test/unit/trained_models/mlruns",
                artifact_folder="./test/unit/trained_models",
            )
Exemplo n.º 6
0
    def test_train_model_pipeline_with_default_modelspecs(self, mock_serializer):
        """We check that the modelspecs object given as default in the prediction job
        is the one given to save_model when there is no previous model saved for the
        prediction job.
        """
        mock_serializer_instance = MagicMock()
        # Mimick the absence of older model.
        mock_serializer_instance.load_model.side_effect = FileNotFoundError()
        mock_serializer.return_value = mock_serializer_instance

        pj = copy.deepcopy(self.pj)

        # hyper params that are different from the defaults.
        xgb_hyper_params = {
            "subsample": 0.9,
            "min_child_weight": 4,
            "max_depth": 8,
            "gamma": 0.5,
            "colsample_bytree": 0.85,
            "eta": 0.1,
            "training_period_days": 90,
        }
        new_hyper_params = {
            key: (value + 0.01) if isinstance(value, float) else value + 1
            for key, value in xgb_hyper_params.items()
        }

        model_specs = copy.deepcopy(self.model_specs)
        model_specs.hyper_params = new_hyper_params

        # Custom features
        model_specs.feature_modules = [
            "test.unit.feature_engineering.test_feature_adder"
        ]
        model_specs.feature_names.append("dummy_0.5")

        pj.default_modelspecs = model_specs

        train_model_pipeline(
            pj=pj,
            input_data=self.train_input,
            check_old_model_age=True,
            mlflow_tracking_uri="./test/unit/trained_models/mlruns",
            artifact_folder="./test/unit/trained_models",
        )

        saved_model_specs = mock_serializer_instance.save_model.call_args.kwargs[
            "model_specs"
        ]
        self.assertEqual(saved_model_specs, model_specs)
Exemplo n.º 7
0
    def test_train_model_No_old_model(self, serializer_mock, save_model_mock):
        # Mock an old model which is better than the new one.
        old_model_mock = MagicMock()
        old_model_mock.age = 8

        serializer_mock_instance = MagicMock()
        serializer_mock_instance.load_model.return_value = (
            old_model_mock,
            self.model_specs,
        )
        serializer_mock_instance.load_model.side_effect = FileNotFoundError()
        serializer_mock.return_value = serializer_mock_instance

        train_model_pipeline(
            pj=self.pj,
            input_data=self.train_input,
            check_old_model_age=True,
            mlflow_tracking_uri="./test/unit/trained_models/mlruns",
            artifact_folder="./test/unit/trained_models",
        )
        self.assertEqual(len(serializer_mock_instance.method_calls), 3)
Exemplo n.º 8
0
    def test_train_model_pipeline_with_save_train_forecasts(self, mock_serializer):
        """We check that the modelspecs object given as default in the prediction job
        is the one given to save_model when there is no previous model saved for the
        prediction job.
        """

        mock_serializer_instance = MagicMock()
        # Mimick the absence of older model.
        mock_serializer_instance.load_model.side_effect = FileNotFoundError()
        mock_serializer.return_value = mock_serializer_instance

        pj = copy.deepcopy(self.pj)
        # hyper params that are different from the defaults.

        datasets = train_model_pipeline(
            pj=pj,
            input_data=self.train_input,
            check_old_model_age=True,
            mlflow_tracking_uri="./test/unit/trained_models/mlruns",
            artifact_folder="./test/unit/trained_models",
        )

        self.assertIsNone(datasets)

        pj.save_train_forecasts = True

        datasets = train_model_pipeline(
            pj=pj,
            input_data=self.train_input,
            check_old_model_age=True,
            mlflow_tracking_uri="./test/unit/trained_models/mlruns",
            artifact_folder="./test/unit/trained_models",
        )
        self.assertIsNotNone(datasets)

        for dataset in datasets:
            self.assertIn("forecast", dataset.columns)
Exemplo n.º 9
0
    def test_train_model_log_new_model_better(self, serializer_mock, save_model_mock):
        # Mock an old model which is better than the new one.
        old_model_mock = MagicMock()
        old_model_mock.age = 8

        serializer_mock_instance = MagicMock()
        serializer_mock_instance.load_model.return_value = (
            old_model_mock,
            self.model_specs,
        )
        serializer_mock.return_value = serializer_mock_instance
        old_model_mock.score.return_value = 0.1

        result = train_model_pipeline(
            pj=self.pj,
            input_data=self.train_input,
            check_old_model_age=True,
            mlflow_tracking_uri="./test/unit/trained_models/mlruns",
            artifact_folder="./test/unit/trained_models",
        )
        self.assertIsNone(result)
        self.assertEqual(len(serializer_mock_instance.method_calls), 3)
def train_model_task(
    pj: PredictionJobDataClass,
    context: TaskContext,
    check_old_model_age: bool = DEFAULT_CHECK_MODEL_AGE,
    datetime_start: datetime = None,
    datetime_end: datetime = None,
) -> None:
    """Train model task.

    Top level task that trains a new model and makes sure the beast available model is
    stored. On this task level all database and context manager dependencies are resolved.

    Expected prediction job keys:  "id", "model", "lat", "lon", "name"

    Args:
        pj (PredictionJobDataClass): Prediction job
        context (TaskContext): Contect object that holds a config manager and a
            database connection.
        check_old_model_age (bool): check if model is too young to be retrained
    """
    # Get the paths for storing model and reports from the config manager
    mlflow_tracking_uri = context.config.paths.mlflow_tracking_uri
    context.logger.debug(f"MLflow tracking uri: {mlflow_tracking_uri}")
    artifact_folder = context.config.paths.artifact_folder
    context.logger.debug(f"Artifact folder: {artifact_folder}")

    context.perf_meter.checkpoint("Added metadata to PredictionJob")

    # Define start and end of the training input data
    if datetime_end is None:
        datetime_end = datetime.utcnow()
    if datetime_start is None:
        datetime_start = datetime_end - timedelta(days=TRAINING_PERIOD_DAYS)

    # todo: See if we can check model age before getting the data
    # Get training input data from database
    input_data = context.database.get_model_input(
        pid=pj["id"],
        location=[pj["lat"], pj["lon"]],
        datetime_start=datetime_start,
        datetime_end=datetime_end,
    )

    context.perf_meter.checkpoint("Retrieved timeseries input")

    # Excecute the model training pipeline
    data_sets = train_model_pipeline(
        pj,
        input_data,
        check_old_model_age=check_old_model_age,
        mlflow_tracking_uri=mlflow_tracking_uri,
        artifact_folder=artifact_folder,
    )

    if pj.save_train_forecasts:
        if data_sets is None:
            raise RuntimeError("Forecasts were not retrieved")
        if not hasattr(context.database, "write_train_forecasts"):
            raise RuntimeError(
                "Database connector does dot support 'write_train_forecasts' while "
                "'save_train_forecasts option was activated.'")
        context.database.write_train_forecasts(pj, data_sets)

    context.perf_meter.checkpoint("Model trained")