def test_store_object_can_be_serialized_by_pickle(): """ This test ensures a store object generated by `_get_store` can be serialized by pickle to prevent issues such as https://github.com/mlflow/mlflow/issues/2954 """ pickle.dump(_get_store("https://example.com"), io.BytesIO()) pickle.dump(_get_store("databricks"), io.BytesIO())
def test_get_store_bad_uris(bad_uri): env = { _TRACKING_URI_ENV_VAR: bad_uri } with mock.patch.dict(os.environ, env), pytest.raises(MlflowException): _get_store()
def test_get_store_bad_uris(bad_uri): env = {_TRACKING_URI_ENV_VAR: bad_uri} with mock.patch.dict(os.environ, env), pytest.raises( UnsupportedModelRegistryStoreURIException, match="Model registry functionality is unavailable", ): _get_store()
def test_get_store_caches_on_store_uri(tmpdir): store_uri_1 = "sqlite:///" + tmpdir.join("store1.db").strpath store_uri_2 = "sqlite:///" + tmpdir.join("store2.db").strpath store1 = _get_store(store_uri_1) store2 = _get_store(store_uri_1) assert store1 is store2 store3 = _get_store(store_uri_2) store4 = _get_store(store_uri_2) assert store3 is store4 assert store1 is not store3
def test_fallback_to_tracking_store(): env = {_TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050"} with mock.patch.dict(os.environ, env): store = _get_store() assert isinstance(store, RestStore) assert store.get_host_creds().host == "https://my-tracking-server:5050" assert store.get_host_creds().token is None
def test_get_store_rest_store_from_arg(): env = { _TRACKING_URI_ENV_VAR: "https://my-tracking-server:5050" # should be ignored } with mock.patch.dict(os.environ, env): store = _get_store("http://some/path") assert isinstance(store, RestStore) assert store.get_host_creds().host == "http://some/path"
def test_get_store_sqlalchemy_store(db_type): patch_create_engine = mock.patch("sqlalchemy.create_engine") uri = "{}://hostname/database".format(db_type) env = {_TRACKING_URI_ENV_VAR: uri} with mock.patch.dict(os.environ, env), patch_create_engine as mock_create_engine, \ mock.patch("mlflow.store.model_registry.sqlalchemy_store.SqlAlchemyStore." "_verify_registry_tables_exist"): store = _get_store() assert isinstance(store, SqlAlchemyStore) assert store.db_uri == uri mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True)
def __init__(self, registry_uri): """ :param registry_uri: Address of local or remote model registry server. """ self.registry_uri = registry_uri self.store = utils._get_store(self.registry_uri)
def store(self): return utils._get_store(self.registry_uri)