コード例 #1
0
ファイル: test_uri.py プロジェクト: TheVinhLuong102/mlflow
def test_extract_db_type_from_uri():
    uri = "{}://username:password@host:port/database"
    for legit_db in DATABASE_ENGINES:
        assert legit_db == extract_db_type_from_uri(uri.format(legit_db))
        assert legit_db == get_uri_scheme(uri.format(legit_db))

        with_driver = legit_db + "+driver-string"
        assert legit_db == extract_db_type_from_uri(uri.format(with_driver))
        assert legit_db == get_uri_scheme(uri.format(with_driver))

    for unsupported_db in ["a", "aa", "sql"]:
        with pytest.raises(MlflowException, match="Invalid database engine"):
            extract_db_type_from_uri(unsupported_db)
コード例 #2
0
def dbfs_artifact_repo_factory(artifact_uri):
    """
    Returns an ArtifactRepository subclass for storing artifacts on DBFS.

    This factory method is used with URIs of the form ``dbfs:/<path>``. DBFS-backed artifact
    storage can only be used together with the RestStore.

    In the special case where the URI is of the form
    `dbfs:/databricks/mlflow-tracking/<Exp-ID>/<Run-ID>/<path>',
    a DatabricksArtifactRepository is returned. This is capable of storing access controlled
    artifacts.

    :param artifact_uri: DBFS root artifact URI (string).
    :return: Subclass of ArtifactRepository capable of storing artifacts on DBFS.
    """
    cleaned_artifact_uri = artifact_uri.rstrip('/')
    uri_scheme = get_uri_scheme(artifact_uri)
    if uri_scheme != 'dbfs':
        raise MlflowException(
            "DBFS URI must be of the form "
            "dbfs:/<path>, but received {uri}".format(uri=artifact_uri))
    if is_databricks_acled_artifacts_uri(artifact_uri):
        return DatabricksArtifactRepository(cleaned_artifact_uri)
    elif mlflow.utils.databricks_utils.is_dbfs_fuse_available() \
            and os.environ.get(USE_FUSE_ENV_VAR, "").lower() != "false" \
            and not artifact_uri.startswith("dbfs:/databricks/mlflow-registry"):
        # If the DBFS FUSE mount is available, write artifacts directly to /dbfs/... using
        # local filesystem APIs
        file_uri = "file:///dbfs/{}".format(
            strip_prefix(cleaned_artifact_uri, "dbfs:/"))
        return LocalArtifactRepository(file_uri)
    return DbfsRestArtifactRepository(cleaned_artifact_uri)
コード例 #3
0
    def get_artifact_repository(self, artifact_uri):
        """Get an artifact repository from the registry based on the scheme of artifact_uri

        :param store_uri: The store URI. This URI is used to select which artifact repository
                          implementation to instantiate and is passed to the
                          constructor of the implementation.

        :return: An instance of `mlflow.store.ArtifactRepository` that fulfills the artifact URI
                 requirements.
        """
        scheme = get_uri_scheme(artifact_uri)
        repository = self._registry.get(scheme)
        if repository is None:
            raise MlflowException(
                "Could not find a registered artifact repository for: {}. "
                "Currently registered schemes are: {}".format(
                    artifact_uri, list(self._registry.keys())))
        return repository(artifact_uri)
コード例 #4
0
    def get_store_builder(self, store_uri):
        """Get a store from the registry based on the scheme of store_uri

        :param store_uri: The store URI. If None, it will be inferred from the environment. This
                          URI is used to select which tracking store implementation to instantiate
                          and is passed to the constructor of the implementation.
        :return: A function that returns an instance of
                 ``mlflow.store.{tracking|model_registry}.AbstractStore`` that fulfills the store
                  URI requirements.
        """
        scheme = store_uri if store_uri == "databricks" else get_uri_scheme(
            store_uri)

        try:
            store_builder = self._registry[scheme]
        except KeyError:
            raise UnsupportedModelRegistryStoreURIException(
                unsupported_uri=store_uri,
                supported_uri_schemes=list(self._registry.keys()))
        return store_builder
コード例 #5
0
    def get_store(self, store_uri=None, artifact_uri=None):
        """Get a store from the registry based on the scheme of store_uri

        :param store_uri: The store URI. If None, it will be inferred from the environment. This URI
                          is used to select which tracking store implementation to instantiate and
                          is passed to the constructor of the implementation.
        :param artifact_uri: Artifact repository URI. Passed through to the tracking store
                             implementation.

        :return: An instance of `mlflow.store.AbstractStore` that fulfills the store URI
                 requirements.
        """
        from mlflow.tracking import utils
        store_uri = store_uri if store_uri is not None else utils.get_tracking_uri()
        scheme = store_uri if store_uri == "databricks" else get_uri_scheme(store_uri)

        try:
            store_builder = self._registry[scheme]
        except KeyError:
            raise MlflowException(
                "Unexpected URI scheme '{}' for tracking store. "
                "Valid schemes are: {}".format(store_uri, list(self._registry.keys())))
        return store_builder(store_uri=store_uri, artifact_uri=artifact_uri)