Exemplo n.º 1
0
def _check_fit_params(X, fit_params, indices=None):
    """Check and validate the parameters passed during `fit`.

    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        Data array.

    fit_params : dict
        Dictionary containing the parameters passed at fit.

    indices : array-like of shape (n_samples,), default=None
        Indices to be selected if the parameter has the same size as
        `X`.

    Returns
    -------
    fit_params_validated : dict
        Validated parameters. We ensure that the values support
        indexing.
    """
    try:
        from sklearn.utils.validation import \
            _check_fit_params as _sklearn_check_fit_params
        return _sklearn_check_fit_params(X, fit_params, indices)
    except ImportError:
        from sklearn.model_selection import _validation

        fit_params_validated = \
            {k: _validation._index_param_value(X, v, indices)
             for k, v in fit_params.items()}
        return fit_params_validated
Exemplo n.º 2
0
def _check_fit_params(
        X,  # type: TwoDimArrayLikeType
        fit_params,  # type: Dict
        indices  # type: OneDimArrayLikeType
):
    # type: (...) -> Dict

    if _sklearn_version >= '0.22.1':
        return _sklearn_check_fit_params(X, fit_params, indices)
    else:  # '_sklearn_version < 0.22.1'
        return {
            key: _sklearn_index_param_value(X, value, indices)
            for key, value in fit_params.items()
        }