示例#1
0
def test_store_object_can_be_serialized_by_pickle():
    """
    This test ensures a store object generated by `_get_store` can be serialized by pickle
    to prevent issues such as https://github.com/mlflow/mlflow/issues/2954
    """
    pickle.dump(_get_store("https://example.com"), io.BytesIO())
    pickle.dump(_get_store("databricks"), io.BytesIO())
示例#2
0
def test_get_store_bad_uris(bad_uri):
    env = {
        _TRACKING_URI_ENV_VAR: bad_uri
    }

    with mock.patch.dict(os.environ, env), pytest.raises(MlflowException):
        _get_store()
示例#3
0
def test_get_store_bad_uris(bad_uri):
    env = {_TRACKING_URI_ENV_VAR: bad_uri}

    with mock.patch.dict(os.environ, env), pytest.raises(
            UnsupportedModelRegistryStoreURIException,
            match="Model registry functionality is unavailable",
    ):
        _get_store()
示例#4
0
def test_get_store_caches_on_store_uri(tmpdir):
    store_uri_1 = "sqlite:///" + tmpdir.join("store1.db").strpath
    store_uri_2 = "sqlite:///" + tmpdir.join("store2.db").strpath

    store1 = _get_store(store_uri_1)
    store2 = _get_store(store_uri_1)
    assert store1 is store2

    store3 = _get_store(store_uri_2)
    store4 = _get_store(store_uri_2)
    assert store3 is store4

    assert store1 is not store3
示例#5
0
def test_fallback_to_tracking_store():
    env = {_TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050"}
    with mock.patch.dict(os.environ, env):
        store = _get_store()
        assert isinstance(store, RestStore)
        assert store.get_host_creds().host == "https://my-tracking-server:5050"
        assert store.get_host_creds().token is None
示例#6
0
def test_get_store_rest_store_from_arg():
    env = {
        _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050"  # should be ignored
    }
    with mock.patch.dict(os.environ, env):
        store = _get_store("http://some/path")
        assert isinstance(store, RestStore)
        assert store.get_host_creds().host == "http://some/path"
示例#7
0
def test_get_store_sqlalchemy_store(db_type):
    patch_create_engine = mock.patch("sqlalchemy.create_engine")
    uri = "{}://hostname/database".format(db_type)
    env = {_TRACKING_URI_ENV_VAR: uri}

    with mock.patch.dict(os.environ, env), patch_create_engine as mock_create_engine, \
            mock.patch("mlflow.store.model_registry.sqlalchemy_store.SqlAlchemyStore."
                       "_verify_registry_tables_exist"):
        store = _get_store()
        assert isinstance(store, SqlAlchemyStore)
        assert store.db_uri == uri

    mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True)
示例#8
0
 def __init__(self, registry_uri):
     """
     :param registry_uri: Address of local or remote model registry server.
     """
     self.registry_uri = registry_uri
     self.store = utils._get_store(self.registry_uri)
示例#9
0
 def store(self):
     return utils._get_store(self.registry_uri)