Exemplo n.º 1
0
def test_valid_tag_types(estimator):
    """Check that estimator tags are valid."""
    tags = _safe_tags(estimator)

    for name, tag in tags.items():
        correct_tags = type(_DEFAULT_TAGS[name])
        if name == "_xfail_checks":
            # _xfail_checks can be a dictionary
            correct_tags = (correct_tags, dict)
        assert isinstance(tag, correct_tags)
Exemplo n.º 2
0
    def transform(self, X, y=None):
        """Reduce X to the selected features.

        Parameters
        ----------
        X : ndarray of shape [n_samples, n_features]
            The input samples.
        y : ignored

        Returns
        -------
        X_r : ndarray
            The selected subset of the input.
        """

        if len(X.shape) == 1:
            X = X.reshape(-1, 1)

        mask = self.get_support()

        # note: we use _safe_tags instead of _get_tags because this is a
        # public Mixin.
        X = self._validate_data(
            X,
            dtype=None,
            accept_sparse="csr",
            force_all_finite=not _safe_tags(self, key="allow_nan"),
            reset=False,
            ensure_2d=self._axis,
        )

        if len(mask) != X.shape[self._axis]:
            raise ValueError("X has a different shape than during fitting.")

        if self._axis == 1:
            return X[:, safe_mask(X, mask)]
        else:
            return X[safe_mask(X, mask)]
Exemplo n.º 3
0
 def _more_tags(self):
     # check if first estimator expects pairwise input
     return {'pairwise': _safe_tags(self.steps[0][1], "pairwise")}
Exemplo n.º 4
0
def test_safe_tags_no_get_tags(estimator, key, expected_results):
    # check the behaviour of _safe_tags when an estimator does not implement
    # _get_tags
    assert _safe_tags(estimator, key=key) == expected_results
Exemplo n.º 5
0
def test_safe_tags_error(estimator, err_msg):
    # Check that safe_tags raises error in ambiguous case.
    with pytest.raises(ValueError, match=err_msg):
        _safe_tags(estimator, key="xxx")