コード例 #1
0
def test_transformer_n_iter():
    transformers = all_estimators(type_filter='transformer')
    for name, Estimator in transformers:
        if issubclass(Estimator, ProjectedGradientNMF):
            # The ProjectedGradientNMF class is deprecated
            with ignore_warnings():
                estimator = Estimator()
        else:
            estimator = Estimator()
        # Dependent on external solvers and hence accessing the iter
        # param is non-trivial.
        external_solver = [
            'Isomap', 'KernelPCA', 'LocallyLinearEmbedding', 'RandomizedLasso',
            'LogisticRegressionCV'
        ]

        if hasattr(estimator, "max_iter") and name not in external_solver:
            if isinstance(estimator, ProjectedGradientNMF):
                # The ProjectedGradientNMF class is deprecated
                with ignore_warnings():
                    yield _set_test_name(check_transformer_n_iter,
                                         name), name, estimator
            else:
                yield _set_test_name(check_transformer_n_iter,
                                     name), name, estimator
コード例 #2
0
def test_non_meta_estimators():
    # input validation etc for non-meta estimators
    estimators = all_estimators()
    for name, Estimator in estimators:
        if issubclass(Estimator, BiclusterMixin):
            continue
        if name.startswith("_"):
            continue
        for check in _yield_all_checks(name, Estimator):
            if issubclass(Estimator, ProjectedGradientNMF):
                # The ProjectedGradientNMF class is deprecated
                with ignore_warnings():
                    yield _set_test_name(check, name), name, Estimator
            else:
                yield _set_test_name(check, name), name, Estimator
コード例 #3
0
def test_non_meta_estimators():
    # input validation etc for non-meta estimators
    estimators = all_estimators()
    for name, Estimator in estimators:
        if issubclass(Estimator, BiclusterMixin):
            continue
        if name.startswith("_"):
            continue
        for check in _yield_all_checks(name, Estimator):
            if issubclass(Estimator, ProjectedGradientNMF):
                # The ProjectedGradientNMF class is deprecated
                with ignore_warnings():
                    yield _set_test_name(check, name), name, Estimator
            else:
                yield _set_test_name(check, name), name, Estimator
コード例 #4
0
def test_non_transformer_estimators_n_iter():
    # Test that all estimators of type which are non-transformer
    # and which have an attribute of max_iter, return the attribute
    # of n_iter atleast 1.
    for est_type in ['regressor', 'classifier', 'cluster']:
        regressors = all_estimators(type_filter=est_type)
        for name, Estimator in regressors:
            # LassoLars stops early for the default alpha=1.0 for
            # the iris dataset.
            if name == 'LassoLars':
                estimator = Estimator(alpha=0.)
            else:
                estimator = Estimator()
            if hasattr(estimator, "max_iter"):
                # These models are dependent on external solvers like
                # libsvm and accessing the iter parameter is non-trivial.
                if name in ([
                        'Ridge', 'SVR', 'NuSVR', 'NuSVC', 'RidgeClassifier',
                        'SVC', 'RandomizedLasso', 'LogisticRegressionCV'
                ]):
                    continue

                # Tested in test_transformer_n_iter below
                elif (name in CROSS_DECOMPOSITION
                      or name in ['LinearSVC', 'LogisticRegression']):
                    continue

                else:
                    # Multitask models related to ENet cannot handle
                    # if y is mono-output.
                    yield (_set_test_name(
                        check_non_transformer_estimators_n_iter,
                        name), name, estimator, 'Multi' in name)
コード例 #5
0
def test_non_transformer_estimators_n_iter():
    # Test that all estimators of type which are non-transformer
    # and which have an attribute of max_iter, return the attribute
    # of n_iter atleast 1.
    for est_type in ['regressor', 'classifier', 'cluster']:
        regressors = all_estimators(type_filter=est_type)
        for name, Estimator in regressors:
            # LassoLars stops early for the default alpha=1.0 for
            # the iris dataset.
            if name == 'LassoLars':
                estimator = Estimator(alpha=0.)
            else:
                estimator = Estimator()
            if hasattr(estimator, "max_iter"):
                # These models are dependent on external solvers like
                # libsvm and accessing the iter parameter is non-trivial.
                if name in (['Ridge', 'SVR', 'NuSVR', 'NuSVC',
                             'RidgeClassifier', 'SVC', 'RandomizedLasso',
                             'LogisticRegressionCV']):
                    continue

                # Tested in test_transformer_n_iter below
                elif (name in CROSS_DECOMPOSITION or
                      name in ['LinearSVC', 'LogisticRegression']):
                    continue

                else:
                    # Multitask models related to ENet cannot handle
                    # if y is mono-output.
                    yield (_set_test_name(
                        check_non_transformer_estimators_n_iter, name),
                        name, estimator, 'Multi' in name)
コード例 #6
0
def test_get_params_invariance():
    # Test for estimators that support get_params, that
    # get_params(deep=False) is a subset of get_params(deep=True)
    # Related to issue #4465

    estimators = all_estimators(include_meta_estimators=False,
                                include_other=True)
    for name, Estimator in estimators:
        if hasattr(Estimator, 'get_params'):
            # If class is deprecated, ignore deprecated warnings
            if hasattr(Estimator.__init__, "deprecated_original"):
                with ignore_warnings():
                    yield _set_test_name(check_get_params_invariance,
                                         name), name, Estimator
            else:
                yield _set_test_name(check_get_params_invariance,
                                     name), name, Estimator
コード例 #7
0
def test_get_params_invariance():
    # Test for estimators that support get_params, that
    # get_params(deep=False) is a subset of get_params(deep=True)
    # Related to issue #4465

    estimators = all_estimators(include_meta_estimators=False,
                                include_other=True)
    for name, Estimator in estimators:
        if hasattr(Estimator, 'get_params'):
            # If class is deprecated, ignore deprecated warnings
            if hasattr(Estimator.__init__, "deprecated_original"):
                with ignore_warnings():
                    yield _set_test_name(
                        check_get_params_invariance, name), name, Estimator
            else:
                yield _set_test_name(
                    check_get_params_invariance, name), name, Estimator
コード例 #8
0
def test_all_estimators():
    # Test that estimators are default-constructible, cloneable
    # and have working repr.
    estimators = all_estimators(include_meta_estimators=True)

    # Meta sanity-check to make sure that the estimator introspection runs
    # properly
    assert_greater(len(estimators), 0)

    for name, Estimator in estimators:
        # some can just not be sensibly default constructed
        yield (_set_test_name(check_parameters_default_constructible,
                              name), name, Estimator)
コード例 #9
0
def test_all_estimators():
    # Test that estimators are default-constructible, cloneable
    # and have working repr.
    estimators = all_estimators(include_meta_estimators=True)

    # Meta sanity-check to make sure that the estimator introspection runs
    # properly
    assert_greater(len(estimators), 0)

    for name, Estimator in estimators:
        # some can just not be sensibly default constructed
        yield (_set_test_name(check_parameters_default_constructible, name),
               name, Estimator)
コード例 #10
0
def test_class_weight_balanced_linear_classifiers():
    classifiers = all_estimators(type_filter='classifier')

    clean_warning_registry()
    with warnings.catch_warnings(record=True):
        linear_classifiers = [
            (name, clazz) for name, clazz in classifiers
            if ('class_weight' in clazz().get_params().keys()
                and issubclass(clazz, LinearClassifierMixin))
        ]

    for name, Classifier in linear_classifiers:
        yield _set_test_name(check_class_weight_balanced_linear_classifier,
                             name), name, Classifier
コード例 #11
0
def test_transformer_n_iter():
    transformers = all_estimators(type_filter='transformer')
    for name, Estimator in transformers:
        if issubclass(Estimator, ProjectedGradientNMF):
            # The ProjectedGradientNMF class is deprecated
            with ignore_warnings():
                estimator = Estimator()
        else:
            estimator = Estimator()
        # Dependent on external solvers and hence accessing the iter
        # param is non-trivial.
        external_solver = ['Isomap', 'KernelPCA', 'LocallyLinearEmbedding',
                           'RandomizedLasso', 'LogisticRegressionCV']

        if hasattr(estimator, "max_iter") and name not in external_solver:
            if isinstance(estimator, ProjectedGradientNMF):
                # The ProjectedGradientNMF class is deprecated
                with ignore_warnings():
                    yield _set_test_name(
                        check_transformer_n_iter, name), name, estimator
            else:
                yield _set_test_name(
                    check_transformer_n_iter, name), name, estimator
コード例 #12
0
def test_class_weight_balanced_linear_classifiers():
    classifiers = all_estimators(type_filter='classifier')

    clean_warning_registry()
    with warnings.catch_warnings(record=True):
        linear_classifiers = [
            (name, clazz)
            for name, clazz in classifiers
            if ('class_weight' in clazz().get_params().keys() and
                issubclass(clazz, LinearClassifierMixin))]

    for name, Classifier in linear_classifiers:
        yield _set_test_name(check_class_weight_balanced_linear_classifier,
                             name), name, Classifier