Ejemplo n.º 1
0
def test_download_artifact_throws_value_error_when_listed_blobs_do_not_contain_artifact_root_prefix(
    mock_client, ):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    # Create a "bad blob" with a name that is not prefixed by the root path of the artifact store
    bad_blob_props = BlobProperties()
    bad_blob_props.size = 42
    bad_blob_props.name = "file_path"

    def get_mock_listing(*args, **kwargs):
        """
        Produces a mock listing that only contains content if the
        specified prefix is the artifact root. This allows us to mock
        `list_artifacts` during the `_download_artifacts_into` subroutine
        without recursively listing the same artifacts at every level of the
        directory traversal.
        """
        # pylint: disable=unused-argument
        if posixpath.abspath(kwargs["name_starts_with"]) == posixpath.abspath(
                TEST_ROOT_PATH):
            # Return a blob that is not prefixed by the root path of the artifact store. This
            # should result in an exception being raised
            return MockBlobList([bad_blob_props])
        else:
            return MockBlobList([])

    mock_client.get_container_client(
    ).walk_blobs.side_effect = get_mock_listing

    with pytest.raises(MlflowException) as exc:
        repo.download_artifacts("")

    assert "Azure blob does not begin with the specified artifact path" in str(
        exc)
Ejemplo n.º 2
0
def test_download_directory_artifact_succeeds_when_artifact_root_is_blob_container_root(
        mock_client, tmpdir):
    repo = AzureBlobArtifactRepository(TEST_BLOB_CONTAINER_ROOT, mock_client)

    subdir_path = "my_directory"
    dir_prefix = BlobPrefix()
    dir_prefix.name = subdir_path

    file_path_1 = "file_1"
    file_path_2 = "file_2"

    blob_props_1 = BlobProperties()
    blob_props_1.size = 42
    blob_props_1.name = posixpath.join(subdir_path, file_path_1)

    blob_props_2 = BlobProperties()
    blob_props_2.size = 42
    blob_props_2.name = posixpath.join(subdir_path, file_path_2)

    def get_mock_listing(*args, **kwargs):
        """
        Produces a mock listing that only contains content if the specified prefix is the artifact
        root or a relevant subdirectory. This allows us to mock `list_artifacts` during the
        `_download_artifacts_into` subroutine without recursively listing the same artifacts at
        every level of the directory traversal.
        """
        # pylint: disable=unused-argument
        if posixpath.abspath(kwargs["name_starts_with"]) == "/":
            return MockBlobList([dir_prefix])
        if posixpath.abspath(
                kwargs["name_starts_with"]) == posixpath.abspath(subdir_path):
            return MockBlobList([blob_props_1, blob_props_2])
        else:
            return MockBlobList([])

    def create_file(buffer):
        fname = os.path.basename(buffer.name)
        f = tmpdir.join(fname)
        f.write("hello world!")

    mock_client.get_container_client(
    ).walk_blobs.side_effect = get_mock_listing
    mock_client.get_container_client().download_blob(
    ).readinto.side_effect = create_file

    # Ensure that the root directory can be downloaded successfully
    repo.download_artifacts("")
    # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir`
    dir_contents = os.listdir(tmpdir.strpath)
    assert file_path_1 in dir_contents
    assert file_path_2 in dir_contents
Ejemplo n.º 3
0
def test_download_directory_artifact_succeeds_when_artifact_root_is_not_blob_container_root(
        mock_client, tmpdir):
    assert TEST_URI is not TEST_BLOB_CONTAINER_ROOT
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    file_path_1 = "file_1"
    file_path_2 = "file_2"

    blob_props_1 = BlobProperties()
    blob_props_1.content_length = 42
    blob_1 = Blob(posixpath.join(TEST_ROOT_PATH, file_path_1),
                  props=blob_props_1)

    blob_props_2 = BlobProperties()
    blob_props_2.content_length = 42
    blob_2 = Blob(posixpath.join(TEST_ROOT_PATH, file_path_2),
                  props=blob_props_2)

    def get_mock_listing(*args, **kwargs):
        """
        Produces a mock listing that only contains content if the
        specified prefix is the artifact root. This allows us to mock
        `list_artifacts` during the `_download_artifacts_into` subroutine
        without recursively listing the same artifacts at every level of the
        directory traversal.
        """
        # pylint: disable=unused-argument
        if posixpath.abspath(
                kwargs["prefix"]) == posixpath.abspath(TEST_ROOT_PATH):
            return MockBlobList([blob_1, blob_2])
        else:
            return MockBlobList([])

    def create_file(container, cloud_path, local_path):
        # pylint: disable=unused-argument
        fname = os.path.basename(local_path)
        f = tmpdir.join(fname)
        f.write("hello world!")

    mock_client.list_blobs.side_effect = get_mock_listing
    mock_client.get_blob_to_path.side_effect = create_file

    # Ensure that the root directory can be downloaded successfully
    repo.download_artifacts("")
    # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir`
    dir_contents = os.listdir(tmpdir.strpath)
    assert file_path_1 in dir_contents
    assert file_path_2 in dir_contents
Ejemplo n.º 4
0
def test_download_file_artifact(mock_client, tmpdir):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    mock_client.get_container_client().walk_blobs.return_value = MockBlobList([])

    def create_file(buffer):
        local_path = os.path.basename(buffer.name)
        f = tmpdir.join(local_path)
        f.write("hello world!")

    mock_client.get_container_client().download_blob().readinto.side_effect = create_file

    repo.download_artifacts("test.txt")
    assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt"))
    mock_client.get_container_client().download_blob.assert_called_with(
        posixpath.join(TEST_ROOT_PATH, "test.txt"))
Ejemplo n.º 5
0
def test_download_file_artifact(mock_client, tmpdir):
    repo = AzureBlobArtifactRepository(TEST_URI, mock_client)

    mock_client.list_blobs.return_value = MockBlobList([])

    def create_file(container, cloud_path, local_path):
        # pylint: disable=unused-argument
        local_path = os.path.basename(local_path)
        f = tmpdir.join(local_path)
        f.write("hello world!")

    mock_client.get_blob_to_path.side_effect = create_file

    repo.download_artifacts("test.txt")
    assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt"))
    mock_client.get_blob_to_path.assert_called_with(
        "container", TEST_ROOT_PATH + "/test.txt", mock.ANY)