Пример #1
0
    def __get_config_with_default_model_template_and_components(
            self, config: Optional[ForecastConfig] = None) -> ForecastConfig:
        """Gets config with default value for `model_template` and `model_components_param` if not provided.

            - model_template : default value is ``self.default_model_template_name``.
            - model_components_param : default value is an empty ModelComponentsParam().

        Parameters
        ----------
        config : :class:`~greykite.framework.templates.model_templates.ForecastConfig` or None
            Config object for template class to use.
            See :class:`~greykite.framework.templates.model_templates.ForecastConfig`.
            If None, uses an empty ForecastConfig.

        Returns
        -------
        config : :class:`~greykite.framework.templates.model_templates.ForecastConfig`
            Input ``config`` with default ``model_template`` populated.
            If ``config.model_template`` is None, it is set to ``self.default_model_template_name``.
            If ``config.model_components_param`` is None, it is set to ``ModelComponentsParam()``.
        """
        config = config if config is not None else ForecastConfig()
        # Unpacks list of a single element and sets default value if None.
        # NB: Does not call `apply_forecast_config_defaults`.
        #   Only sets `model_template` and `model_components_param`.
        #   The template class may have its own implementation of forecast config defaults.
        forecast_config_defaults = ForecastConfigDefaults()
        forecast_config_defaults.DEFAULT_MODEL_TEMPLATE = self.default_model_template_name
        config.model_template = forecast_config_defaults.apply_model_template_defaults(
            config.model_template)
        config.model_components_param = forecast_config_defaults.apply_model_components_defaults(
            config.model_components_param)
        return config
def test_apply_model_components_defaults():
    """Tests apply_model_components_defaults"""
    assert ForecastConfigDefaults().apply_model_components_defaults(
        None) == ModelComponentsParam()
    mcp = ModelComponentsParam({"growth": "growth"})
    assert ForecastConfigDefaults().apply_model_components_defaults(mcp) == mcp
    assert ForecastConfigDefaults().apply_model_components_defaults([mcp
                                                                     ]) == mcp
    assert ForecastConfigDefaults().apply_model_components_defaults(
        [None, mcp]) == [ModelComponentsParam(), mcp]
def test_apply_model_template_defaults():
    """Tests apply_model_template_defaults"""
    assert ForecastConfigDefaults().apply_model_template_defaults(
        model_template=None) == "SILVERKITE"
    mt = "RANDOM_TEMPLATE"
    assert ForecastConfigDefaults().apply_model_template_defaults(
        model_template=mt) == mt
    assert ForecastConfigDefaults().apply_model_template_defaults(
        model_template=[mt]) == mt
    assert ForecastConfigDefaults().apply_model_template_defaults(
        model_template=[None, mt]) == ["SILVERKITE", mt]
Пример #4
0
def assert_forecast_config_json_multiple_model_componments_parameter(
        config: Optional[ForecastConfig] = None):
    """Asserts the forecast config values. This function expects a particular config and is not generic"""
    config = ForecastConfigDefaults().apply_forecast_config_defaults(config)
    assert config.model_template == [
        ModelTemplateEnum.SILVERKITE.name,
        ModelTemplateEnum.SILVERKITE_DAILY_90.name,
        ModelTemplateEnum.SILVERKITE_WEEKLY.name
    ]
    assert config.evaluation_metric_param.relative_error_tolerance == 0.02
    # First model_components_param
    model_components_param_1 = config.model_components_param[0]
    assert model_components_param_1.autoregression is None
    assert model_components_param_1.changepoints is None
    assert model_components_param_1.custom is None
    assert model_components_param_1.growth == {"growth_param": 0}
    assert model_components_param_1.events == {"events_param": 1}
    assert (model_components_param_1.hyperparameter_override is None
            or model_components_param_1.hyperparameter_override is None)
    assert model_components_param_1.regressors == {
        "names": ["regressor1", "regressor2"]
    }
    assert model_components_param_1.lagged_regressors is None
    assert model_components_param_1.seasonality == {"seas_param": 2}
    assert model_components_param_1.uncertainty == {"uncertainty_param": 3}
    # Second model_components_param
    model_components_param_2 = config.model_components_param[1]
    assert model_components_param_2.autoregression == {
        "autoreg_dict": {
            "autoreg_param": 0
        }
    }
    assert model_components_param_2.changepoints is None
    assert model_components_param_2.custom == {"custom_param": 1}
    assert model_components_param_2.growth == {"growth_param": 2}
    assert model_components_param_2.events == {"events_param": 3}
    assert (model_components_param_2.hyperparameter_override == [{
        "h1": 4
    }, {
        "h2": 5
    }, None] or model_components_param_2.hyperparameter_override == [{
        "h1": 4
    }, {
        "h2": 5
    }, {}])
    assert model_components_param_2.regressors == {
        "names": ["regressor1", "regressor2"]
    }
    assert model_components_param_2.lagged_regressors == {
        "lagged_regressor_dict": {
            "lag_reg_param": 0
        }
    }
    assert model_components_param_2.seasonality == {"seas_param": 6}
    assert model_components_param_2.uncertainty == {"uncertainty_param": 7}
    assert config.to_dict()  # runs without error
Пример #5
0
def assert_default_forecast_config(config: Optional[ForecastConfig] = None):
    """Asserts for the default ForecastConfig values"""
    try:
        config = ForecastConfigDefaults().apply_forecast_config_defaults(
            config)
        assert config.model_template == ModelTemplateEnum.SILVERKITE.name
        assert config.metadata_param.time_col == TIME_COL
        assert config.metadata_param.value_col == VALUE_COL
        assert config.evaluation_period_param.periods_between_train_test is None
        assert config.evaluation_period_param.cv_max_splits == EVALUATION_PERIOD_CV_MAX_SPLITS
        assert config.evaluation_metric_param.cv_selection_metric == EvaluationMetricEnum.MeanAbsolutePercentError.name
        assert config.evaluation_metric_param.cv_report_metrics == CV_REPORT_METRICS_ALL
        assert config.computation_param.n_jobs == COMPUTATION_N_JOBS
        assert config.computation_param.verbose == COMPUTATION_VERBOSE
        assert config.to_dict()  # runs without error
    except Exception:
        fail("Config should not raise Exception")
Пример #6
0
def assert_forecast_config(config: Optional[ForecastConfig] = None):
    """Asserts the forecast config values. This function expects a particular config and is not generic"""
    config = ForecastConfigDefaults().apply_forecast_config_defaults(config)
    assert config.model_template == ModelTemplateEnum.SILVERKITE.name
    assert config.metadata_param.time_col == "custom_time_col"
    assert config.metadata_param.value_col == VALUE_COL
    assert config.metadata_param.freq is None
    assert config.metadata_param.date_format is None
    assert config.metadata_param.train_end_date is None
    assert config.metadata_param.anomaly_info == [{
        "key": "value"
    }, {
        "key2": "value2"
    }]
    assert config.evaluation_period_param.test_horizon == 10
    assert config.evaluation_period_param.periods_between_train_test == 5
    assert config.evaluation_period_param.cv_horizon is None
    assert config.evaluation_period_param.cv_min_train_periods == 20
    assert config.evaluation_period_param.cv_expanding_window is True
    assert config.evaluation_period_param.cv_use_most_recent_splits is None
    assert config.evaluation_period_param.cv_periods_between_splits is None
    assert config.evaluation_period_param.cv_periods_between_train_test == config.evaluation_period_param.periods_between_train_test
    assert config.evaluation_period_param.cv_max_splits == EVALUATION_PERIOD_CV_MAX_SPLITS
    assert config.evaluation_metric_param.cv_selection_metric == EvaluationMetricEnum.MeanSquaredError.name
    assert config.evaluation_metric_param.cv_report_metrics == [
        EvaluationMetricEnum.MeanAbsoluteError.name,
        EvaluationMetricEnum.MeanAbsolutePercentError.name
    ]
    assert config.evaluation_metric_param.agg_periods is None
    assert config.evaluation_metric_param.agg_func is None
    assert config.evaluation_metric_param.null_model_params is None
    assert config.evaluation_metric_param.relative_error_tolerance == 0.02
    assert config.model_components_param.autoregression == {
        "autoreg_dict": {
            "autoreg_param": 0
        }
    }
    assert config.model_components_param.changepoints is None
    assert config.model_components_param.custom == {"custom_param": 1}
    assert config.model_components_param.growth == {"growth_param": 2}
    assert config.model_components_param.events == {"events_param": 3}
    assert (config.model_components_param.hyperparameter_override == [{
        "h1": 4
    }, {
        "h2": 5
    }, None] or config.model_components_param.hyperparameter_override
            == [{
                "h1": 4
            }, {
                "h2": 5
            }, {}])
    assert config.model_components_param.regressors == {
        "names": ["regressor1", "regressor2"]
    }
    assert config.model_components_param.seasonality == {"seas_param": 6}
    assert config.model_components_param.uncertainty == {
        "uncertainty_param": 7
    }
    assert config.computation_param.hyperparameter_budget is None
    assert config.computation_param.n_jobs == COMPUTATION_N_JOBS
    assert config.computation_param.verbose == COMPUTATION_VERBOSE
    assert config.to_dict()  # runs without error