def test_infer_requirements_does_not_print_warning_for_recognized_packages(): with mock.patch( "mlflow.utils.requirements_utils._capture_imported_modules", return_value=["sklearn"], ), mock.patch( "mlflow.utils.requirements_utils._PYPI_PACKAGE_INDEX", _PyPIPackageIndex(date="2022-01-01", package_names=set(["scikit-learn"])), ), mock.patch( "mlflow.utils.requirements_utils._logger.warning") as mock_warning: _infer_requirements("path/to/model", "sklearn") mock_warning.assert_not_called()
def test_infer_requirements_prints_warning_for_unrecognized_packages(): with mock.patch( "mlflow.utils.requirements_utils._capture_imported_modules", return_value=["sklearn"], ), mock.patch( "mlflow.utils.requirements_utils._PYPI_PACKAGE_INDEX", _PyPIPackageIndex(date="2022-01-01", package_names=set()), ), mock.patch( "mlflow.utils.requirements_utils._logger.warning") as mock_warning: _infer_requirements("path/to/model", "sklearn") mock_warning.assert_called_once() warning_template = mock_warning.call_args[0][0] date, unrecognized_packages = mock_warning.call_args[0][1:3] warning_text = warning_template % (date, unrecognized_packages) assert "not found in the public PyPI package index" in warning_text assert "scikit-learn" in warning_text
def test_infer_requirements_excludes_mlflow(): with mock.patch( "mlflow.utils.requirements_utils._capture_imported_modules", return_value=["mlflow", "pytest"], ): mlflow_package = "mlflow-skinny" if "MLFLOW_SKINNY" in os.environ else "mlflow" assert mlflow_package in _module_to_packages("mlflow") assert _infer_requirements( "path/to/model", "sklearn") == [f"pytest=={pytest.__version__}"]
def infer_pip_requirements(model_uri, flavor, fallback=None): """ Infers the pip requirements of the specified model by creating a subprocess and loading the model in it to determine which packages are imported. :param model_uri: The URI of the model. :param flavor: The flavor name of the model. :param fallback: If provided, an unexpected error during the inference procedure is swallowed and the value of ``fallback`` is returned. Otherwise, the error is raised. :return: A list of inferred pip requirements (e.g. ``["scikit-learn==0.24.2", ...]``). """ try: return _infer_requirements(model_uri, flavor) except Exception: if fallback is not None: _logger.exception(_INFER_PIP_REQUIREMENTS_FALLBACK_MESSAGE, model_uri, flavor) return fallback raise