def test_download_artifact_throws_value_error_when_listed_blobs_do_not_contain_artifact_root_prefix( mock_client): repo = AzureBlobArtifactRepository(TEST_URI, mock_client) # Create a "bad blob" with a name that is not prefixed by the root path of the artifact store bad_blob_props = BlobProperties() bad_blob_props.content_length = 42 bad_blob = Blob("file_path", props=bad_blob_props) def get_mock_listing(*args, **kwargs): """ Produces a mock listing that only contains content if the specified prefix is the artifact root. This allows us to mock `list_artifacts` during the `_download_artifacts_into` subroutine without recursively listing the same artifacts at every level of the directory traversal. """ # pylint: disable=unused-argument if os.path.abspath( kwargs["prefix"]) == os.path.abspath(TEST_ROOT_PATH): # Return a blob that is not prefixed by the root path of the artifact store. This # should result in an exception being raised return MockBlobList([bad_blob]) else: return MockBlobList([]) mock_client.list_blobs.side_effect = get_mock_listing with pytest.raises(ValueError) as exc: repo.download_artifacts("") assert "Azure blob does not begin with the specified artifact path" in str( exc)
def test_log_artifact(mock_client, tmpdir): repo = AzureBlobArtifactRepository(TEST_URI, mock_client) d = tmpdir.mkdir("data") f = d.join("test.txt") f.write("hello world!") fpath = d + '/test.txt' fpath = fpath.strpath repo.log_artifact(fpath) mock_client.create_blob_from_path.assert_called_with( "container", TEST_ROOT_PATH + "/test.txt", fpath)
def test_download_directory_artifact_succeeds_when_artifact_root_is_blob_container_root( mock_client, tmpdir): repo = AzureBlobArtifactRepository(TEST_BLOB_CONTAINER_ROOT, mock_client) subdir_path = "my_directory" dir_prefix = BlobPrefix() dir_prefix.name = subdir_path file_path_1 = "file_1" file_path_2 = "file_2" blob_props_1 = BlobProperties() blob_props_1.content_length = 42 blob_1 = Blob(os.path.join(subdir_path, file_path_1), props=blob_props_1) blob_props_2 = BlobProperties() blob_props_2.content_length = 42 blob_2 = Blob(os.path.join(subdir_path, file_path_2), props=blob_props_2) def get_mock_listing(*args, **kwargs): """ Produces a mock listing that only contains content if the specified prefix is the artifact root or a relevant subdirectory. This allows us to mock `list_artifacts` during the `_download_artifacts_into` subroutine without recursively listing the same artifacts at every level of the directory traversal. """ # pylint: disable=unused-argument if os.path.abspath(kwargs["prefix"]) == "/": return MockBlobList([dir_prefix]) if os.path.abspath(kwargs["prefix"]) == os.path.abspath(subdir_path): return MockBlobList([blob_1, blob_2]) else: return MockBlobList([]) def create_file(container, cloud_path, local_path): # pylint: disable=unused-argument fname = os.path.basename(local_path) f = tmpdir.join(fname) f.write("hello world!") mock_client.list_blobs.side_effect = get_mock_listing mock_client.get_blob_to_path.side_effect = create_file # Ensure that the root directory can be downloaded successfully repo.download_artifacts("") # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir` dir_contents = os.listdir(tmpdir.strpath) assert file_path_1 in dir_contents assert file_path_2 in dir_contents
def from_artifact_uri(artifact_uri, store): """ Given an artifact URI for an Experiment Run (e.g., /local/file/path or s3://my/bucket), returns an ArtifactReposistory instance capable of logging and downloading artifacts on behalf of this URI. :param store: An instance of AbstractStore which the artifacts are registered in. """ if artifact_uri.startswith("s3:/"): # Import these locally to avoid creating a circular import loop from mlflow.store.s3_artifact_repo import S3ArtifactRepository return S3ArtifactRepository(artifact_uri) elif artifact_uri.startswith("gs:/"): from mlflow.store.gcs_artifact_repo import GCSArtifactRepository return GCSArtifactRepository(artifact_uri) elif artifact_uri.startswith("wasbs:/"): from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository return AzureBlobArtifactRepository(artifact_uri) elif artifact_uri.startswith("sftp:/"): from mlflow.store.sftp_artifact_repo import SFTPArtifactRepository return SFTPArtifactRepository(artifact_uri) elif artifact_uri.startswith("dbfs:/"): from mlflow.store.dbfs_artifact_repo import DbfsArtifactRepository if not isinstance(store, DatabricksStore): raise MlflowException( '`store` must be an instance of DatabricksStore.') return DbfsArtifactRepository(artifact_uri, store.http_request_kwargs) else: from mlflow.store.local_artifact_repo import LocalArtifactRepository return LocalArtifactRepository(artifact_uri)
def test_log_artifacts(mock_client, tmpdir): repo = AzureBlobArtifactRepository(TEST_URI, mock_client) parentd = tmpdir.mkdir("data") subd = parentd.mkdir("subdir") parentd.join("a.txt").write("A") subd.join("b.txt").write("B") subd.join("c.txt").write("C") repo.log_artifacts(parentd.strpath) mock_client.create_blob_from_path.assert_has_calls([ mock.call("container", TEST_ROOT_PATH + "/a.txt", parentd.strpath + "/a.txt"), mock.call("container", TEST_ROOT_PATH + "/subdir/b.txt", subd.strpath + "/b.txt"), mock.call("container", TEST_ROOT_PATH + "/subdir/c.txt", subd.strpath + "/c.txt"), ], any_order=True)
def test_download_file_artifact(mock_client, tmpdir): repo = AzureBlobArtifactRepository(TEST_URI, mock_client) mock_client.list_blobs.return_value = MockBlobList([]) def create_file(container, cloud_path, local_path): # pylint: disable=unused-argument local_path = os.path.basename(local_path) f = tmpdir.join(local_path) f.write("hello world!") mock_client.get_blob_to_path.side_effect = create_file repo.download_artifacts("test.txt") assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt")) mock_client.get_blob_to_path.assert_called_with( "container", TEST_ROOT_PATH + "/test.txt", mock.ANY)
def test_download_artifacts(mock_client, tmpdir): repo = AzureBlobArtifactRepository(TEST_URI, mock_client) mock_client.list_blobs.return_value = MockBlobList([]) def create_file(container, cloud_path, local_path): # pylint: disable=unused-argument local_path = local_path.replace(tmpdir.strpath, '') f = tmpdir.join(local_path) f.write("hello world!") return f.strpath mock_client.get_blob_to_path.side_effect = create_file open(repo._download_artifacts_into("test.txt", tmpdir.strpath)).read() mock_client.get_blob_to_path.assert_called_with( "container", TEST_ROOT_PATH + "/test.txt", tmpdir.strpath + "/test.txt")
def test_list_artifacts(mock_client): repo = AzureBlobArtifactRepository(TEST_URI, mock_client) # Create some files to return dir_prefix = BlobPrefix() dir_prefix.name = TEST_ROOT_PATH + "/dir" blob_props = BlobProperties() blob_props.content_length = 42 blob = Blob(TEST_ROOT_PATH + "/file", props=blob_props) mock_client.list_blobs.return_value = MockBlobList([dir_prefix, blob]) artifacts = repo.list_artifacts() assert artifacts[0].path == "dir" assert artifacts[0].is_dir is True assert artifacts[0].file_size is None assert artifacts[1].path == "file" assert artifacts[1].is_dir is False assert artifacts[1].file_size == 42
def from_artifact_uri(artifact_uri): """ Given an artifact URI for an Experiment Run (e.g., /local/file/path or s3://my/bucket), returns an ArtifactReposistory instance capable of logging and downloading artifacts on behalf of this URI. """ if artifact_uri.startswith("s3:/"): # Import these locally to avoid creating a circular import loop from mlflow.store.s3_artifact_repo import S3ArtifactRepository return S3ArtifactRepository(artifact_uri) elif artifact_uri.startswith("gs:/"): from mlflow.store.gcs_artifact_repo import GCSArtifactRepository return GCSArtifactRepository(artifact_uri) elif artifact_uri.startswith("wasbs:/"): from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository return AzureBlobArtifactRepository(artifact_uri) else: from mlflow.store.local_artifact_repo import LocalArtifactRepository return LocalArtifactRepository(artifact_uri)
def test_list_artifacts_empty(mock_client): repo = AzureBlobArtifactRepository(TEST_URI, mock_client) mock_client.list_blobs.return_value = MockBlobList([]) assert repo.list_artifacts() == []
def test_exception_if_no_env_vars(mock_client): # pylint: disable=unused-argument # We pass in the mock_client here to clear Azure environment variables, but we don't use it with pytest.raises(Exception, match="AZURE_STORAGE_CONNECTION_STRING"): AzureBlobArtifactRepository(TEST_URI)