def test_get_store_for_unregistered_scheme(): tracking_store = TrackingStoreRegistry() with pytest.raises(mlflow.exceptions.MlflowException, match="Unexpected URI scheme"): tracking_store.get_store("unknown-scheme://")
def test_plugin_registration(): tracking_store = TrackingStoreRegistry() test_uri = "mock-scheme://fake-host/fake-path" test_scheme = "mock-scheme" mock_plugin = mock.Mock() tracking_store.register(test_scheme, mock_plugin) assert test_scheme in tracking_store._registry assert tracking_store.get_store(test_uri) == mock_plugin.return_value mock_plugin.assert_called_once_with(store_uri=test_uri, artifact_uri=None)
def test_handle_plugin_registration_failure_via_entrypoints(exception): mock_entrypoint = mock.Mock(load=mock.Mock(side_effect=exception)) mock_entrypoint.name = "mock-scheme" with mock.patch("entrypoints.get_group_all", return_value=[mock_entrypoint]) as mock_get_group_all: tracking_store = TrackingStoreRegistry() # Check that the raised warning contains the message from the original exception with pytest.warns(UserWarning, match="test exception"): tracking_store.register_entrypoints() mock_entrypoint.load.assert_called_once() mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
def test_plugin_registration_via_entrypoints(): mock_plugin_function = mock.Mock() mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function)) mock_entrypoint.name = "mock-scheme" with mock.patch( "entrypoints.get_group_all", return_value=[mock_entrypoint] ) as mock_get_group_all: tracking_store = TrackingStoreRegistry() tracking_store.register_entrypoints() assert tracking_store.get_store("mock-scheme://") == mock_plugin_function.return_value mock_plugin_function.assert_called_once_with(store_uri="mock-scheme://", artifact_uri=None) mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
""" Get the Databricks profile specified by the tracking URI (if any), otherwise returns None. """ parsed_uri = urllib.parse.urlparse(uri) if parsed_uri.scheme == "databricks": return parsed_uri.netloc return None def _get_databricks_rest_store(store_uri, **_): profile = get_db_profile_from_uri(store_uri) return DatabricksRestStore(lambda: get_databricks_host_creds(profile)) _tracking_store_registry = TrackingStoreRegistry() _tracking_store_registry.register('', _get_file_store) _tracking_store_registry.register('file', _get_file_store) _tracking_store_registry.register('databricks', _get_databricks_rest_store) for scheme in ['http', 'https']: _tracking_store_registry.register(scheme, _get_rest_store) for scheme in DATABASE_ENGINES: _tracking_store_registry.register(scheme, _get_sqlalchemy_store) _tracking_store_registry.register_entrypoints() def _get_store(store_uri=None, artifact_uri=None): return _tracking_store_registry.get_store(store_uri, artifact_uri)
if prefix: return prefix + route return route def _get_file_store(store_uri, artifact_uri): from mlflow.store.file_store import FileStore return FileStore(store_uri, artifact_uri) def _get_sqlalchemy_store(store_uri, artifact_uri): from mlflow.store.sqlalchemy_store import SqlAlchemyStore return SqlAlchemyStore(store_uri, artifact_uri) _tracking_store_registry = TrackingStoreRegistry() _tracking_store_registry.register('', _get_file_store) _tracking_store_registry.register('file', _get_file_store) for scheme in DATABASE_ENGINES: _tracking_store_registry.register(scheme, _get_sqlalchemy_store) def _get_store(backend_store_uri=None, default_artifact_root=None): from mlflow.server import BACKEND_STORE_URI_ENV_VAR, ARTIFACT_ROOT_ENV_VAR global _store if _store is None: store_uri = backend_store_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None) artifact_root = default_artifact_root or os.environ.get(ARTIFACT_ROOT_ENV_VAR, None) _store = _tracking_store_registry.get_store(store_uri, artifact_root) return _store