예제 #1
0
 def test_not_xgboost(self, model):
     msg_match = r"^model must be from XGBoost's scikit-learn API.*"
     with pytest.raises(TypeError, match=msg_match):
         model_validator.must_xgboost_sklearn(model)
예제 #2
0
    def create_standard_model_from_xgboost(
        self,
        obj,
        environment,
        model_api=None,
        name=None,
        desc=None,
        labels=None,
        attrs=None,
        lock_level=None,
    ):
        """Create a Standard Verta Model version from an XGBoost model.

        .. versionadded:: 0.18.2

        .. note::

            If using an XGBoost model from their scikit-learn API,
            ``"scikit-learn"`` must also be specified in `environment`
            (in addition to ``"xgboost"``).

        Parameters
        ----------
        obj : `xgboost.sklearn.XGBModel <https://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.sklearn>`__
            XGBoost model using their scikit-learn wrapper interface.
        environment : :class:`~verta.environment.Python`
            pip and apt dependencies.
        model_api : :class:`~verta.utils.ModelAPI`, optional
            Model API specifying the model's expected input and output
        name : str, optional
            Name of the model version. If no name is provided, one will be
            generated.
        desc : str, optional
            Description of the model version.
        labels : list of str, optional
            Labels of the model version.
        attrs : dict of str to {None, bool, float, int, str}, optional
            Attributes of the model version.
        lock_level : :mod:`~verta.registry.lock`, default :class:`~verta.registry.lock.Open`
            Lock level to set when creating this model version.

        Returns
        -------
        :class:`~verta.registry.entities.RegisteredModelVersion`

        Examples
        --------
        .. code-block:: python

            import xgboost as xgb
            from verta.environment import Python

            model = xgb.XGBClassifier(**hyperparams)
            model.fit(X_train, y_train)

            model_ver = reg_model.create_standard_model_from_xgboost(
                model,
                Python(["scikit-learn", "xgboost"]),
            )
            endpoint.update(model_ver, wait=True)
            endpoint.get_deployed_model().predict(input)

        """
        model_validator.must_xgboost_sklearn(obj)

        return self._create_standard_model_from_spec(
            model=obj,
            environment=environment,
            model_api=model_api,
            name=name,
            desc=desc,
            labels=labels,
            attrs=attrs,
            lock_level=lock_level,
        )
예제 #3
0
 def test_xgboost(self, model):
     assert model_validator.must_xgboost_sklearn(model)