示例#1
0
def test_download_artifact_throws_value_error_when_listed_blobs_do_not_contain_artifact_root_prefix(
        mock_client):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    # Create a "bad blob" with a name that is not prefixed by the root path of the artifact store
    bad_blob_props = BlobProperties()
    bad_blob_props.content_length = 42
    bad_blob = Blob("file_path", props=bad_blob_props)

    def get_mock_listing(*args, **kwargs):
        """
        Produces a mock listing that only contains content if the
        specified prefix is the artifact root. This allows us to mock
        `list_artifacts` during the `_download_artifacts_into` subroutine
        without recursively listing the same artifacts at every level of the
        directory traversal.
        """
        # pylint: disable=unused-argument
        if os.path.abspath(
                kwargs["prefix"]) == os.path.abspath(TEST_ROOT_PATH):
            # Return a blob that is not prefixed by the root path of the artifact store. This
            # should result in an exception being raised
            return MockBlobList([bad_blob])
        else:
            return MockBlobList([])

    mock_client.list_blobs.side_effect = get_mock_listing

    with pytest.raises(ValueError) as exc:
        repo.download_artifacts("")

    assert "Azure blob does not begin with the specified artifact path" in str(
        exc)
示例#2
0
def test_log_artifact(mock_client, tmpdir):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    d = tmpdir.mkdir("data")
    f = d.join("test.txt")
    f.write("hello world!")
    fpath = d + '/test.txt'
    fpath = fpath.strpath

    repo.log_artifact(fpath)
    mock_client.create_blob_from_path.assert_called_with(
        "container", TEST_ROOT_PATH + "/test.txt", fpath)
示例#3
0
def test_download_directory_artifact_succeeds_when_artifact_root_is_blob_container_root(
        mock_client, tmpdir):
    repo = AzureBlobArtifactRepository(TEST_BLOB_CONTAINER_ROOT, mock_client)

    subdir_path = "my_directory"
    dir_prefix = BlobPrefix()
    dir_prefix.name = subdir_path

    file_path_1 = "file_1"
    file_path_2 = "file_2"

    blob_props_1 = BlobProperties()
    blob_props_1.content_length = 42
    blob_1 = Blob(os.path.join(subdir_path, file_path_1), props=blob_props_1)

    blob_props_2 = BlobProperties()
    blob_props_2.content_length = 42
    blob_2 = Blob(os.path.join(subdir_path, file_path_2), props=blob_props_2)

    def get_mock_listing(*args, **kwargs):
        """
        Produces a mock listing that only contains content if the specified prefix is the artifact
        root or a relevant subdirectory. This allows us to mock `list_artifacts` during the
        `_download_artifacts_into` subroutine without recursively listing the same artifacts at
        every level of the directory traversal.
        """
        # pylint: disable=unused-argument
        if os.path.abspath(kwargs["prefix"]) == "/":
            return MockBlobList([dir_prefix])
        if os.path.abspath(kwargs["prefix"]) == os.path.abspath(subdir_path):
            return MockBlobList([blob_1, blob_2])
        else:
            return MockBlobList([])

    def create_file(container, cloud_path, local_path):
        # pylint: disable=unused-argument
        fname = os.path.basename(local_path)
        f = tmpdir.join(fname)
        f.write("hello world!")

    mock_client.list_blobs.side_effect = get_mock_listing
    mock_client.get_blob_to_path.side_effect = create_file

    # Ensure that the root directory can be downloaded successfully
    repo.download_artifacts("")
    # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir`
    dir_contents = os.listdir(tmpdir.strpath)
    assert file_path_1 in dir_contents
    assert file_path_2 in dir_contents
示例#4
0
 def from_artifact_uri(artifact_uri, store):
     """
     Given an artifact URI for an Experiment Run (e.g., /local/file/path or s3://my/bucket),
     returns an ArtifactReposistory instance capable of logging and downloading artifacts
     on behalf of this URI.
     :param store: An instance of AbstractStore which the artifacts are registered in.
     """
     if artifact_uri.startswith("s3:/"):
         # Import these locally to avoid creating a circular import loop
         from mlflow.store.s3_artifact_repo import S3ArtifactRepository
         return S3ArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("gs:/"):
         from mlflow.store.gcs_artifact_repo import GCSArtifactRepository
         return GCSArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("wasbs:/"):
         from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository
         return AzureBlobArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("sftp:/"):
         from mlflow.store.sftp_artifact_repo import SFTPArtifactRepository
         return SFTPArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("dbfs:/"):
         from mlflow.store.dbfs_artifact_repo import DbfsArtifactRepository
         if not isinstance(store, DatabricksStore):
             raise MlflowException(
                 '`store` must be an instance of DatabricksStore.')
         return DbfsArtifactRepository(artifact_uri,
                                       store.http_request_kwargs)
     else:
         from mlflow.store.local_artifact_repo import LocalArtifactRepository
         return LocalArtifactRepository(artifact_uri)
def test_log_artifacts(mock_client, tmpdir):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    parentd = tmpdir.mkdir("data")
    subd = parentd.mkdir("subdir")
    parentd.join("a.txt").write("A")
    subd.join("b.txt").write("B")
    subd.join("c.txt").write("C")

    repo.log_artifacts(parentd.strpath)

    mock_client.create_blob_from_path.assert_has_calls([
        mock.call("container", TEST_ROOT_PATH + "/a.txt", parentd.strpath + "/a.txt"),
        mock.call("container", TEST_ROOT_PATH + "/subdir/b.txt", subd.strpath + "/b.txt"),
        mock.call("container", TEST_ROOT_PATH + "/subdir/c.txt", subd.strpath + "/c.txt"),
    ], any_order=True)
示例#6
0
def test_download_file_artifact(mock_client, tmpdir):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    mock_client.list_blobs.return_value = MockBlobList([])

    def create_file(container, cloud_path, local_path):
        # pylint: disable=unused-argument
        local_path = os.path.basename(local_path)
        f = tmpdir.join(local_path)
        f.write("hello world!")

    mock_client.get_blob_to_path.side_effect = create_file

    repo.download_artifacts("test.txt")
    assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt"))
    mock_client.get_blob_to_path.assert_called_with(
        "container", TEST_ROOT_PATH + "/test.txt", mock.ANY)
def test_download_artifacts(mock_client, tmpdir):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    mock_client.list_blobs.return_value = MockBlobList([])

    def create_file(container, cloud_path, local_path):
        # pylint: disable=unused-argument
        local_path = local_path.replace(tmpdir.strpath, '')
        f = tmpdir.join(local_path)
        f.write("hello world!")
        return f.strpath

    mock_client.get_blob_to_path.side_effect = create_file

    open(repo._download_artifacts_into("test.txt", tmpdir.strpath)).read()
    mock_client.get_blob_to_path.assert_called_with(
        "container", TEST_ROOT_PATH + "/test.txt", tmpdir.strpath + "/test.txt")
示例#8
0
def test_list_artifacts(mock_client):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    # Create some files to return
    dir_prefix = BlobPrefix()
    dir_prefix.name = TEST_ROOT_PATH + "/dir"

    blob_props = BlobProperties()
    blob_props.content_length = 42
    blob = Blob(TEST_ROOT_PATH + "/file", props=blob_props)

    mock_client.list_blobs.return_value = MockBlobList([dir_prefix, blob])

    artifacts = repo.list_artifacts()
    assert artifacts[0].path == "dir"
    assert artifacts[0].is_dir is True
    assert artifacts[0].file_size is None
    assert artifacts[1].path == "file"
    assert artifacts[1].is_dir is False
    assert artifacts[1].file_size == 42
示例#9
0
 def from_artifact_uri(artifact_uri):
     """
     Given an artifact URI for an Experiment Run (e.g., /local/file/path or s3://my/bucket),
     returns an ArtifactReposistory instance capable of logging and downloading artifacts
     on behalf of this URI.
     """
     if artifact_uri.startswith("s3:/"):
         # Import these locally to avoid creating a circular import loop
         from mlflow.store.s3_artifact_repo import S3ArtifactRepository
         return S3ArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("gs:/"):
         from mlflow.store.gcs_artifact_repo import GCSArtifactRepository
         return GCSArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("wasbs:/"):
         from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository
         return AzureBlobArtifactRepository(artifact_uri)
     else:
         from mlflow.store.local_artifact_repo import LocalArtifactRepository
         return LocalArtifactRepository(artifact_uri)
示例#10
0
def test_list_artifacts_empty(mock_client):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)
    mock_client.list_blobs.return_value = MockBlobList([])
    assert repo.list_artifacts() == []
示例#11
0
def test_exception_if_no_env_vars(mock_client):
    # pylint: disable=unused-argument
    # We pass in the mock_client here to clear Azure environment variables, but we don't use it
    with pytest.raises(Exception, match="AZURE_STORAGE_CONNECTION_STRING"):
        AzureBlobArtifactRepository(TEST_URI)