예제 #1
0
def test_get_store_for_unregistered_scheme():

    tracking_store = TrackingStoreRegistry()

    with pytest.raises(mlflow.exceptions.MlflowException,
                       match="Unexpected URI scheme"):
        tracking_store.get_store("unknown-scheme://")
예제 #2
0
def test_get_store_for_unregistered_scheme():

    tracking_store = TrackingStoreRegistry()

    with pytest.raises(
            UnsupportedModelRegistryStoreURIException,
            match="Model registry functionality is unavailable",
    ):
        tracking_store.get_store("unknown-scheme://")
예제 #3
0
def test_plugin_registration():
    tracking_store = TrackingStoreRegistry()

    test_uri = "mock-scheme://fake-host/fake-path"
    test_scheme = "mock-scheme"

    mock_plugin = mock.Mock()
    tracking_store.register(test_scheme, mock_plugin)
    assert test_scheme in tracking_store._registry
    assert tracking_store.get_store(test_uri) == mock_plugin.return_value
    mock_plugin.assert_called_once_with(store_uri=test_uri, artifact_uri=None)
예제 #4
0
파일: test_utils.py 프로젝트: tnixon/mlflow
def test_handle_plugin_registration_failure_via_entrypoints(exception):
    mock_entrypoint = mock.Mock(load=mock.Mock(side_effect=exception))
    mock_entrypoint.name = "mock-scheme"

    with mock.patch("entrypoints.get_group_all",
                    return_value=[mock_entrypoint]) as mock_get_group_all:

        tracking_store = TrackingStoreRegistry()

        # Check that the raised warning contains the message from the original exception
        with pytest.warns(UserWarning, match="test exception"):
            tracking_store.register_entrypoints()

    mock_entrypoint.load.assert_called_once()
    mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
예제 #5
0
def test_plugin_registration_via_entrypoints():
    mock_plugin_function = mock.Mock()
    mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function))
    mock_entrypoint.name = "mock-scheme"

    with mock.patch(
        "entrypoints.get_group_all", return_value=[mock_entrypoint]
    ) as mock_get_group_all:

        tracking_store = TrackingStoreRegistry()
        tracking_store.register_entrypoints()

    assert tracking_store.get_store("mock-scheme://") == mock_plugin_function.return_value

    mock_plugin_function.assert_called_once_with(store_uri="mock-scheme://", artifact_uri=None)
    mock_get_group_all.assert_called_once_with("mlflow.tracking_store")
예제 #6
0
        token=os.environ.get(_TRACKING_TOKEN_ENV_VAR),
        ignore_tls_verification=os.environ.get(_TRACKING_INSECURE_TLS_ENV_VAR) == 'true',
        client_cert_path=os.environ.get(_TRACKING_CLIENT_CERT_PATH_ENV_VAR),
        server_cert_path=os.environ.get(_TRACKING_SERVER_CERT_PATH_ENV_VAR),
    )


def _get_rest_store(store_uri, **_):
    return RestStore(partial(_get_default_host_creds, store_uri))


def _get_databricks_rest_store(store_uri, **_):
    return DatabricksRestStore(lambda: get_databricks_host_creds(store_uri))


_tracking_store_registry = TrackingStoreRegistry()
_tracking_store_registry.register('', _get_file_store)
_tracking_store_registry.register('file', _get_file_store)
_tracking_store_registry.register('databricks', _get_databricks_rest_store)

for scheme in ['http', 'https']:
    _tracking_store_registry.register(scheme, _get_rest_store)

for scheme in DATABASE_ENGINES:
    _tracking_store_registry.register(scheme, _get_sqlalchemy_store)

_tracking_store_registry.register_entrypoints()


def _get_store(store_uri=None, artifact_uri=None):
    return _tracking_store_registry.get_store(store_uri, artifact_uri)
예제 #7
0
def test_get_store_for_unregistered_scheme():

    tracking_store = TrackingStoreRegistry()

    with pytest.raises(UnsupportedModelRegistryStoreURIException):
        tracking_store.get_store("unknown-scheme://")