def test_handle_plugin_registration_failure_via_entrypoints(exception): mock_entrypoint = mock.Mock(load=mock.Mock(side_effect=exception)) mock_entrypoint.name = "mock-scheme" with mock.patch("entrypoints.get_group_all", return_value=[mock_entrypoint]) as mock_get_group_all: tracking_store = TrackingStoreRegistry() # Check that the raised warning contains the message from the original exception with pytest.warns(UserWarning, match="test exception"): tracking_store.register_entrypoints() mock_entrypoint.load.assert_called_once() mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
def test_plugin_registration_via_entrypoints(): mock_plugin_function = mock.Mock() mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function)) mock_entrypoint.name = "mock-scheme" with mock.patch( "entrypoints.get_group_all", return_value=[mock_entrypoint] ) as mock_get_group_all: tracking_store = TrackingStoreRegistry() tracking_store.register_entrypoints() assert tracking_store.get_store("mock-scheme://") == mock_plugin_function.return_value mock_plugin_function.assert_called_once_with(store_uri="mock-scheme://", artifact_uri=None) mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
def _get_databricks_rest_store(store_uri, **_): return DatabricksRestStore(lambda: get_databricks_host_creds(store_uri)) _tracking_store_registry = TrackingStoreRegistry() _tracking_store_registry.register('', _get_file_store) _tracking_store_registry.register('file', _get_file_store) _tracking_store_registry.register('databricks', _get_databricks_rest_store) for scheme in ['http', 'https']: _tracking_store_registry.register(scheme, _get_rest_store) for scheme in DATABASE_ENGINES: _tracking_store_registry.register(scheme, _get_sqlalchemy_store) _tracking_store_registry.register_entrypoints() def _get_store(store_uri=None, artifact_uri=None): return _tracking_store_registry.get_store(store_uri, artifact_uri) # TODO(sueann): move to a projects utils module def _get_git_url_if_present(uri): """ Return the path git_uri#sub_directory if the URI passed is a local path that's part of a Git repo, or returns the original URI otherwise. :param uri: The expanded uri :return: The git_uri#sub_directory if the uri is part of a Git repo, otherwise return the original uri """