Example #1
0
def test_download_file_artifact_succeeds_when_artifact_root_is_s3_bucket_root(
        s3_artifact_root, tmpdir):
    file_a_name = "a.txt"
    file_a_text = "A"
    file_a_path = os.path.join(str(tmpdir), file_a_name)
    with open(file_a_path, "w") as f:
        f.write(file_a_text)

    repo = get_artifact_repository(s3_artifact_root)
    repo.log_artifact(file_a_path)

    downloaded_file_path = repo.download_artifacts(file_a_name)
    with open(downloaded_file_path, "r") as f:
        assert f.read() == file_a_text
Example #2
0
def test_file_artifact_is_logged_and_downloaded_successfully(
        s3_artifact_root, tmpdir):
    file_name = "test.txt"
    file_path = os.path.join(str(tmpdir), file_name)
    file_text = "Hello world!"

    with open(file_path, "w") as f:
        f.write(file_text)

    repo = get_artifact_repository(
        posixpath.join(s3_artifact_root, "some/path"))
    repo.log_artifact(file_path)
    downloaded_text = open(repo.download_artifacts(file_name)).read()
    assert downloaded_text == file_text
Example #3
0
    def test_list_artifacts_with_relative_path(self):
        list_artifact_file_proto_mock = [
            FileInfo(path=posixpath.join(MOCK_SUBDIR, 'a.txt'),
                     is_dir=False,
                     file_size=0)
        ]
        list_artifacts_dir_proto_mock = [
            FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/a.txt'),
                     is_dir=False,
                     file_size=100),
            FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/dir'),
                     is_dir=True,
                     file_size=0)
        ]
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
                as get_run_artifact_root_mock, \
                mock.patch(
                    DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + '.message_to_json')as message_mock, \
                mock.patch(
                    DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock:
            get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
            list_artifact_response_proto = \
                ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock,
                                       next_page_token=None)
            call_endpoint_mock.return_value = list_artifact_response_proto
            message_mock.return_value = None
            databricks_artifact_repo = get_artifact_repository(
                MOCK_SUBDIR_ROOT_URI)
            artifacts = databricks_artifact_repo.list_artifacts('test')
            assert isinstance(artifacts, list)
            assert isinstance(artifacts[0], FileInfoEntity)
            assert len(artifacts) == 2
            assert artifacts[0].path == 'test/a.txt'
            assert artifacts[0].is_dir is False
            assert artifacts[0].file_size == 100
            assert artifacts[1].path == 'test/dir'
            assert artifacts[1].is_dir is True
            assert artifacts[1].file_size is None
            message_mock.assert_called_with(
                ListArtifacts(run_id=MOCK_RUN_ID,
                              path=posixpath.join(MOCK_SUBDIR, "test")))

            # Calling list_artifacts() on a relative path that's a file should return an empty list
            list_artifact_response_proto = \
                ListArtifacts.Response(root_uri='', files=list_artifact_file_proto_mock,
                                       next_page_token=None)
            call_endpoint_mock.return_value = list_artifact_response_proto
            artifacts = databricks_artifact_repo.list_artifacts('a.txt')
            assert len(artifacts) == 0
Example #4
0
    def log_artifact(self, run_id, local_path, artifact_path=None):
        """
        Write a local file or directory to the remote ``artifact_uri``.

        :param local_path: Path to the file or directory to write.
        :param artifact_path: If provided, the directory in ``artifact_uri`` to write to.
        """
        run = self.get_run(run_id)
        artifact_repo = get_artifact_repository(run.info.artifact_uri)
        if os.path.isdir(local_path):
            dir_name = os.path.basename(os.path.normpath(local_path))
            path_name = os.path.join(artifact_path, dir_name) \
                if artifact_path is not None else dir_name
            artifact_repo.log_artifacts(local_path, path_name)
        else:
            artifact_repo.log_artifact(local_path, artifact_path)
Example #5
0
def test_download_directory_artifact_succeeds_when_artifact_root_is_s3_bucket_root(
        s3_artifact_root, tmpdir):
    file_a_name = "a.txt"
    file_a_text = "A"
    subdir_path = str(tmpdir.mkdir("subdir"))
    nested_path = os.path.join(subdir_path, "nested")
    os.makedirs(nested_path)
    with open(os.path.join(nested_path, file_a_name), "w") as f:
        f.write(file_a_text)

    repo = get_artifact_repository(s3_artifact_root)
    repo.log_artifacts(subdir_path)

    downloaded_dir_path = repo.download_artifacts("nested")
    assert file_a_name in os.listdir(downloaded_dir_path)
    with open(os.path.join(downloaded_dir_path, file_a_name), "r") as f:
        assert f.read() == file_a_text
def test_plugin_registration_via_installed_package():
    """This test requires the package in tests/resources/mlflow-test-plugin to be installed"""

    reload(artifact_repository_registry)

    assert (
        "file-plugin" in
        artifact_repository_registry._artifact_repository_registry._registry
    )

    from mlflow_test_plugin.local_artifact import PluginLocalArtifactRepository

    test_uri = "file-plugin:test-path"

    plugin_repo = artifact_repository_registry.get_artifact_repository(test_uri)

    assert isinstance(plugin_repo, PluginLocalArtifactRepository)
    assert plugin_repo.is_plugin
Example #7
0
    def test_init_validation_and_cleaning(self):
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
                as get_run_artifact_root_mock:
            get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
            # Basic artifact uri
            repo = get_artifact_repository(
                'dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts'
            )
            assert repo.artifact_uri == 'dbfs:/databricks/mlflow-tracking/' \
                                        'MOCK-EXP/MOCK-RUN-ID/artifacts'
            assert repo.run_id == MOCK_RUN_ID
            assert repo.run_relative_artifact_repo_root_path == ""

            with pytest.raises(MlflowException):
                DatabricksArtifactRepository('s3://test')
            with pytest.raises(MlflowException):
                DatabricksArtifactRepository(
                    'dbfs:/databricks/mlflow/EXP/RUN/artifact')
Example #8
0
def test_file_artifact_is_logged_with_content_metadata(s3_artifact_root,
                                                       tmpdir):
    file_name = "test.txt"
    file_path = os.path.join(str(tmpdir), file_name)
    file_text = "Hello world!"

    with open(file_path, "w") as f:
        f.write(file_text)

    repo = get_artifact_repository(
        posixpath.join(s3_artifact_root, "some/path"))
    repo.log_artifact(file_path)

    bucket, _ = repo.parse_s3_uri(s3_artifact_root)
    s3_client = repo._get_s3_client()
    response = s3_client.head_object(Bucket=bucket, Key="some/path/test.txt")
    assert response.get("ContentType") == "text/plain"
    assert response.get("ContentEncoding") is None
Example #9
0
    def download_artifacts(self, run_id, path, dst_path=None):
        """
        Download an artifact file or directory from a run to a local directory if applicable,
        and return a local path for it.

        :param run_id: The run to download artifacts from.
        :param path: Relative source path to the desired artifact.
        :param dst_path: Absolute path of the local filesystem destination directory to which to
                         download the specified artifacts. This directory must already exist.
                         If unspecified, the artifacts will either be downloaded to a new
                         uniquely-named directory on the local filesystem or will be returned
                         directly in the case of the LocalArtifactRepository.
        :return: Local path of desired artifact.
        """
        run = self.get_run(run_id)
        artifact_root = run.info.artifact_uri
        artifact_repo = get_artifact_repository(artifact_root)
        return artifact_repo.download_artifacts(path, dst_path)
def test_plugin_registration_via_entrypoints():
    mock_plugin_function = mock.Mock()
    mock_entrypoint = mock.Mock(load=mock.Mock(return_value=mock_plugin_function))
    mock_entrypoint.name = "mock-scheme"

    with mock.patch(
        "entrypoints.get_group_all", return_value=[mock_entrypoint]
    ) as mock_get_group_all:

        artifact_repository_registry = ArtifactRepositoryRegistry()
        artifact_repository_registry.register_entrypoints()

    assert (
        artifact_repository_registry.get_artifact_repository("mock-scheme://fake-host/fake-path")
        == mock_plugin_function.return_value
    )

    mock_plugin_function.assert_called_once_with("mock-scheme://fake-host/fake-path")
    mock_get_group_all.assert_called_once_with("mlflow.artifact_repository")
Example #11
0
File: cli.py Project: iPieter/kiwi
def download_artifacts(run_id, artifact_path, artifact_uri):
    """
    Download an artifact file or directory to a local directory.
    The output is the name of the file or directory on the local disk.

    Either ``--run-id`` or ``--artifact-uri`` must be provided.
    """
    if run_id is None and artifact_uri is None:
        _logger.error(
            "Either ``--run-id`` or ``--artifact-uri`` must be provided.")
        sys.exit(1)

    if artifact_uri is not None:
        print(_download_artifact_from_uri(artifact_uri))
        return

    artifact_path = artifact_path if artifact_path is not None else ""
    store = _get_store()
    artifact_uri = store.get_run(run_id).info.artifact_uri
    artifact_repo = get_artifact_repository(artifact_uri)
    artifact_location = artifact_repo.download_artifacts(artifact_path)
    print(artifact_location)
Example #12
0
def test_file_and_directories_artifacts_are_logged_and_downloaded_successfully_in_batch(
        s3_artifact_root, tmpdir):
    subdir_path = str(tmpdir.mkdir("subdir"))
    nested_path = os.path.join(subdir_path, "nested")
    os.makedirs(nested_path)
    with open(os.path.join(subdir_path, "a.txt"), "w") as f:
        f.write("A")
    with open(os.path.join(subdir_path, "b.txt"), "w") as f:
        f.write("B")
    with open(os.path.join(nested_path, "c.txt"), "w") as f:
        f.write("C")

    repo = get_artifact_repository(
        posixpath.join(s3_artifact_root, "some/path"))
    repo.log_artifacts(subdir_path)

    # Download individual files and verify correctness of their contents
    downloaded_file_a_text = open(repo.download_artifacts("a.txt")).read()
    assert downloaded_file_a_text == "A"
    downloaded_file_b_text = open(repo.download_artifacts("b.txt")).read()
    assert downloaded_file_b_text == "B"
    downloaded_file_c_text = open(
        repo.download_artifacts("nested/c.txt")).read()
    assert downloaded_file_c_text == "C"

    # Download the nested directory and verify correctness of its contents
    downloaded_dir = repo.download_artifacts("nested")
    assert os.path.basename(downloaded_dir) == "nested"
    text = open(os.path.join(downloaded_dir, "c.txt")).read()
    assert text == "C"

    # Download the root directory and verify correctness of its contents
    downloaded_dir = repo.download_artifacts("")
    dir_contents = os.listdir(downloaded_dir)
    assert "nested" in dir_contents
    assert os.path.isdir(os.path.join(downloaded_dir, "nested"))
    assert "a.txt" in dir_contents
    assert "b.txt" in dir_contents
Example #13
0
def dbfs_fuse_artifact_repo(force_dbfs_fuse_repo):  # pylint: disable=unused-argument
    return get_artifact_repository('dbfs:/unused/path/replaced/by/mock')
Example #14
0
def test_artifact_uri_factory():
    repo = get_artifact_repository("ftp://*****:*****@test_ftp:123/some/path")
    assert isinstance(repo, FTPArtifactRepository)
Example #15
0
def dbfs_artifact_repo():
    with mock.patch('mlflow.store.artifact.dbfs_artifact_repo._get_host_creds_from_default_store') \
            as get_creds_mock:
        get_creds_mock.return_value = lambda: MlflowHostCreds('http://host')
        return get_artifact_repository('dbfs:/test/')
Example #16
0
 def __init__(self, artifact_uri):
     from kiwi.store.artifact.artifact_repository_registry import get_artifact_repository
     uri = RunsArtifactRepository.get_underlying_uri(artifact_uri)
     super(RunsArtifactRepository, self).__init__(artifact_uri)
     self.repo = get_artifact_repository(uri)
Example #17
0
def test_artifact_uri_factory():
    from paramiko.ssh_exception import SSHException
    with pytest.raises(SSHException):
        get_artifact_repository("sftp://*****:*****@test_sftp:123/some/path")
Example #18
0
def _get_artifact_repo(run):
    return get_artifact_repository(run.info.artifact_uri)
Example #19
0
def test_artifact_uri_factory():
    repo = get_artifact_repository("gs://test_bucket/some/path")
    assert isinstance(repo, GCSArtifactRepository)
Example #20
0
def databricks_artifact_repo():
    with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
            as get_run_artifact_root_mock:
        get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
        return get_artifact_repository(
            "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")
def test_get_unknown_scheme():
    artifact_repository_registry = ArtifactRepositoryRegistry()

    with pytest.raises(kiwi.exceptions.MlflowException,
                       match="Could not find a registered artifact repository"):
        artifact_repository_registry.get_artifact_repository("unknown-scheme://")