예제 #1
0
def test_get_silverkite_hyperparameter_grid(model_components_param, silverkite, silverkite_diagnostics):
    template = SilverkiteTemplate()
    template.config = template.apply_forecast_config_defaults()
    hyperparameter_grid = template.get_hyperparameter_grid()
    expected_grid = {
        "estimator__silverkite": [SilverkiteForecast()],
        "estimator__silverkite_diagnostics": [SilverkiteDiagnostics()],
        "estimator__origin_for_time_vars": [None],
        "estimator__extra_pred_cols": [["ct1"]],
        "estimator__train_test_thresh": [None],
        "estimator__training_fraction": [None],
        "estimator__fit_algorithm_dict": [{
            "fit_algorithm": "linear",
            "fit_algorithm_params": None}],
        "estimator__daily_event_df_dict": [None],
        "estimator__fs_components_df": [pd.DataFrame({
            "name": ["tod", "tow", "tom", "toq", "toy"],
            "period": [24.0, 7.0, 1.0, 1.0, 1.0],
            "order": [3, 3, 1, 1, 5],
            "seas_names": ["daily", "weekly", "monthly", "quarterly", "yearly"]})],
        "estimator__autoreg_dict": [None],
        "estimator__changepoints_dict": [None],
        "estimator__seasonality_changepoints_dict": [None],
        "estimator__changepoint_detector": [None],
        "estimator__min_admissible_value": [None],
        "estimator__max_admissible_value": [None],
        "estimator__uncertainty_dict": [None],
    }
    assert_equal(hyperparameter_grid, expected_grid, ignore_keys={"estimator__silverkite": None, "estimator__silverkite_diagnostics": None})
    assert hyperparameter_grid["estimator__silverkite"][0] != silverkite
    assert hyperparameter_grid["estimator__silverkite_diagnostics"][0] != silverkite_diagnostics

    # Tests auto-list conversion
    template.config.model_components_param = model_components_param
    template.time_properties = {"origin_for_time_vars": 2020}
    hyperparameter_grid = template.get_hyperparameter_grid()
    expected_grid = {
        "estimator__silverkite": [silverkite],
        "estimator__silverkite_diagnostics": [silverkite_diagnostics],
        "estimator__origin_for_time_vars": [2020],
        "estimator__extra_pred_cols": [["ct1"], ["ct2"], ["regressor1", "regressor3"]],
        "estimator__train_test_thresh": [None],
        "estimator__training_fraction": [None],
        "estimator__fit_algorithm_dict": [{
            "fit_algorithm": "linear",
            "fit_algorithm_params": None,
        }],
        "estimator__daily_event_df_dict": [None],
        "estimator__fs_components_df": [None],
        "estimator__autoreg_dict": [None],
        "estimator__changepoints_dict": [{
            "method": "uniform",
            "n_changepoints": 20,
        }],
        "estimator__seasonality_changepoints_dict": [None],
        "estimator__changepoint_detector": [None],
        "estimator__min_admissible_value": [None],
        "estimator__max_admissible_value": [4],
        "estimator__uncertainty_dict": [{
            "uncertainty_method": "simple_conditional_residuals"
        }],
    }
    assert_equal(hyperparameter_grid, expected_grid)

    # Tests hyperparameter_override
    template.config.model_components_param.hyperparameter_override = [
        {
            "input__response__null__max_frac": 0.1,
            "estimator__min_admissible_value": [2],
            "estimator__extra_pred_cols": ["override_estimator__extra_pred_cols"],
        },
        {},
        {
            "estimator__extra_pred_cols": ["val1", "val2"],
            "estimator__origin_for_time_vars": [2019],
        },
        None
    ]
    template.time_properties = {"origin_for_time_vars": 2020}
    hyperparameter_grid = template.get_hyperparameter_grid()
    expected_grid["estimator__origin_for_time_vars"] = [2020]
    updated_grid1 = expected_grid.copy()
    updated_grid1["input__response__null__max_frac"] = [0.1]
    updated_grid1["estimator__min_admissible_value"] = [2]
    updated_grid1["estimator__extra_pred_cols"] = [["override_estimator__extra_pred_cols"]]
    updated_grid2 = expected_grid.copy()
    updated_grid2["estimator__extra_pred_cols"] = [["val1", "val2"]]
    updated_grid2["estimator__origin_for_time_vars"] = [2019]
    expected_grid = [
        updated_grid1,
        expected_grid,
        updated_grid2,
        expected_grid]
    assert_equal(hyperparameter_grid, expected_grid)
예제 #2
0
def silverkite():
    return SilverkiteForecast()
예제 #3
0
def apply_default_model_components(model_components=None,
                                   time_properties=None):
    """Sets default values for ``model_components``.

    Parameters
    ----------
    model_components : :class:`~greykite.framework.templates.autogen.forecast_config.ModelComponentsParam` or None, default None
        Configuration of model growth, seasonality, events, etc.
        See :func:`~greykite.framework.templates.silverkite_templates.silverkite_template` for details.
    time_properties : `dict` [`str`, `any`] or None, default None
        Time properties dictionary (likely produced by
        `~greykite.common.time_properties_forecast.get_forecast_time_properties`)
        with keys:

        ``"period"`` : `int`
            Period of each observation (i.e. minimum time between observations, in seconds).
        ``"simple_freq"`` : `SimpleTimeFrequencyEnum`
            ``SimpleTimeFrequencyEnum`` member corresponding to data frequency.
        ``"num_training_points"`` : `int`
            Number of observations for training.
        ``"num_training_days"`` : `int`
            Number of days for training.
        ``"start_year"`` : `int`
            Start year of the training period.
        ``"end_year"`` : `int`
            End year of the forecast period.
        ``"origin_for_time_vars"`` : `float`
            Continuous time representation of the first date in ``df``.

    Returns
    -------
    model_components : :class:`~greykite.framework.templates.autogen.forecast_config.ModelComponentsParam`
        The provided ``model_components`` with default values set
    """
    if model_components is None:
        model_components = ModelComponentsParam()
    else:
        # makes a copy to avoid mutating input
        model_components = dataclasses.replace(model_components)

    # sets default values
    default_seasonality = {
        "fs_components_df": [
            pd.DataFrame({
                "name": ["tod", "tow", "tom", "toq", "toy"],
                "period": [24.0, 7.0, 1.0, 1.0, 1.0],
                "order": [3, 3, 1, 1, 5],
                "seas_names":
                ["daily", "weekly", "monthly", "quarterly", "yearly"]
            })
        ],
    }
    model_components.seasonality = update_dictionary(
        default_seasonality,
        overwrite_dict=model_components.seasonality,
        allow_unknown_keys=False)

    # model_components.growth must be empty.
    # Pass growth terms via `extra_pred_cols` instead.
    default_growth = {}
    model_components.growth = update_dictionary(
        default_growth,
        overwrite_dict=model_components.growth,
        allow_unknown_keys=False)

    default_events = {
        "daily_event_df_dict": [None],
    }
    model_components.events = update_dictionary(
        default_events,
        overwrite_dict=model_components.events,
        allow_unknown_keys=False)

    default_changepoints = {
        "changepoints_dict": [None],
        "seasonality_changepoints_dict": [None],
        # Not allowed, to prevent leaking future information
        # into the past. Pass `changepoints_dict` with method="auto" for
        # automatic detection.
        # "changepoint_detector": [None],
    }
    model_components.changepoints = update_dictionary(
        default_changepoints,
        overwrite_dict=model_components.changepoints,
        allow_unknown_keys=False)

    default_autoregression = {
        "autoreg_dict": [None],
    }
    model_components.autoregression = update_dictionary(
        default_autoregression,
        overwrite_dict=model_components.autoregression,
        allow_unknown_keys=False)

    default_regressors = {}
    model_components.regressors = update_dictionary(
        default_regressors,
        overwrite_dict=model_components.regressors,
        allow_unknown_keys=False)

    default_lagged_regressors = {
        "lagged_regressor_dict": [None],
    }
    model_components.lagged_regressors = update_dictionary(
        default_lagged_regressors,
        overwrite_dict=model_components.lagged_regressors,
        allow_unknown_keys=False)

    default_uncertainty = {
        "uncertainty_dict": [None],
    }
    model_components.uncertainty = update_dictionary(
        default_uncertainty,
        overwrite_dict=model_components.uncertainty,
        allow_unknown_keys=False)

    if time_properties is not None:
        origin_for_time_vars = time_properties.get("origin_for_time_vars")
    else:
        origin_for_time_vars = None

    default_custom = {
        "silverkite":
        [SilverkiteForecast()],  # NB: sklearn creates a copy in grid search
        "silverkite_diagnostics": [SilverkiteDiagnostics()],
        # The same origin for every split, based on start year of full dataset.
        # To use first date of each training split, set to `None` in model_components.
        "origin_for_time_vars": [origin_for_time_vars],
        "extra_pred_cols": ["ct1"],  # linear growth
        "fit_algorithm_dict": [{
            "fit_algorithm": "linear",
            "fit_algorithm_params": None,
        }],
        "min_admissible_value": [None],
        "max_admissible_value": [None],
    }
    model_components.custom = update_dictionary(
        default_custom,
        overwrite_dict=model_components.custom,
        allow_unknown_keys=False)

    # sets to {} if None, for each item if
    # `model_components.hyperparameter_override` is a list of dictionaries
    model_components.hyperparameter_override = update_dictionaries(
        {}, overwrite_dicts=model_components.hyperparameter_override)

    return model_components
예제 #4
0
def test_apply_default_model_components(model_components_param, silverkite, silverkite_diagnostics):
    model_components = apply_default_model_components()
    assert_equal(model_components.seasonality, {
        "fs_components_df": [pd.DataFrame({
            "name": ["tod", "tow", "tom", "toq", "toy"],
            "period": [24.0, 7.0, 1.0, 1.0, 1.0],
            "order": [3, 3, 1, 1, 5],
            "seas_names": ["daily", "weekly", "monthly", "quarterly", "yearly"]})],
    })
    assert model_components.growth == {}
    assert model_components.events == {
        "daily_event_df_dict": [None],
    }
    assert model_components.changepoints == {
        "changepoints_dict": [None],
        "seasonality_changepoints_dict": [None],
    }
    assert model_components.autoregression == {
        "autoreg_dict": [None],
    }
    assert model_components.regressors == {}
    assert model_components.uncertainty == {
        "uncertainty_dict": [None],
    }
    assert_equal(model_components.custom, {
        "silverkite": [SilverkiteForecast()],
        "silverkite_diagnostics": [SilverkiteDiagnostics()],
        "origin_for_time_vars": [None],
        "extra_pred_cols": ["ct1"],  # linear growth
        "fit_algorithm_dict": [{
            "fit_algorithm": "linear",
            "fit_algorithm_params": None,
        }],
        "min_admissible_value": [None],
        "max_admissible_value": [None],
    }, ignore_keys={
        "silverkite": None,
        "silverkite_diagnostics": None
    })
    assert model_components.custom["silverkite"][0] != silverkite  # a different instance was created
    assert model_components.custom["silverkite_diagnostics"][0] != silverkite_diagnostics

    # overwrite some parameters
    time_properties = {
        "origin_for_time_vars": 2020
    }
    original_components = dataclasses.replace(model_components_param)  # creates a copy
    updated_components = apply_default_model_components(
        model_components=model_components_param,
        time_properties=time_properties)
    assert original_components == model_components_param  # not mutated by the function
    assert updated_components.seasonality == model_components_param.seasonality
    assert updated_components.events == {
        "daily_event_df_dict": [None],
    }
    assert updated_components.changepoints == {
        "changepoints_dict": {  # combination of defaults and provided params
            "method": "uniform",
            "n_changepoints": 20,
        },
        "seasonality_changepoints_dict": [None],
    }
    assert updated_components.autoregression == {"autoreg_dict": [None]}
    assert updated_components.uncertainty == model_components_param.uncertainty
    assert updated_components.custom == {  # combination of defaults and provided params
        "silverkite": silverkite,  # the same object that was passed in (not a copy)
        "silverkite_diagnostics": silverkite_diagnostics,
        "origin_for_time_vars": [time_properties["origin_for_time_vars"]],  # from time_properties
        "extra_pred_cols": [["ct1"], ["ct2"], ["regressor1", "regressor3"]],
        "max_admissible_value": 4,
        "fit_algorithm_dict": [{
            "fit_algorithm": "linear",
            "fit_algorithm_params": None,
        }],
        "min_admissible_value": [None],
    }

    # `time_properties` without start_year key
    updated_components = apply_default_model_components(
        model_components=model_components_param,
        time_properties={})
    assert updated_components.custom["origin_for_time_vars"] == [None]

    updated_components = apply_default_model_components(
        model_components=ModelComponentsParam(
            autoregression={
                "autoreg_dict": {
                    "lag_dict": {"orders": [7]},
                    "agg_lag_dict": {
                        "orders_list": [[7, 7*2, 7*3]],
                        "interval_list": [(7, 7*2)]},
                    "series_na_fill_func": lambda s: s.bfill().ffill()}
            })
    )

    autoreg_dict = updated_components.autoregression["autoreg_dict"]
    assert autoreg_dict["lag_dict"] == {"orders": [7]}
    assert autoreg_dict["agg_lag_dict"]["orders_list"] == [[7, 14, 21]]
    assert autoreg_dict["agg_lag_dict"]["interval_list"] == [(7, 14)]
예제 #5
0
def test_silverkite_with_components_hourly_data():
    """Tests get_components, plot_components, plot_trend,
    plot_seasonalities with hourly data
    """
    hourly_data = generate_df_with_reg_for_tests(
        freq="H",
        periods=24 * 4,
        train_start_date=datetime.datetime(2018, 1, 1),
        conti_year_origin=2018)
    train_df = hourly_data.get("train_df").copy()
    params_hourly = params_components()

    # converts into parameters for `forecast_silverkite`
    coverage = params_hourly.pop("coverage")
    model = BaseSilverkiteEstimator(
        coverage=coverage, uncertainty_dict=params_hourly["uncertainty_dict"])
    model.fit(X=train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL)
    silverkite = SilverkiteForecast()
    model.model_dict = silverkite.forecast(df=train_df,
                                           time_col=cst.TIME_COL,
                                           value_col=cst.VALUE_COL,
                                           **params_hourly)
    model.finish_fit()

    # Test plot_components
    with pytest.warns(Warning) as record:
        title = "Custom component plot"
        fig = model.plot_components(
            names=["trend", "DAILY_SEASONALITY", "DUMMY"], title=title)
        expected_rows = 3 + 1  # includes changepoints
        assert len(fig.data) == expected_rows
        assert [fig.data[i].name for i in range(expected_rows)] == \
               [cst.VALUE_COL, "trend", "DAILY_SEASONALITY", "trend change point"]

        assert fig.layout.xaxis.title["text"] == cst.TIME_COL
        assert fig.layout.xaxis2.title["text"] == cst.TIME_COL
        assert fig.layout.xaxis3.title["text"] == "Hour of day"

        assert fig.layout.yaxis.title["text"] == cst.VALUE_COL
        assert fig.layout.yaxis2.title["text"] == "trend"
        assert fig.layout.yaxis3.title["text"] == "daily"

        assert fig.layout.title["text"] == title
        assert f"The following components have not been specified in the model: " \
               f"{{'DUMMY'}}, plotting the rest." in record[0].message.args[0]

    # Test plot_trend
    title = "Custom trend plot"
    fig = model.plot_trend(title=title)
    expected_rows = 2
    assert len(fig.data) == expected_rows + 1  # includes changepoints
    assert [fig.data[i].name
            for i in range(expected_rows)] == [cst.VALUE_COL, "trend"]

    assert fig.layout.xaxis.title["text"] == cst.TIME_COL
    assert fig.layout.xaxis2.title["text"] == cst.TIME_COL

    assert fig.layout.yaxis.title["text"] == cst.VALUE_COL
    assert fig.layout.yaxis2.title["text"] == "trend"

    assert fig.layout.title["text"] == title

    # Test plot_seasonalities
    with pytest.warns(Warning):
        # suppresses the warning on seasonalities removed
        title = "Custom seasonality plot"
        fig = model.plot_seasonalities(title=title)
        expected_rows = 4
        assert len(fig.data) == expected_rows
        assert [fig.data[i].name for i in range(expected_rows)] == \
               [cst.VALUE_COL, "DAILY_SEASONALITY", "WEEKLY_SEASONALITY", "YEARLY_SEASONALITY"]

        assert fig.layout.xaxis.title["text"] == cst.TIME_COL
        assert fig.layout.xaxis2.title["text"] == "Hour of day"
        assert fig.layout.xaxis3.title["text"] == "Day of week"
        assert fig.layout.xaxis4.title["text"] == "Time of year"

        assert fig.layout.yaxis.title["text"] == cst.VALUE_COL
        assert fig.layout.yaxis2.title["text"] == "daily"
        assert fig.layout.yaxis3.title["text"] == "weekly"
        assert fig.layout.yaxis4.title["text"] == "yearly"

        assert fig.layout.title["text"] == title
예제 #6
0
def test_silverkite_with_components_daily_data():
    """Tests get_components, plot_components, plot_trend,
    plot_seasonalities with daily data and missing input values.
    """
    daily_data = generate_df_with_reg_for_tests(
        freq="D",
        periods=20,
        train_start_date=datetime.datetime(2018, 1, 1),
        conti_year_origin=2018)
    train_df = daily_data["train_df"].copy()
    train_df.loc[[2, 4, 7], cst.VALUE_COL] = np.nan  # creates missing values

    params_daily = params_components()  # SilverkiteEstimator parameters
    # converts into parameters for `forecast_silverkite`
    coverage = params_daily.pop("coverage")
    # removes daily seasonality terms
    params_daily["fs_components_df"] = pd.DataFrame({
        "name": ["tow", "ct1"],
        "period": [7.0, 1.0],
        "order": [4, 5],
        "seas_names": ["weekly", "yearly"]
    })

    model = BaseSilverkiteEstimator(
        coverage=coverage, uncertainty_dict=params_daily["uncertainty_dict"])

    with pytest.raises(NotFittedError,
                       match="Call `fit` before calling `plot_components`."):
        model.plot_components()

    with pytest.warns(Warning):
        # suppress warnings from conf_interval.py and sklearn
        # a subclass's fit() method will have these steps
        model.fit(X=train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL)
        silverkite = SilverkiteForecast()
        model.model_dict = silverkite.forecast(df=train_df,
                                               time_col=cst.TIME_COL,
                                               value_col=cst.VALUE_COL,
                                               **params_daily)
        model.finish_fit()

    # Tests plot_components
    with pytest.warns(Warning) as record:
        title = "Custom component plot"
        model._set_silverkite_diagnostics_params()
        fig = model.plot_components(
            names=["trend", "YEARLY_SEASONALITY", "DUMMY"], title=title)
        expected_rows = 3
        assert len(fig.data) == expected_rows + 1  # includes changepoints
        assert [fig.data[i].name for i in range(expected_rows)] == \
               [cst.VALUE_COL, "trend", "YEARLY_SEASONALITY"]

        assert fig.layout.xaxis.title["text"] == cst.TIME_COL
        assert fig.layout.xaxis2.title["text"] == cst.TIME_COL
        assert fig.layout.xaxis3.title["text"] == "Time of year"

        assert fig.layout.yaxis.title["text"] == cst.VALUE_COL
        assert fig.layout.yaxis2.title["text"] == "trend"
        assert fig.layout.yaxis3.title["text"] == "yearly"

        assert fig.layout.title["text"] == title
        assert f"The following components have not been specified in the model: " \
               f"{{'DUMMY'}}, plotting the rest." in record[0].message.args[0]

    # Missing component error
    with pytest.raises(
            ValueError,
            match=
            "None of the provided components have been specified in the model."
    ):
        model.plot_components(names=["DUMMY"])

    # Tests plot_trend
    title = "Custom trend plot"
    fig = model.plot_trend(title=title)
    expected_rows = 2
    assert len(fig.data) == expected_rows + 1  # includes changepoints
    assert [fig.data[i].name
            for i in range(expected_rows)] == [cst.VALUE_COL, "trend"]

    assert fig.layout.xaxis.title["text"] == cst.TIME_COL
    assert fig.layout.xaxis2.title["text"] == cst.TIME_COL

    assert fig.layout.yaxis.title["text"] == cst.VALUE_COL
    assert fig.layout.yaxis2.title["text"] == "trend"

    assert fig.layout.title["text"] == title

    # Tests plot_seasonalities
    with pytest.warns(Warning):
        # suppresses the warning on seasonalities removed
        title = "Custom seasonality plot"
        fig = model.plot_seasonalities(title=title)
        expected_rows = 3
        assert len(fig.data) == expected_rows
        assert [fig.data[i].name for i in range(expected_rows)] == \
               [cst.VALUE_COL, "WEEKLY_SEASONALITY", "YEARLY_SEASONALITY"]

        assert fig.layout.xaxis.title["text"] == cst.TIME_COL
        assert fig.layout.xaxis2.title["text"] == "Day of week"
        assert fig.layout.xaxis3.title["text"] == "Time of year"

        assert fig.layout.yaxis.title["text"] == cst.VALUE_COL
        assert fig.layout.yaxis2.title["text"] == "weekly"
        assert fig.layout.yaxis3.title["text"] == "yearly"

        assert fig.layout.title["text"] == title

    # Component plot error if `fit_algorithm` is "rf" or "gradient_boosting"
    params_daily["fit_algorithm"] = "rf"
    model = BaseSilverkiteEstimator(
        coverage=coverage, uncertainty_dict=params_daily["uncertainty_dict"])
    with pytest.warns(Warning):
        # suppress warnings from conf_interval.py and sklearn
        # a subclass's fit() method will have these steps
        model.fit(X=train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL)
        model.model_dict = silverkite.forecast(df=train_df,
                                               time_col=cst.TIME_COL,
                                               value_col=cst.VALUE_COL,
                                               **params_daily)
        model.finish_fit()
    assert model.coef_ is None
    with pytest.raises(
            NotImplementedError,
            match=
            "Component plot has only been implemented for additive linear models."
    ):
        model.plot_components()

    with pytest.raises(
            NotImplementedError,
            match=
            "Component plot has only been implemented for additive linear models."
    ):
        model.plot_trend()

    with pytest.raises(
            NotImplementedError,
            match=
            "Component plot has only been implemented for additive linear models."
    ):
        model.plot_seasonalities()
예제 #7
0
def test_fit_predict(daily_data):
    """Checks fit and predict function with null model"""
    model = BaseSilverkiteEstimator(null_model_params={"strategy": "mean"})
    train_df = daily_data["train_df"]
    test_df = daily_data["test_df"]
    assert model.last_predicted_X_ is None
    assert model.cached_predictions_ is None

    with pytest.raises(NotFittedError,
                       match="Call `fit` before calling `predict`."):
        model.predict(train_df)

    # Every subclass `fit` follows these steps
    model.fit(train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL)
    # Checks that `df` is set, but other variables aren't
    assert_equal(model.df, train_df)
    assert model.pred_cols is None
    assert model.feature_cols is None
    assert model.coef_ is None

    with pytest.raises(
            ValueError,
            match="Must set `self.model_dict` before calling this function."):
        model.finish_fit()

    silverkite = SilverkiteForecast()
    model.model_dict = silverkite.forecast(df=train_df,
                                           time_col=cst.TIME_COL,
                                           value_col=cst.VALUE_COL,
                                           origin_for_time_vars=None,
                                           extra_pred_cols=None,
                                           train_test_thresh=None,
                                           training_fraction=None,
                                           fit_algorithm="linear",
                                           fit_algorithm_params=None,
                                           daily_event_df_dict=None,
                                           changepoints_dict=None,
                                           fs_components_df=pd.DataFrame({
                                               "name":
                                               ["tod", "tow", "conti_year"],
                                               "period": [24.0, 7.0, 1.0],
                                               "order": [3, 3, 5],
                                               "seas_names":
                                               ["daily", "weekly", "yearly"]
                                           }),
                                           autoreg_dict=None,
                                           min_admissible_value=None,
                                           max_admissible_value=None,
                                           uncertainty_dict=None)

    with pytest.raises(
            NotFittedError,
            match="Subclass must call `finish_fit` inside the `fit` method."):
        model.predict(train_df)
    assert model.last_predicted_X_ is not None  # attempted prediction
    assert model.cached_predictions_ is None

    model.finish_fit()
    # Checks that other variables are set
    assert_equal(model.pred_cols, model.model_dict["pred_cols"])
    assert_equal(model.feature_cols, model.model_dict["x_mat"].columns)
    assert_equal(
        model.coef_,
        pd.DataFrame(model.model_dict["ml_model"].coef_,
                     index=model.feature_cols))

    # Predicts on a new dataset
    with LogCapture(cst.LOGGER_NAME) as log_capture:
        predicted = model.predict(test_df)
        assert_equal(model.last_predicted_X_, test_df)
        assert_equal(model.cached_predictions_, predicted)
        log_capture.check()  # no log messages (not using cached predictions)

    # Uses cached predictions
    with LogCapture(cst.LOGGER_NAME) as log_capture:
        assert_equal(model.predict(test_df), predicted)
        log_capture.check(
            (cst.LOGGER_NAME, "DEBUG", "Returning cached predictions."))

    # Predicts on a different dataset
    with LogCapture(cst.LOGGER_NAME) as log_capture:
        predicted = model.predict(train_df)
        assert_equal(model.last_predicted_X_, train_df)
        assert_equal(model.cached_predictions_, predicted)
        log_capture.check()  # no log messages (not using cached predictions)

    # .fit() clears the cached result
    model.fit(train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL)
    assert model.last_predicted_X_ is None
    assert model.cached_predictions_ is None
예제 #8
0
    def __init__(
            self,
            silverkite: SilverkiteForecast = SilverkiteForecast(),
            silverkite_diagnostics: SilverkiteDiagnostics = SilverkiteDiagnostics(),
            score_func=mean_squared_error,
            coverage=None,
            null_model_params=None,
            origin_for_time_vars=None,
            extra_pred_cols=None,
            train_test_thresh=None,
            training_fraction=None,
            fit_algorithm_dict=None,
            daily_event_df_dict=None,
            fs_components_df=pd.DataFrame({
                "name": ["tod", "tow", "conti_year"],
                "period": [24.0, 7.0, 1.0],
                "order": [3, 3, 5],
                "seas_names": ["daily", "weekly", "yearly"]}),
            autoreg_dict=None,
            lagged_regressor_dict=None,
            changepoints_dict=None,
            seasonality_changepoints_dict=None,
            changepoint_detector=None,
            min_admissible_value=None,
            max_admissible_value=None,
            uncertainty_dict=None,
            normalize_method=None,
            adjust_anomalous_dict=None,
            impute_dict=None,
            regression_weight_col=None,
            forecast_horizon=None,
            simulation_based=False):
        # every subclass of BaseSilverkiteEstimator must call super().__init__
        super().__init__(
            silverkite=silverkite,
            silverkite_diagnostics=silverkite_diagnostics,
            score_func=score_func,
            coverage=coverage,
            null_model_params=null_model_params,
            uncertainty_dict=uncertainty_dict)

        # necessary to set parameters, to ensure get_params() works
        # (used in grid search)
        self.score_func = score_func
        self.coverage = coverage
        self.null_model_params = null_model_params
        self.origin_for_time_vars = origin_for_time_vars
        self.extra_pred_cols = extra_pred_cols
        self.train_test_thresh = train_test_thresh
        self.fit_algorithm_dict = fit_algorithm_dict
        self.training_fraction = training_fraction
        self.daily_event_df_dict = daily_event_df_dict
        self.fs_components_df = fs_components_df
        self.autoreg_dict = autoreg_dict
        self.lagged_regressor_dict = lagged_regressor_dict
        self.changepoints_dict = changepoints_dict
        self.seasonality_changepoints_dict = seasonality_changepoints_dict
        self.changepoint_detector = changepoint_detector
        self.min_admissible_value = min_admissible_value
        self.max_admissible_value = max_admissible_value
        self.uncertainty_dict = uncertainty_dict
        self.normalize_method = normalize_method
        self.adjust_anomalous_dict = adjust_anomalous_dict
        self.impute_dict = impute_dict
        self.regression_weight_col = regression_weight_col
        self.forecast_horizon = forecast_horizon
        self.simulation_based = simulation_based
        self.validate_inputs()
def test_silverkite_constants():
    silverkite = SilverkiteForecast()
    assert silverkite._silverkite_seasonality_enum is SilverkiteSeasonalityEnum