def test_import_optional():
    match = "Missing .*notapackage.* pip .* conda .* notapackage"
    with pytest.raises(ImportError, match=match):
        import_optional_dependency("notapackage")

    result = import_optional_dependency("notapackage", raise_on_missing=False)
    assert result is None
def test_no_version_raises():
    name = "fakemodule"
    module = types.ModuleType(name)
    sys.modules[name] = module
    VERSIONS[name] = "1.0.0"

    with pytest.raises(ImportError, match="Can't determine .* fakemodule"):
        import_optional_dependency(name)
def test_bad_version():
    name = "fakemodule"
    module = types.ModuleType(name)
    module.__version__ = "0.9.0"
    sys.modules[name] = module
    VERSIONS[name] = "1.0.0"

    match = "Eland requires .*1.0.0.* of .fakemodule.*'0.9.0'"
    with pytest.raises(ImportError, match=match):
        import_optional_dependency("fakemodule")

    with pytest.warns(UserWarning):
        result = import_optional_dependency("fakemodule", on_version="warn")
    assert result is None

    module.__version__ = "1.0.0"  # exact match is OK
    result = import_optional_dependency("fakemodule")
    assert result is module
def test_xlrd_version_fallback():
    pytest.importorskip("xlrd")
    import_optional_dependency("xlrd")
예제 #5
0
# Licensed to Elasticsearch B.V under one or more agreements.
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
# See the LICENSE file in the project root for more information

from typing import List, Union

import numpy as np

from eland.ml._optional import import_optional_dependency
from eland.ml._model_serializer import Tree, TreeNode, Ensemble

sklearn = import_optional_dependency("sklearn")
xgboost = import_optional_dependency("xgboost")

from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils.validation import check_is_fitted
from xgboost import Booster, XGBRegressor, XGBClassifier


class ModelTransformer:
    def __init__(
        self,
        model,
        feature_names: List[str],
        classification_labels: List[str] = None,
        classification_weights: List[float] = None,
    ):
        self._feature_names = feature_names
        self._model = model
        self._classification_labels = classification_labels