コード例 #1
0
def test_infer_requirements_does_not_print_warning_for_recognized_packages():
    with mock.patch(
            "mlflow.utils.requirements_utils._capture_imported_modules",
            return_value=["sklearn"],
    ), mock.patch(
            "mlflow.utils.requirements_utils._PYPI_PACKAGE_INDEX",
            _PyPIPackageIndex(date="2022-01-01",
                              package_names=set(["scikit-learn"])),
    ), mock.patch(
            "mlflow.utils.requirements_utils._logger.warning") as mock_warning:
        _infer_requirements("path/to/model", "sklearn")
        mock_warning.assert_not_called()
コード例 #2
0
def test_infer_requirements_prints_warning_for_unrecognized_packages():
    with mock.patch(
            "mlflow.utils.requirements_utils._capture_imported_modules",
            return_value=["sklearn"],
    ), mock.patch(
            "mlflow.utils.requirements_utils._PYPI_PACKAGE_INDEX",
            _PyPIPackageIndex(date="2022-01-01", package_names=set()),
    ), mock.patch(
            "mlflow.utils.requirements_utils._logger.warning") as mock_warning:
        _infer_requirements("path/to/model", "sklearn")

        mock_warning.assert_called_once()
        warning_template = mock_warning.call_args[0][0]
        date, unrecognized_packages = mock_warning.call_args[0][1:3]
        warning_text = warning_template % (date, unrecognized_packages)
        assert "not found in the public PyPI package index" in warning_text
        assert "scikit-learn" in warning_text
コード例 #3
0
def test_infer_requirements_excludes_mlflow():
    with mock.patch(
            "mlflow.utils.requirements_utils._capture_imported_modules",
            return_value=["mlflow", "pytest"],
    ):
        mlflow_package = "mlflow-skinny" if "MLFLOW_SKINNY" in os.environ else "mlflow"
        assert mlflow_package in _module_to_packages("mlflow")
        assert _infer_requirements(
            "path/to/model", "sklearn") == [f"pytest=={pytest.__version__}"]
コード例 #4
0
def infer_pip_requirements(model_uri, flavor, fallback=None):
    """
    Infers the pip requirements of the specified model by creating a subprocess and loading
    the model in it to determine which packages are imported.

    :param model_uri: The URI of the model.
    :param flavor: The flavor name of the model.
    :param fallback: If provided, an unexpected error during the inference procedure is swallowed
                     and the value of ``fallback`` is returned. Otherwise, the error is raised.
    :return: A list of inferred pip requirements (e.g. ``["scikit-learn==0.24.2", ...]``).
    """
    try:
        return _infer_requirements(model_uri, flavor)
    except Exception:
        if fallback is not None:
            _logger.exception(_INFER_PIP_REQUIREMENTS_FALLBACK_MESSAGE, model_uri, flavor)
            return fallback
        raise