コード例 #1
0
def check_series_to_series_transform_univariate(Estimator):
    n_timepoints = 15
    out = _construct_fit_transform(
        Estimator,
        n_timepoints=n_timepoints,
        add_nan=_has_tag(Estimator, "handles-missing-data"),
    )
    assert isinstance(out, (pd.Series, np.ndarray))
コード例 #2
0
def check_series_to_primitive_transform_multivariate(Estimator):
    n_columns = 3
    if _has_tag(Estimator, "univariate-only"):
        _check_raises_error(Estimator, n_columns=n_columns)
    else:
        out = _construct_fit_transform(Estimator, n_columns=n_columns)
        assert isinstance(out, (pd.Series, np.ndarray))
        assert out.shape == (n_columns, )
コード例 #3
0
def check_panel_to_tabular_transform_multivariate(Estimator):
    n_instances = 5
    if _has_tag(Estimator, "univariate-only"):
        _check_raises_error(Estimator, n_instances=n_instances, n_columns=3)
    else:
        out = _construct_fit_transform(Estimator, n_instances=n_instances, n_columns=3)
        assert isinstance(out, (pd.DataFrame, np.ndarray))
        assert out.shape[0] == n_instances
コード例 #4
0
def _check_raises_error(Estimator, **kwargs):
    with pytest.raises(ValueError, match=r"univariate"):
        if _has_tag(Estimator, "fit-in-transform"):
            # As some estimators have an empty fit method, we here check if they
            # raise the appropriate error in transform rather than fit.
            _construct_fit_transform(Estimator, **kwargs)
        else:
            # All other estimators should raise the error in fit.
            _construct_fit(Estimator, **kwargs)
コード例 #5
0
def check_series_to_series_transform_multivariate(Estimator):
    n_columns = 3
    n_timepoints = 15
    if _has_tag(Estimator, "univariate-only"):
        _check_raises_error(Estimator, n_timepoints=n_timepoints, n_columns=n_columns)
    else:
        out = _construct_fit_transform(
            Estimator, n_timepoints=n_timepoints, n_columns=n_columns
        )
        assert isinstance(out, (pd.DataFrame, np.ndarray))
        assert out.shape == (n_timepoints, n_columns)
コード例 #6
0
def _yield_transformer_checks(Estimator):
    yield from all_transformer_checks
    if hasattr(Estimator, "inverse_transform"):
        yield check_transform_inverse_transform_equivalent
    if issubclass(Estimator, _SeriesToPrimitivesTransformer):
        yield from series_to_primitive_checks
    if issubclass(Estimator, _SeriesToSeriesTransformer):
        yield from series_to_series_checks
    if issubclass(Estimator, _PanelToTabularTransformer):
        yield from panel_to_tabular_checks
    if issubclass(Estimator, _PanelToPanelTransformer):
        yield from panel_to_panel_checks
    if _has_tag(Estimator, "transform-returns-same-time-index"):
        yield check_transform_returns_same_time_index