def test_get_gridsearch(gridsearch_params, expected_estimator, expected_error):
    if expected_error is not None:
        with pytest.raises(expected_error):
            get_gridsearch(**gridsearch_params)
    else:
        res = get_gridsearch(**gridsearch_params)

        print(res)
        assert isinstance(res, GridSearchCV)
        assert isinstance(res.cv, FinerTimeSplit)
        assert res.error_score is np.nan
        assert res.refit is False

        assert isinstance(res.estimator["exog_passthrough"],
                          expected_estimator["exog_passthrough"])
        assert isinstance(res.estimator["holiday"],
                          expected_estimator["holiday"])
        if expected_estimator["holiday_step"]:
            assert all([
                isinstance(holiday_step[1], expected_estimator["holiday_step"])
                for holiday_step in res.estimator["holiday"].steps
            ])
            assert all([(holiday_step[1].country_code is code) &
                        (holiday_step[1].country_code_column is col)
                        for holiday_step, code, col in zip(
                            res.estimator["holiday"].steps,
                            expected_estimator["holiday_steps_codes"],
                            expected_estimator["holiday_steps_columns"],
                        )])

        assert isinstance(res.estimator["model"], expected_estimator["model"])
def test_add_model_to_gridsearch():
    gs = get_gridsearch(frequency="D", sklearn_models=False)

    model = ProphetWrapper()
    gs = add_model_to_gridsearch(model, gs)

    assert len(gs.param_grid) == 1
    assert str(gs.param_grid[0]["model"][0].get_params()) == str(model.get_params())

    gs = get_gridsearch(frequency="D", sklearn_models=False)

    model = [ProphetWrapper(), ProphetWrapper(clip_predictions_lower=0.0)]

    gs = add_model_to_gridsearch(model, gs)

    assert len(gs.param_grid) == 2
    assert str(gs.param_grid[0]["model"][0].get_params()) == str(model[0].get_params())
    assert str(gs.param_grid[1]["model"][0].get_params()) == str(model[1].get_params())
def test_get_gridsearch(gridsearch_params, expected_estimator, expected_error):
    if expected_error is not None:
        with pytest.raises(expected_error):
            get_gridsearch(**gridsearch_params)
    else:
        res = get_gridsearch(**gridsearch_params)

        print(res)
        assert isinstance(res, GridSearchCV)
        assert isinstance(res.cv, FinerTimeSplit)
        assert res.error_score is np.nan
        assert res.refit is False

        assert isinstance(res.estimator["exog_passthrough"],
                          expected_estimator["exog_passthrough"])
        assert isinstance(res.estimator["holiday"],
                          expected_estimator["holiday"])
        assert isinstance(res.estimator["model"], expected_estimator["model"])