コード例 #1
0
ファイル: test_model.py プロジェクト: zyfra/ebonite
def _check_model_wrapper(graph, model, feed_dict, tmpdir):
    # this import is required to ensure that Tensorflow model wrapper is registered
    import ebonite.ext.tensorflow  # noqa

    with tf.Session(graph=graph) as session:
        session.run(tf.global_variables_initializer())
        # training here is just random initialization

        tmw = ModelAnalyzer.analyze(model, input_data=feed_dict)
        assert tmw.model.tensors is model

        expected_requirements = {'tensorflow', 'numpy'}
        assert set(tmw.requirements.modules) == expected_requirements

        pred = tmw.call_method('predict', feed_dict)

        with tmw.dump() as artifact:
            artifact.materialize(tmpdir)

        tmw.unbind()
        with pytest.raises(ValueError):
            tmw.call_method('predict', feed_dict)

    tmw.load(tmpdir)
    assert tmw.model is not model

    pred2 = tmw.call_method('predict', feed_dict)
    assert pred2 == pred

    assert set(tmw.requirements.modules) == expected_requirements
コード例 #2
0
ファイル: test_model.py プロジェクト: zyfra/ebonite
def _check_model_wrapper(net, input_data, tmpdir):
    # this import is required for dataset type to be registered
    import ebonite.ext.torch  # noqa

    tmw = ModelAnalyzer.analyze(net, input_data=input_data)

    assert tmw.model is net

    expected_requirements = {'torch'}
    assert set(tmw.requirements.modules) == expected_requirements

    prediction = tmw.call_method('predict', input_data)

    with tmw.dump() as artifact:
        artifact.materialize(tmpdir)

    tmw.unbind()
    with pytest.raises(ValueError):
        tmw.call_method('predict', input_data)

    tmw.load(tmpdir)

    assert tmw.model is not net

    prediction2 = tmw.call_method('predict', input_data)

    assert torch.equal(prediction, prediction2)

    assert set(tmw.requirements.modules) == expected_requirements
コード例 #3
0
ファイル: core.py プロジェクト: geffy/ebonite
    def create(cls,
               model_object,
               input_data,
               model_name: str = None,
               additional_artifacts: ArtifactCollection = None,
               additional_requirements: AnyRequirements = None,
               custom_wrapper: ModelWrapper = None,
               custom_artifact: ArtifactCollection = None,
               custom_input_meta: DatasetType = None,
               custom_output_meta: DatasetType = None,
               custom_prediction=None,
               custom_requirements: AnyRequirements = None) -> 'Model':
        """
        Creates Model instance from arbitrary model objects and sample of input data

        :param model_object: The model object to analyze.
        :param input_data: The image to run.
        :param model_name: The model name.
        :param additional_artifacts: Additional artifact.
        :param additional_requirements: Additional requirements.
        :param custom_wrapper: Custom model wrapper.
        :param custom_artifact: Custom artifact collection to replace all other.
        :param custom_input_meta: Custom input DatasetType.
        :param custom_output_meta: Custom output DatasetType.
        :param custom_prediction: Custom prediction output.
        :param custom_requirements: Custom requirements to replace all other.
        :returns: :py:class:`Model`
        """
        wrapper: ModelWrapper = custom_wrapper or ModelAnalyzer.analyze(
            model_object)
        name = model_name or _generate_model_name(wrapper)

        artifact = custom_artifact or WrapperArtifactCollection(wrapper)
        if additional_artifacts is not None:
            artifact += additional_artifacts

        input_meta = custom_input_meta or DatasetAnalyzer.analyze(input_data)
        prediction = custom_prediction or wrapper.predict(input_data)
        output_meta = custom_output_meta or DatasetAnalyzer.analyze(prediction)

        if custom_requirements is not None:
            requirements = resolve_requirements(custom_requirements)
        else:
            requirements = get_object_requirements(model_object)
            requirements += get_object_requirements(input_data)
            requirements += get_object_requirements(prediction)

        if additional_requirements is not None:
            requirements += additional_requirements
        model = Model(name, wrapper, None, input_meta, output_meta,
                      requirements)
        model._unpersisted_artifacts = artifact
        return model
コード例 #4
0
ファイル: test_model.py プロジェクト: geffy/ebonite
def _check_model_wrapper(net, tmpdir):
    tmw = ModelAnalyzer.analyze(net)

    assert tmw.model is net

    with tmw.dump() as artifact:
        artifact.materialize(tmpdir)
    tmw.load(tmpdir)

    assert tmw.model is not net

    return tmw.model
コード例 #5
0
ファイル: wrapper.py プロジェクト: geffy/ebonite
    def _safe_analyze(self, obj):
        """
        Checks if obj has wrapper

        :param obj: object to check
        :return: :class:`ModelWrapper` instance or None
        """
        # we couldn't import analyzer at top as it leads to circular import failure
        from ebonite.core.analyzer.model import ModelAnalyzer
        try:
            return ModelAnalyzer.analyze(obj)
        except ValueError:
            return None
コード例 #6
0
    def create(cls,
               model_object,
               input_data,
               model_name: str = None,
               params: Dict[str, Any] = None,
               description: str = None,
               additional_artifacts: ArtifactCollection = None,
               additional_requirements: AnyRequirements = None,
               custom_wrapper: ModelWrapper = None,
               custom_artifact: ArtifactCollection = None,
               custom_requirements: AnyRequirements = None) -> 'Model':
        """
        Creates Model instance from arbitrary model objects and sample of input data

        :param model_object: The model object to analyze.
        :param input_data: Input data sample to determine structure of inputs and outputs for given model object.
        :param model_name: The model name.
        :param params: dict with arbitrary parameters. Must be json-serializable
        :param description: text description of this model
        :param additional_artifacts: Additional artifact.
        :param additional_requirements: Additional requirements.
        :param custom_wrapper: Custom model wrapper.
        :param custom_artifact: Custom artifact collection to replace all other.
        :param custom_requirements: Custom requirements to replace all other.
        :returns: :py:class:`Model`
        """
        wrapper: ModelWrapper = custom_wrapper or ModelAnalyzer.analyze(
            model_object, input_data=input_data)
        name = model_name or _generate_model_name(wrapper)

        artifact = custom_artifact or WrapperArtifactCollection(wrapper)
        if additional_artifacts is not None:
            artifact += additional_artifacts

        if custom_requirements is not None:
            requirements = resolve_requirements(custom_requirements)
        else:
            requirements = wrapper.requirements

        if additional_requirements is not None:
            requirements += additional_requirements

        requirements = RequirementAnalyzer.analyze(requirements)
        params = params or {}
        params[cls.PYTHON_VERSION] = params.get(cls.PYTHON_VERSION,
                                                get_python_version())
        model = Model(name, wrapper, None, requirements, params, description)
        model._unpersisted_artifacts = artifact
        return model
コード例 #7
0
def test_wrapper__dump_load(tmpdir, model, inp_data, request):
    model = request.getfixturevalue(model)
    wrapper = ModelAnalyzer.analyze(model, input_data=inp_data)

    expected_requirements = {'sklearn', 'numpy'}
    assert set(wrapper.requirements.modules) == expected_requirements

    with wrapper.dump() as d:
        d.materialize(tmpdir)
    wrapper.unbind()
    with pytest.raises(ValueError):
        wrapper.call_method('predict', inp_data)

    wrapper.load(tmpdir)
    np.testing.assert_array_almost_equal(
        model.predict(inp_data), wrapper.call_method('predict', inp_data))
    assert set(wrapper.requirements.modules) == expected_requirements
コード例 #8
0
ファイル: wrapper.py プロジェクト: DariaMishina/ebonite
    def _get_non_pickle_io(self, obj):
        """
        Checks if obj has non-Pickle IO and returns it

        :param obj: object to check
        :return: non-Pickle :class:`ModelIO` instance or None
        """

        # avoid calling heavy analyzer machinery for "unknown" objects:
        # they are either non-models or callables
        if not isinstance(obj, self.known_types):
            return None

        # we couldn't import analyzer at top as it leads to circular import failure
        from ebonite.core.analyzer.model import ModelAnalyzer
        try:
            io = ModelAnalyzer._find_hook(obj)._wrapper_factory().io
            return None if isinstance(io, PickleModelIO) else io
        except ValueError:
            # non-model object
            return None
コード例 #9
0
def test_catboost_model_wrapper(catboost_model, pandas_data, tmpdir, request):
    catboost_model = request.getfixturevalue(catboost_model)

    # this import is required to ensure that CatBoost model wrapper is registered
    import ebonite.ext.catboost  # noqa

    cbmw = ModelAnalyzer.analyze(catboost_model)
    assert cbmw.model is catboost_model

    with cbmw.dump() as artifact:
        artifact.materialize(tmpdir)

    cbmw.unbind()
    with pytest.raises(ValueError):
        cbmw.predict(pandas_data)

    cbmw.load(tmpdir)
    assert cbmw.model is not catboost_model

    np.testing.assert_array_almost_equal(catboost_model.predict(pandas_data),
                                         cbmw.predict(pandas_data))
コード例 #10
0
def test_model_wrapper(net, input_data, tmpdir, request):
    # force loading of dataset and model hooks
    import ebonite.ext.tensorflow_v2  # noqa

    net = request.getfixturevalue(net)
    input_data = request.getfixturevalue(input_data)

    orig_pred = net(input_data) if callable(net) else net.predict(input_data)

    tmw = ModelAnalyzer.analyze(net, input_data=input_data)

    assert tmw.model is net

    expected_requirements = {'tensorflow', 'numpy'}
    assert set(tmw.requirements.modules) == expected_requirements

    prediction = tmw.call_method('predict', input_data)

    np.testing.assert_array_equal(orig_pred, prediction)

    with tmw.dump() as artifact:
        artifact.materialize(tmpdir)

    tmw.unbind()
    with pytest.raises(ValueError):
        tmw.call_method('predict', input_data)

    tmw.load(tmpdir)

    assert tmw.model is not net

    prediction2 = tmw.call_method('predict', input_data)

    np.testing.assert_array_equal(prediction, prediction2)

    assert set(tmw.requirements.modules) == expected_requirements
コード例 #11
0
def test_wrapper__reg_predict_proba(regressor, inp_data):
    wrapper = ModelAnalyzer.analyze(regressor, input_data=inp_data)

    with pytest.raises(ValueError):
        wrapper.call_method('predict_proba', inp_data)
コード例 #12
0
def test_wrapper__predict(model, inp_data, request):
    model = request.getfixturevalue(model)
    wrapper = ModelAnalyzer.analyze(model, input_data=inp_data)

    np.testing.assert_array_almost_equal(
        model.predict(inp_data), wrapper.call_method('predict', inp_data))
コード例 #13
0
def test_wrapper__clf_predict_proba(classifier, inp_data):
    wrapper = ModelAnalyzer.analyze(classifier, input_data=inp_data)

    np.testing.assert_array_almost_equal(
        classifier.predict_proba(inp_data),
        wrapper.call_method('predict_proba', inp_data))
コード例 #14
0
def test_hook(model, inp_data, request):
    model = request.getfixturevalue(model)
    wrapper = ModelAnalyzer.analyze(model, input_data=inp_data)

    assert isinstance(wrapper, SklearnModelWrapper)
コード例 #15
0
def wrapper(booster, dmatrix_np) -> ModelWrapper:
    return ModelAnalyzer.analyze(booster, input_data=dmatrix_np)
コード例 #16
0
ファイル: test_model.py プロジェクト: geffy/ebonite
def test_hook(booster):
    wrapper = ModelAnalyzer.analyze(booster)
    assert isinstance(wrapper, XGBoostModelWrapper)
    assert wrapper.model == booster
コード例 #17
0
def wrapper(booster) -> ModelWrapper:
    return ModelAnalyzer.analyze(booster)
コード例 #18
0
def wrapper(summer_model, numpy_data):
    model_obj = _wrap_model(summer_model)
    return ModelAnalyzer.analyze(model_obj, input_data=numpy_data)
コード例 #19
0
ファイル: test_model.py プロジェクト: geffy/ebonite
def wrapper(summer_model):
    model_obj = _wrap_model(summer_model)
    return ModelAnalyzer.analyze(model_obj)
コード例 #20
0
def wrapper(model):
    return ModelAnalyzer.analyze(model)