def __get_config_with_default_model_template_and_components( self, config: Optional[ForecastConfig] = None) -> ForecastConfig: """Gets config with default value for `model_template` and `model_components_param` if not provided. - model_template : default value is ``self.default_model_template_name``. - model_components_param : default value is an empty ModelComponentsParam(). Parameters ---------- config : :class:`~greykite.framework.templates.model_templates.ForecastConfig` or None Config object for template class to use. See :class:`~greykite.framework.templates.model_templates.ForecastConfig`. If None, uses an empty ForecastConfig. Returns ------- config : :class:`~greykite.framework.templates.model_templates.ForecastConfig` Input ``config`` with default ``model_template`` populated. If ``config.model_template`` is None, it is set to ``self.default_model_template_name``. If ``config.model_components_param`` is None, it is set to ``ModelComponentsParam()``. """ config = config if config is not None else ForecastConfig() # Unpacks list of a single element and sets default value if None. # NB: Does not call `apply_forecast_config_defaults`. # Only sets `model_template` and `model_components_param`. # The template class may have its own implementation of forecast config defaults. forecast_config_defaults = ForecastConfigDefaults() forecast_config_defaults.DEFAULT_MODEL_TEMPLATE = self.default_model_template_name config.model_template = forecast_config_defaults.apply_model_template_defaults( config.model_template) config.model_components_param = forecast_config_defaults.apply_model_components_defaults( config.model_components_param) return config
def test_apply_model_components_defaults(): """Tests apply_model_components_defaults""" assert ForecastConfigDefaults().apply_model_components_defaults( None) == ModelComponentsParam() mcp = ModelComponentsParam({"growth": "growth"}) assert ForecastConfigDefaults().apply_model_components_defaults(mcp) == mcp assert ForecastConfigDefaults().apply_model_components_defaults([mcp ]) == mcp assert ForecastConfigDefaults().apply_model_components_defaults( [None, mcp]) == [ModelComponentsParam(), mcp]
def test_apply_model_template_defaults(): """Tests apply_model_template_defaults""" assert ForecastConfigDefaults().apply_model_template_defaults( model_template=None) == "SILVERKITE" mt = "RANDOM_TEMPLATE" assert ForecastConfigDefaults().apply_model_template_defaults( model_template=mt) == mt assert ForecastConfigDefaults().apply_model_template_defaults( model_template=[mt]) == mt assert ForecastConfigDefaults().apply_model_template_defaults( model_template=[None, mt]) == ["SILVERKITE", mt]
def assert_forecast_config_json_multiple_model_componments_parameter( config: Optional[ForecastConfig] = None): """Asserts the forecast config values. This function expects a particular config and is not generic""" config = ForecastConfigDefaults().apply_forecast_config_defaults(config) assert config.model_template == [ ModelTemplateEnum.SILVERKITE.name, ModelTemplateEnum.SILVERKITE_DAILY_90.name, ModelTemplateEnum.SILVERKITE_WEEKLY.name ] assert config.evaluation_metric_param.relative_error_tolerance == 0.02 # First model_components_param model_components_param_1 = config.model_components_param[0] assert model_components_param_1.autoregression is None assert model_components_param_1.changepoints is None assert model_components_param_1.custom is None assert model_components_param_1.growth == {"growth_param": 0} assert model_components_param_1.events == {"events_param": 1} assert (model_components_param_1.hyperparameter_override is None or model_components_param_1.hyperparameter_override is None) assert model_components_param_1.regressors == { "names": ["regressor1", "regressor2"] } assert model_components_param_1.lagged_regressors is None assert model_components_param_1.seasonality == {"seas_param": 2} assert model_components_param_1.uncertainty == {"uncertainty_param": 3} # Second model_components_param model_components_param_2 = config.model_components_param[1] assert model_components_param_2.autoregression == { "autoreg_dict": { "autoreg_param": 0 } } assert model_components_param_2.changepoints is None assert model_components_param_2.custom == {"custom_param": 1} assert model_components_param_2.growth == {"growth_param": 2} assert model_components_param_2.events == {"events_param": 3} assert (model_components_param_2.hyperparameter_override == [{ "h1": 4 }, { "h2": 5 }, None] or model_components_param_2.hyperparameter_override == [{ "h1": 4 }, { "h2": 5 }, {}]) assert model_components_param_2.regressors == { "names": ["regressor1", "regressor2"] } assert model_components_param_2.lagged_regressors == { "lagged_regressor_dict": { "lag_reg_param": 0 } } assert model_components_param_2.seasonality == {"seas_param": 6} assert model_components_param_2.uncertainty == {"uncertainty_param": 7} assert config.to_dict() # runs without error
def assert_default_forecast_config(config: Optional[ForecastConfig] = None): """Asserts for the default ForecastConfig values""" try: config = ForecastConfigDefaults().apply_forecast_config_defaults( config) assert config.model_template == ModelTemplateEnum.SILVERKITE.name assert config.metadata_param.time_col == TIME_COL assert config.metadata_param.value_col == VALUE_COL assert config.evaluation_period_param.periods_between_train_test is None assert config.evaluation_period_param.cv_max_splits == EVALUATION_PERIOD_CV_MAX_SPLITS assert config.evaluation_metric_param.cv_selection_metric == EvaluationMetricEnum.MeanAbsolutePercentError.name assert config.evaluation_metric_param.cv_report_metrics == CV_REPORT_METRICS_ALL assert config.computation_param.n_jobs == COMPUTATION_N_JOBS assert config.computation_param.verbose == COMPUTATION_VERBOSE assert config.to_dict() # runs without error except Exception: fail("Config should not raise Exception")
def assert_forecast_config(config: Optional[ForecastConfig] = None): """Asserts the forecast config values. This function expects a particular config and is not generic""" config = ForecastConfigDefaults().apply_forecast_config_defaults(config) assert config.model_template == ModelTemplateEnum.SILVERKITE.name assert config.metadata_param.time_col == "custom_time_col" assert config.metadata_param.value_col == VALUE_COL assert config.metadata_param.freq is None assert config.metadata_param.date_format is None assert config.metadata_param.train_end_date is None assert config.metadata_param.anomaly_info == [{ "key": "value" }, { "key2": "value2" }] assert config.evaluation_period_param.test_horizon == 10 assert config.evaluation_period_param.periods_between_train_test == 5 assert config.evaluation_period_param.cv_horizon is None assert config.evaluation_period_param.cv_min_train_periods == 20 assert config.evaluation_period_param.cv_expanding_window is True assert config.evaluation_period_param.cv_use_most_recent_splits is None assert config.evaluation_period_param.cv_periods_between_splits is None assert config.evaluation_period_param.cv_periods_between_train_test == config.evaluation_period_param.periods_between_train_test assert config.evaluation_period_param.cv_max_splits == EVALUATION_PERIOD_CV_MAX_SPLITS assert config.evaluation_metric_param.cv_selection_metric == EvaluationMetricEnum.MeanSquaredError.name assert config.evaluation_metric_param.cv_report_metrics == [ EvaluationMetricEnum.MeanAbsoluteError.name, EvaluationMetricEnum.MeanAbsolutePercentError.name ] assert config.evaluation_metric_param.agg_periods is None assert config.evaluation_metric_param.agg_func is None assert config.evaluation_metric_param.null_model_params is None assert config.evaluation_metric_param.relative_error_tolerance == 0.02 assert config.model_components_param.autoregression == { "autoreg_dict": { "autoreg_param": 0 } } assert config.model_components_param.changepoints is None assert config.model_components_param.custom == {"custom_param": 1} assert config.model_components_param.growth == {"growth_param": 2} assert config.model_components_param.events == {"events_param": 3} assert (config.model_components_param.hyperparameter_override == [{ "h1": 4 }, { "h2": 5 }, None] or config.model_components_param.hyperparameter_override == [{ "h1": 4 }, { "h2": 5 }, {}]) assert config.model_components_param.regressors == { "names": ["regressor1", "regressor2"] } assert config.model_components_param.seasonality == {"seas_param": 6} assert config.model_components_param.uncertainty == { "uncertainty_param": 7 } assert config.computation_param.hyperparameter_budget is None assert config.computation_param.n_jobs == COMPUTATION_N_JOBS assert config.computation_param.verbose == COMPUTATION_VERBOSE assert config.to_dict() # runs without error