Exemplo n.º 1
0
def check_methods_do_not_change_state(Estimator):
    # Check that methods that are not supposed to change attributes of the
    # estimators do not change anything (including hyper-parameters and
    # fitted parameters)
    estimator = _construct_instance(Estimator)
    set_random_state(estimator)

    fit_args = _make_args(estimator, "fit")
    estimator.fit(*fit_args)
    dict_before = estimator.__dict__.copy()

    for method in NON_STATE_CHANGING_METHODS:
        if hasattr(estimator, method):
            args = _make_args(estimator, method)
            getattr(estimator, method)(*args)

            if method == "transform" and _has_tag(Estimator,
                                                  "fit-in-transform"):
                # Some transformers fit during transform, as they apply
                # some transformation to each series passed to transform,
                # so transform will actually change the state of these estimator.
                continue

            assert (
                estimator.__dict__ == dict_before
            ), f"Estimator: {estimator} changes __dict__ during {method}"
Exemplo n.º 2
0
def check_series_to_primitive_transform_multivariate(Estimator):
    n_columns = 3
    if _has_tag(Estimator, "univariate-only"):
        _check_raises_error(Estimator, n_columns=n_columns)
    else:
        out = _construct_fit_transform(Estimator, n_columns=n_columns)
        assert isinstance(out, (pd.Series, np.ndarray))
        assert out.shape == (n_columns,)
Exemplo n.º 3
0
def check_panel_to_tabular_transform_multivariate(Estimator):
    n_instances = 5
    if _has_tag(Estimator, "univariate-only"):
        _check_raises_error(Estimator, n_instances=n_instances, n_columns=3)
    else:
        out = _construct_fit_transform(Estimator, n_instances=n_instances, n_columns=3)
        assert isinstance(out, (pd.DataFrame, np.ndarray))
        assert out.shape[0] == n_instances
Exemplo n.º 4
0
def _check_raises_error(Estimator, **kwargs):
    with pytest.raises(ValueError, match=r"univariate"):
        if _has_tag(Estimator, "fit-in-transform"):
            # As some estimators have an empty fit method, we here check if they
            # raise the appropriate error in transform rather than fit.
            _construct_fit_transform(Estimator, **kwargs)
        else:
            # All other estimators should raise the error in fit.
            _construct_fit(Estimator, **kwargs)
Exemplo n.º 5
0
def check_series_to_series_transform_multivariate(Estimator):
    n_columns = 3
    n_timepoints = 5
    if _has_tag(Estimator, "univariate-only"):
        _check_raises_error(Estimator, n_timepoints=n_timepoints, n_columns=n_columns)
    else:
        out = _construct_fit_transform(
            Estimator, n_timepoints=n_timepoints, n_columns=n_columns
        )
        assert isinstance(out, (pd.DataFrame, np.ndarray))
        assert out.shape == (n_timepoints, n_columns)
Exemplo n.º 6
0
def _yield_transformer_checks(Estimator):
    yield from all_transformer_checks
    if hasattr(Estimator, "inverse_transform"):
        yield check_transform_inverse_transform_equivalent
    if issubclass(Estimator, _SeriesToPrimitivesTransformer):
        yield from series_to_primitive_checks
    if issubclass(Estimator, _SeriesToSeriesTransformer):
        yield from series_to_series_checks
    if issubclass(Estimator, _PanelToTabularTransformer):
        yield from panel_to_tabular_checks
    if issubclass(Estimator, _PanelToPanelTransformer):
        yield from panel_to_panel_checks
    if _has_tag(Estimator, "transform-returns-same-time-index"):
        yield check_transform_returns_same_time_index