def test_score_function(daily_data_with_reg): """Checks score function without null model, with regressors""" model = BaseSilverkiteEstimator() train_df = daily_data_with_reg["train_df"] test_df = daily_data_with_reg["test_df"] # every subclass `fit` follows these steps model.fit(X=train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL) silverkite = SilverkiteForecast() model.model_dict = silverkite.forecast( df=train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL, origin_for_time_vars=None, extra_pred_cols=["ct1", "regressor1", "regressor2"], train_test_thresh=None, training_fraction=None, fit_algorithm="linear", fit_algorithm_params=None, daily_event_df_dict=None, changepoints_dict=None, fs_components_df=pd.DataFrame({ "name": ["tod", "tow", "conti_year"], "period": [24.0, 7.0, 1.0], "order": [3, 3, 5], "seas_names": ["daily", "weekly", "yearly"] }), autoreg_dict=None, min_admissible_value=None, max_admissible_value=None, uncertainty_dict=None) model.finish_fit() score = model.score(test_df, test_df[cst.VALUE_COL]) pred_df = model.predict(test_df) assert list(pred_df.columns) == [cst.TIME_COL, cst.PREDICTED_COL] assert score == pytest.approx( mean_squared_error(pred_df[cst.PREDICTED_COL], test_df[cst.VALUE_COL])) assert score == pytest.approx(4.6, rel=1e-1)
def test_fit_predict(daily_data): """Checks fit and predict function with null model""" model = BaseSilverkiteEstimator(null_model_params={"strategy": "mean"}) train_df = daily_data["train_df"] test_df = daily_data["test_df"] assert model.last_predicted_X_ is None assert model.cached_predictions_ is None with pytest.raises(NotFittedError, match="Call `fit` before calling `predict`."): model.predict(train_df) # Every subclass `fit` follows these steps model.fit(train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL) # Checks that `df` is set, but other variables aren't assert_equal(model.df, train_df) assert model.pred_cols is None assert model.feature_cols is None assert model.coef_ is None with pytest.raises( ValueError, match="Must set `self.model_dict` before calling this function."): model.finish_fit() silverkite = SilverkiteForecast() model.model_dict = silverkite.forecast(df=train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL, origin_for_time_vars=None, extra_pred_cols=None, train_test_thresh=None, training_fraction=None, fit_algorithm="linear", fit_algorithm_params=None, daily_event_df_dict=None, changepoints_dict=None, fs_components_df=pd.DataFrame({ "name": ["tod", "tow", "conti_year"], "period": [24.0, 7.0, 1.0], "order": [3, 3, 5], "seas_names": ["daily", "weekly", "yearly"] }), autoreg_dict=None, min_admissible_value=None, max_admissible_value=None, uncertainty_dict=None) with pytest.raises( NotFittedError, match="Subclass must call `finish_fit` inside the `fit` method."): model.predict(train_df) assert model.last_predicted_X_ is not None # attempted prediction assert model.cached_predictions_ is None model.finish_fit() # Checks that other variables are set assert_equal(model.pred_cols, model.model_dict["pred_cols"]) assert_equal(model.feature_cols, model.model_dict["x_mat"].columns) assert_equal( model.coef_, pd.DataFrame(model.model_dict["ml_model"].coef_, index=model.feature_cols)) # Predicts on a new dataset with LogCapture(cst.LOGGER_NAME) as log_capture: predicted = model.predict(test_df) assert_equal(model.last_predicted_X_, test_df) assert_equal(model.cached_predictions_, predicted) log_capture.check() # no log messages (not using cached predictions) # Uses cached predictions with LogCapture(cst.LOGGER_NAME) as log_capture: assert_equal(model.predict(test_df), predicted) log_capture.check( (cst.LOGGER_NAME, "DEBUG", "Returning cached predictions.")) # Predicts on a different dataset with LogCapture(cst.LOGGER_NAME) as log_capture: predicted = model.predict(train_df) assert_equal(model.last_predicted_X_, train_df) assert_equal(model.cached_predictions_, predicted) log_capture.check() # no log messages (not using cached predictions) # .fit() clears the cached result model.fit(train_df, time_col=cst.TIME_COL, value_col=cst.VALUE_COL) assert model.last_predicted_X_ is None assert model.cached_predictions_ is None