def test_run_relative_artifact_repo_root_path(self, artifact_uri, expected_relative_path):
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root"
     ) as get_run_artifact_root_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         # Basic artifact uri
         repo = get_artifact_repository(artifact_uri)
         assert repo.run_id == MOCK_RUN_ID
         assert repo.run_relative_artifact_repo_root_path == expected_relative_path
Exemple #2
0
def get_model_version_artifact_handler():
    from querystring_parser import parser

    query_string = request.query_string.decode("utf-8")
    request_dict = parser.parse(query_string, normalized=True)
    name = request_dict.get("name")
    version = request_dict.get("version")
    artifact_uri = _get_model_registry_store().get_model_version_download_uri(name, version)
    return _send_artifact(get_artifact_repository(artifact_uri), request_dict["path"])
 def test_init_validation_and_cleaning(self):
     with mock.patch(DBFS_ARTIFACT_REPOSITORY_PACKAGE + '._get_host_creds_from_default_store') \
             as get_creds_mock:
         get_creds_mock.return_value = lambda: MlflowHostCreds('http://host'
                                                               )
         repo = get_artifact_repository('dbfs:/test/')
         assert repo.artifact_uri == 'dbfs:/test'
         with pytest.raises(MlflowException):
             DbfsRestArtifactRepository('s3://test')
Exemple #4
0
def get_model_version_artifact_handler():
    query_string = request.query_string.decode('utf-8')
    request_dict = parser.parse(query_string, normalized=True)
    name = request_dict.get('name')
    version = request_dict.get('version')
    artifact_uri = _get_model_registry_store().get_model_version_download_uri(
        name, version)
    return _send_artifact(get_artifact_repository(artifact_uri),
                          request_dict['path'])
 def test_extract_run_id(self):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
             as get_run_artifact_root_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         expected_run_id = "RUN_ID"
         repo = get_artifact_repository(
             'dbfs:/databricks/mlflow-tracking/EXP/RUN_ID/artifact')
         assert repo.run_id == expected_run_id
         repo = get_artifact_repository(
             'dbfs:/databricks/mlflow-tracking/EXP_ID/RUN_ID/artifacts')
         assert repo.run_id == expected_run_id
         repo = get_artifact_repository(
             'dbfs:/databricks///mlflow-tracking///EXP_ID///RUN_ID///artifacts/'
         )
         assert repo.run_id == expected_run_id
         repo = get_artifact_repository(
             'dbfs:/databricks///mlflow-tracking//EXP_ID//RUN_ID///artifacts//'
         )
         assert repo.run_id == expected_run_id
Exemple #6
0
    def log_artifacts(self, run_id, local_dir, artifact_path=None):
        """
        Write a directory of files to the remote ``artifact_uri``.

        :param local_dir: Path to the directory of files to write.
        :param artifact_path: If provided, the directory in ``artifact_uri`` to write to.
        """
        run = self.get_run(run_id)
        artifact_repo = get_artifact_repository(run.info.artifact_uri)
        artifact_repo.log_artifacts(local_dir, artifact_path)
    def __init__(self, artifact_uri):
        from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository

        super().__init__(artifact_uri)
        if is_using_databricks_registry(artifact_uri):
            # Use the DatabricksModelsArtifactRepository if a databricks profile is being used.
            self.repo = DatabricksModelsArtifactRepository(artifact_uri)
        else:
            uri = ModelsArtifactRepository.get_underlying_uri(artifact_uri)
            self.repo = get_artifact_repository(uri)
Exemple #8
0
def list_artifacts(run_id, artifact_path):
    """
    Return all the artifacts directly under run's root artifact directory,
    or a sub-directory. The output is a JSON-formatted list.
    """
    artifact_path = artifact_path if artifact_path is not None else ""
    store = _get_store()
    artifact_uri = store.get_run(run_id).info.artifact_uri
    artifact_repo = get_artifact_repository(artifact_uri)
    file_infos = artifact_repo.list_artifacts(artifact_path)
    click.echo(_file_infos_to_json(file_infos))
Exemple #9
0
def test_plugin_registration():
    artifact_repository_registry = ArtifactRepositoryRegistry()

    mock_plugin = mock.Mock()
    artifact_repository_registry.register("mock-scheme", mock_plugin)
    assert "mock-scheme" in artifact_repository_registry._registry
    repository_instance = artifact_repository_registry.get_artifact_repository(
        artifact_uri="mock-scheme://fake-host/fake-path")
    assert repository_instance == mock_plugin.return_value

    mock_plugin.assert_called_once_with("mock-scheme://fake-host/fake-path")
Exemple #10
0
def _get_artifact_repo_mlflow_artifacts():
    """
    Get an artifact repository specified by ``--artifacts-destination`` option for ``mlflow server``
    command.
    """
    from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR

    global _artifact_repo
    if _artifact_repo is None:
        _artifact_repo = get_artifact_repository(
            os.environ[ARTIFACTS_DESTINATION_ENV_VAR])
    return _artifact_repo
Exemple #11
0
def test_file_artifact_is_logged_and_downloaded_successfully(s3_artifact_root, tmpdir):
    file_name = "test.txt"
    file_path = os.path.join(str(tmpdir), file_name)
    file_text = "Hello world!"

    with open(file_path, "w") as f:
        f.write(file_text)

    repo = get_artifact_repository(posixpath.join(s3_artifact_root, "some/path"))
    repo.log_artifact(file_path)
    downloaded_text = open(repo.download_artifacts(file_name)).read()
    assert downloaded_text == file_text
Exemple #12
0
    def log_artifacts(self, run_id, local_dir, artifact_path=None):
        """
        Write a directory of files to the remote ``artifact_uri``.

        :param local_dir: Path to the directory of files to write.
        :param artifact_path: If provided, the directory in ``artifact_uri`` to write to.
        """
        run = self.get_run(run_id)
        # TODO: use add_databricks_profile_info_to_artifact_uri(artifact_root, self.tracking_uri)
        #   for DBFS artifact repositories to use the correct tracking URI.
        artifact_repo = get_artifact_repository(run.info.artifact_uri)
        artifact_repo.log_artifacts(local_dir, artifact_path)
Exemple #13
0
def log_artifacts(local_dir, run_id, artifact_path):
    """
    Log the files within a local directory as an artifact of a run, optionally
    within a run-specific artifact path. Run artifacts can be organized into
    directories, so you can place the artifact in a directory this way.
    """
    store = _get_store()
    artifact_uri = store.get_run(run_id).info.artifact_uri
    artifact_repo = get_artifact_repository(artifact_uri)
    artifact_repo.log_artifacts(local_dir, artifact_path)
    _logger.info("Logged artifact from local dir %s to artifact_path=%s",
                 local_dir, artifact_path)
Exemple #14
0
def _download_artifact_from_uri(artifact_uri, output_path=None):
    """
    :param artifact_uri: The *absolute* URI of the artifact to download.
    :param output_path: The local filesystem path to which to download the artifact. If unspecified,
                        a local output path will be created.
    """
    if os.path.exists(artifact_uri):
        if os.name != "nt":
            # If we're dealing with local files, just reference the direct pathing.
            # non-nt-based file systems can directly reference path information, while nt-based
            # systems need to url-encode special characters in directory listings to be able to
            # resolve them (i.e., spaces converted to %20 within a file name or path listing)
            root_uri = os.path.dirname(artifact_uri)
            artifact_path = os.path.basename(artifact_uri)
            return get_artifact_repository(
                artifact_uri=root_uri).download_artifacts(
                    artifact_path=artifact_path, dst_path=output_path)
        else:  # if we're dealing with nt-based systems, we need to utilize pathname2url to encode.
            artifact_uri = path_to_local_file_uri(artifact_uri)

    parsed_uri = urllib.parse.urlparse(str(artifact_uri))
    prefix = ""
    if parsed_uri.scheme and not parsed_uri.path.startswith("/"):
        # relative path is a special case, urllib does not reconstruct it properly
        prefix = parsed_uri.scheme + ":"
        parsed_uri = parsed_uri._replace(scheme="")

    # For models:/ URIs, it doesn't make sense to initialize a ModelsArtifactRepository with only
    # the model name portion of the URI, then call download_artifacts with the version info.
    if ModelsArtifactRepository.is_models_uri(artifact_uri):
        root_uri = artifact_uri
        artifact_path = ""
    else:
        artifact_path = posixpath.basename(parsed_uri.path)
        parsed_uri = parsed_uri._replace(
            path=posixpath.dirname(parsed_uri.path))
        root_uri = prefix + urllib.parse.urlunparse(parsed_uri)

    return get_artifact_repository(artifact_uri=root_uri).download_artifacts(
        artifact_path=artifact_path, dst_path=output_path)
Exemple #15
0
    def test_list_artifacts_with_relative_path(self):
        list_artifact_file_proto_mock = [
            FileInfo(path=posixpath.join(MOCK_SUBDIR, "a.txt"),
                     is_dir=False,
                     file_size=0)
        ]
        list_artifacts_dir_proto_mock = [
            FileInfo(path=posixpath.join(MOCK_SUBDIR, "test/a.txt"),
                     is_dir=False,
                     file_size=100),
            FileInfo(path=posixpath.join(MOCK_SUBDIR, "test/dir"),
                     is_dir=True,
                     file_size=0),
        ]
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                        "._get_run_artifact_root"
                        ) as get_run_artifact_root_mock, mock.patch(
                            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                            ".message_to_json") as message_mock, mock.patch(
                                DATABRICKS_ARTIFACT_REPOSITORY +
                                "._call_endpoint") as call_endpoint_mock:
            get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
            list_artifact_response_proto = ListArtifacts.Response(
                root_uri="",
                files=list_artifacts_dir_proto_mock,
                next_page_token=None)
            call_endpoint_mock.return_value = list_artifact_response_proto
            message_mock.return_value = None
            databricks_artifact_repo = get_artifact_repository(
                MOCK_SUBDIR_ROOT_URI)
            artifacts = databricks_artifact_repo.list_artifacts("test")
            assert isinstance(artifacts, list)
            assert isinstance(artifacts[0], FileInfoEntity)
            assert len(artifacts) == 2
            assert artifacts[0].path == "test/a.txt"
            assert artifacts[0].is_dir is False
            assert artifacts[0].file_size == 100
            assert artifacts[1].path == "test/dir"
            assert artifacts[1].is_dir is True
            assert artifacts[1].file_size is None
            message_mock.assert_called_with(
                ListArtifacts(run_id=MOCK_RUN_ID,
                              path=posixpath.join(MOCK_SUBDIR, "test")))

            # Calling list_artifacts() on a relative path that's a file should return an empty list
            list_artifact_response_proto = ListArtifacts.Response(
                root_uri="",
                files=list_artifact_file_proto_mock,
                next_page_token=None)
            call_endpoint_mock.return_value = list_artifact_response_proto
            artifacts = databricks_artifact_repo.list_artifacts("a.txt")
            assert len(artifacts) == 0
Exemple #16
0
    def list_artifacts(self, run_id, path=None):
        """
        List the artifacts for a run.

        :param run_id: The run to list artifacts from.
        :param path: The run's relative artifact path to list from. By default it is set to None
                     or the root artifact path.
        :return: List of :py:class:`mlflow.entities.FileInfo`
        """
        run = self.get_run(run_id)
        artifact_root = run.info.artifact_uri
        artifact_repo = get_artifact_repository(artifact_root)
        return artifact_repo.list_artifacts(path)
Exemple #17
0
 def test_init_validation_and_cleaning(self):
     with mock.patch(
             DBFS_ARTIFACT_REPOSITORY_PACKAGE +
             "._get_host_creds_from_default_store") as get_creds_mock:
         get_creds_mock.return_value = lambda: MlflowHostCreds("http://host"
                                                               )
         repo = get_artifact_repository("dbfs:/test/")
         assert repo.artifact_uri == "dbfs:/test"
         with pytest.raises(MlflowException):
             DbfsRestArtifactRepository("s3://test")
         with pytest.raises(MlflowException):
             DbfsRestArtifactRepository(
                 "dbfs://profile@notdatabricks/test/")
def test_download_file_artifact_succeeds_when_artifact_root_is_s3_bucket_root(
        s3_artifact_root, tmpdir):
    file_a_name = "a.txt"
    file_a_text = "A"
    file_a_path = os.path.join(str(tmpdir), file_a_name)
    with open(file_a_path, "w") as f:
        f.write(file_a_text)

    repo = get_artifact_repository(s3_artifact_root)
    repo.log_artifact(file_a_path)

    downloaded_file_path = repo.download_artifacts(file_a_name)
    with open(downloaded_file_path, "r") as f:
        assert f.read() == file_a_text
Exemple #19
0
def test_dbfs_artifact_repo_delegates_to_correct_repo(is_dbfs_fuse_available,
                                                      host_creds_mock):  # pylint: disable=unused-argument
    # fuse available
    is_dbfs_fuse_available.return_value = True
    artifact_uri = "dbfs:/databricks/my/absolute/dbfs/path"
    repo = get_artifact_repository(artifact_uri)
    assert isinstance(repo, LocalArtifactRepository)
    assert repo.artifact_dir == os.path.join(os.path.sep, "dbfs", "databricks",
                                             "my", "absolute", "dbfs", "path")
    # fuse available but a model repository DBFS location
    repo = get_artifact_repository(
        "dbfs:/databricks/mlflow-registry/version12345/models")
    assert isinstance(repo, DbfsRestArtifactRepository)
    # fuse not available
    with mock.patch.dict(os.environ,
                         {'MLFLOW_ENABLE_DBFS_FUSE_ARTIFACT_REPO': 'false'}):
        fuse_disabled_repo = get_artifact_repository(artifact_uri)
    assert isinstance(fuse_disabled_repo, DbfsRestArtifactRepository)
    assert fuse_disabled_repo.artifact_uri == artifact_uri
    is_dbfs_fuse_available.return_value = False
    rest_repo = get_artifact_repository(artifact_uri)
    assert isinstance(rest_repo, DbfsRestArtifactRepository)
    assert rest_repo.artifact_uri == artifact_uri
def test_get_s3_client_verify_param_set_correctly(s3_artifact_root,
                                                  ignore_tls_env, verify):
    from unittest.mock import ANY

    with mock.patch.dict("os.environ",
                         {"MLFLOW_S3_IGNORE_TLS": ignore_tls_env},
                         clear=True):
        with mock.patch("boto3.client") as mock_get_s3_client:
            repo = get_artifact_repository(
                posixpath.join(s3_artifact_root, "some/path"))
            repo._get_s3_client()
            mock_get_s3_client.assert_called_with("s3",
                                                  config=ANY,
                                                  endpoint_url=ANY,
                                                  verify=verify)
def test_plugin_registration_via_installed_package():
    """This test requires the package in tests/resources/mlflow-test-plugin to be installed"""

    reload(artifact_repository_registry)

    assert "file-plugin" in artifact_repository_registry._artifact_repository_registry._registry

    from mlflow_test_plugin.local_artifact import PluginLocalArtifactRepository

    test_uri = "file-plugin:test-path"

    plugin_repo = artifact_repository_registry.get_artifact_repository(test_uri)

    assert isinstance(plugin_repo, PluginLocalArtifactRepository)
    assert plugin_repo.is_plugin
Exemple #22
0
    def list_artifacts(self, run_id, path=None):
        """
        List the artifacts for a run.

        :param run_id: The run to list artifacts from.
        :param path: The run's relative artifact path to list from. By default it is set to None
                     or the root artifact path.
        :return: List of :py:class:`mlflow.entities.FileInfo`
        """
        run = self.get_run(run_id)
        artifact_root = run.info.artifact_uri
        # TODO: use add_databricks_profile_info_to_artifact_uri(artifact_root, self.tracking_uri)
        #   for DBFS artifact repositories to use the correct tracking URI.
        artifact_repo = get_artifact_repository(artifact_root)
        return artifact_repo.list_artifacts(path)
Exemple #23
0
    def log_artifact(self, run_id, local_path, artifact_path=None):
        """
        Write a local file or directory to the remote ``artifact_uri``.

        :param local_path: Path to the file or directory to write.
        :param artifact_path: If provided, the directory in ``artifact_uri`` to write to.
        """
        run = self.get_run(run_id)
        artifact_repo = get_artifact_repository(run.info.artifact_uri)
        if os.path.isdir(local_path):
            dir_name = os.path.basename(os.path.normpath(local_path))
            path_name = os.path.join(artifact_path, dir_name) \
                if artifact_path is not None else dir_name
            artifact_repo.log_artifacts(local_path, path_name)
        else:
            artifact_repo.log_artifact(local_path, artifact_path)
Exemple #24
0
def test_file_artifact_is_logged_with_content_metadata(s3_artifact_root, tmpdir):
    file_name = "test.txt"
    file_path = os.path.join(str(tmpdir), file_name)
    file_text = "Hello world!"

    with open(file_path, "w") as f:
        f.write(file_text)

    repo = get_artifact_repository(posixpath.join(s3_artifact_root, "some/path"))
    repo.log_artifact(file_path)

    bucket, _ = repo.parse_s3_uri(s3_artifact_root)
    s3_client = repo._get_s3_client()
    response = s3_client.head_object(Bucket=bucket, Key="some/path/test.txt")
    assert response.get("ContentType") == "text/plain"
    assert response.get("ContentEncoding") is None
Exemple #25
0
 def _get_artifact_repo(self, run_id):
     # Attempt to fetch the artifact repo from a local cache
     cached_repo = TrackingServiceClient._artifact_repos_cache.get(run_id)
     if cached_repo is not None:
         return cached_repo
     else:
         run = self.get_run(run_id)
         artifact_uri = add_databricks_profile_info_to_artifact_uri(
             run.info.artifact_uri, self.tracking_uri)
         artifact_repo = get_artifact_repository(artifact_uri)
         # Cache the artifact repo to avoid a future network call, removing the oldest
         # entry in the cache if there are too many elements
         if len(TrackingServiceClient._artifact_repos_cache) > 1024:
             TrackingServiceClient._artifact_repos_cache.popitem(last=False)
         TrackingServiceClient._artifact_repos_cache[run_id] = artifact_repo
         return artifact_repo
Exemple #26
0
def _download_artifact_from_uri(artifact_uri, output_path=None):
    """
    :param artifact_uri: The *absolute* URI of the artifact to download.
    :param output_path: The local filesystem path to which to download the artifact. If unspecified,
                        a local output path will be created.
    """
    parsed_uri = urllib.parse.urlparse(artifact_uri)
    prefix = ""
    if parsed_uri.scheme and not parsed_uri.path.startswith("/"):
        # relative path is a special case, urllib does not reconstruct it properly
        prefix = parsed_uri.scheme + ":"
        parsed_uri = parsed_uri._replace(scheme="")
    artifact_path = posixpath.basename(parsed_uri.path)
    parsed_uri = parsed_uri._replace(path=posixpath.dirname(parsed_uri.path))
    root_uri = prefix + urllib.parse.urlunparse(parsed_uri)
    return get_artifact_repository(artifact_uri=root_uri).download_artifacts(
        artifact_path=artifact_path, dst_path=output_path)
def test_download_directory_artifact_succeeds_when_artifact_root_is_s3_bucket_root(
        s3_artifact_root, tmpdir):
    file_a_name = "a.txt"
    file_a_text = "A"
    subdir_path = str(tmpdir.mkdir("subdir"))
    nested_path = os.path.join(subdir_path, "nested")
    os.makedirs(nested_path)
    with open(os.path.join(nested_path, file_a_name), "w") as f:
        f.write(file_a_text)

    repo = get_artifact_repository(s3_artifact_root)
    repo.log_artifacts(subdir_path)

    downloaded_dir_path = repo.download_artifacts("nested")
    assert file_a_name in os.listdir(downloaded_dir_path)
    with open(os.path.join(downloaded_dir_path, file_a_name), "r") as f:
        assert f.read() == file_a_text
Exemple #28
0
    def log_artifact(self, run_id, local_path, artifact_path=None):
        """
        Write a local file or directory to the remote ``artifact_uri``.

        :param local_path: Path to the file or directory to write.
        :param artifact_path: If provided, the directory in ``artifact_uri`` to write to.
        """
        run = self.get_run(run_id)
        # TODO: use add_databricks_profile_info_to_artifact_uri(artifact_root, self.tracking_uri)
        #   for DBFS artifact repositories to use the correct tracking URI.
        artifact_repo = get_artifact_repository(run.info.artifact_uri)
        if os.path.isdir(local_path):
            dir_name = os.path.basename(os.path.normpath(local_path))
            path_name = os.path.join(artifact_path, dir_name) \
                if artifact_path is not None else dir_name
            artifact_repo.log_artifacts(local_path, path_name)
        else:
            artifact_repo.log_artifact(local_path, artifact_path)
    def test_init_validation_and_cleaning(self):
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
                as get_run_artifact_root_mock:
            get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
            # Basic artifact uri
            repo = get_artifact_repository(
                'dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts'
            )
            assert repo.artifact_uri == 'dbfs:/databricks/mlflow-tracking/' \
                                        'MOCK-EXP/MOCK-RUN-ID/artifacts'
            assert repo.run_id == MOCK_RUN_ID
            assert repo.run_relative_artifact_repo_root_path == ""

            with pytest.raises(MlflowException):
                DatabricksArtifactRepository('s3://test')
            with pytest.raises(MlflowException):
                DatabricksArtifactRepository(
                    'dbfs:/databricks/mlflow/EXP/RUN/artifact')
Exemple #30
0
    def download_artifacts(self, run_id, path, dst_path=None):
        """
        Download an artifact file or directory from a run to a local directory if applicable,
        and return a local path for it.

        :param run_id: The run to download artifacts from.
        :param path: Relative source path to the desired artifact.
        :param dst_path: Absolute path of the local filesystem destination directory to which to
                         download the specified artifacts. This directory must already exist.
                         If unspecified, the artifacts will either be downloaded to a new
                         uniquely-named directory on the local filesystem or will be returned
                         directly in the case of the LocalArtifactRepository.
        :return: Local path of desired artifact.
        """
        run = self.get_run(run_id)
        artifact_root = run.info.artifact_uri
        artifact_repo = get_artifact_repository(artifact_root)
        return artifact_repo.download_artifacts(path, dst_path)