Ejemplo n.º 1
0
def test_get_store_for_unregistered_scheme():

    tracking_store = TrackingStoreRegistry()

    with pytest.raises(mlflow.exceptions.MlflowException,
                       match="Unexpected URI scheme"):
        tracking_store.get_store("unknown-scheme://")
Ejemplo n.º 2
0
def test_plugin_registration():
    tracking_store = TrackingStoreRegistry()

    test_uri = "mock-scheme://fake-host/fake-path"
    test_scheme = "mock-scheme"

    mock_plugin = mock.Mock()
    tracking_store.register(test_scheme, mock_plugin)
    assert test_scheme in tracking_store._registry
    assert tracking_store.get_store(test_uri) == mock_plugin.return_value
    mock_plugin.assert_called_once_with(store_uri=test_uri, artifact_uri=None)
Ejemplo n.º 3
0
def test_handle_plugin_registration_failure_via_entrypoints(exception):
    mock_entrypoint = mock.Mock(load=mock.Mock(side_effect=exception))
    mock_entrypoint.name = "mock-scheme"

    with mock.patch("entrypoints.get_group_all",
                    return_value=[mock_entrypoint]) as mock_get_group_all:

        tracking_store = TrackingStoreRegistry()

        # Check that the raised warning contains the message from the original exception
        with pytest.warns(UserWarning, match="test exception"):
            tracking_store.register_entrypoints()

    mock_entrypoint.load.assert_called_once()
    mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
Ejemplo n.º 4
0
def test_plugin_registration_via_entrypoints():
    mock_plugin_function = mock.Mock()
    mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function))
    mock_entrypoint.name = "mock-scheme"

    with mock.patch(
        "entrypoints.get_group_all", return_value=[mock_entrypoint]
    ) as mock_get_group_all:

        tracking_store = TrackingStoreRegistry()
        tracking_store.register_entrypoints()

    assert tracking_store.get_store("mock-scheme://") == mock_plugin_function.return_value

    mock_plugin_function.assert_called_once_with(store_uri="mock-scheme://", artifact_uri=None)
    mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
Ejemplo n.º 5
0
    """
    Get the Databricks profile specified by the tracking URI (if any), otherwise
    returns None.
    """
    parsed_uri = urllib.parse.urlparse(uri)
    if parsed_uri.scheme == "databricks":
        return parsed_uri.netloc
    return None


def _get_databricks_rest_store(store_uri, **_):
    profile = get_db_profile_from_uri(store_uri)
    return DatabricksRestStore(lambda: get_databricks_host_creds(profile))


_tracking_store_registry = TrackingStoreRegistry()
_tracking_store_registry.register('', _get_file_store)
_tracking_store_registry.register('file', _get_file_store)
_tracking_store_registry.register('databricks', _get_databricks_rest_store)

for scheme in ['http', 'https']:
    _tracking_store_registry.register(scheme, _get_rest_store)

for scheme in DATABASE_ENGINES:
    _tracking_store_registry.register(scheme, _get_sqlalchemy_store)

_tracking_store_registry.register_entrypoints()


def _get_store(store_uri=None, artifact_uri=None):
    return _tracking_store_registry.get_store(store_uri, artifact_uri)
Ejemplo n.º 6
0
    if prefix:
        return prefix + route
    return route


def _get_file_store(store_uri, artifact_uri):
    from mlflow.store.file_store import FileStore
    return FileStore(store_uri, artifact_uri)


def _get_sqlalchemy_store(store_uri, artifact_uri):
    from mlflow.store.sqlalchemy_store import SqlAlchemyStore
    return SqlAlchemyStore(store_uri, artifact_uri)


_tracking_store_registry = TrackingStoreRegistry()
_tracking_store_registry.register('', _get_file_store)
_tracking_store_registry.register('file', _get_file_store)
for scheme in DATABASE_ENGINES:
    _tracking_store_registry.register(scheme, _get_sqlalchemy_store)


def _get_store(backend_store_uri=None, default_artifact_root=None):
    from mlflow.server import BACKEND_STORE_URI_ENV_VAR, ARTIFACT_ROOT_ENV_VAR
    global _store
    if _store is None:
        store_uri = backend_store_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
        artifact_root = default_artifact_root or os.environ.get(ARTIFACT_ROOT_ENV_VAR, None)
        _store = _tracking_store_registry.get_store(store_uri, artifact_root)
    return _store