def test_log_artifacts_provides_failure_info(self, databricks_artifact_repo, tmpdir): src_file1_path = os.path.join(str(tmpdir), "file_1.txt") with open(src_file1_path, "w") as f: f.write("file1") src_file2_path = os.path.join(str(tmpdir), "file_2.txt") with open(src_file2_path, "w") as f: f.write("file2") with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud" ) as upload_mock: write_credentials_mock.return_value = [ ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ] upload_mock.side_effect = [ MlflowException("MOCK ERROR 1"), MlflowException("MOCK ERROR 2"), ] with pytest.raises(MlflowException) as exc: databricks_artifact_repo.log_artifacts(str(tmpdir), "test_artifacts") err_msg = str(exc.value) assert MOCK_RUN_ROOT_URI in err_msg assert "file_1.txt" in err_msg assert "MOCK ERROR 1" in err_msg assert "file_2.txt" in err_msg assert "MOCK ERROR 2" in err_msg
def test_get_read_credential_infos_respects_max_request_size(self, databricks_artifact_repo): """ Verifies that the `_get_read_credential_infos` method, which is used to resolve read access credentials for a collection of artifacts, handles paginated responses properly, issuing incremental requests until all pages have been consumed """ assert _MAX_CREDENTIALS_REQUEST_SIZE == 2000, ( "The maximum request size configured by the client should be consistent with the" " Databricks backend. Only update this value of the backend limit has changed." ) # Create 3 chunks of paths, two of which have the maximum request size and one of which # is smaller than the maximum chunk size. Aggregate and pass these to # `_get_read_credential_infos`, validating that this method decomposes the aggregate # list into these expected chunks and makes 3 separate requests paths_chunk_1 = ["path1"] * _MAX_CREDENTIALS_REQUEST_SIZE paths_chunk_2 = ["path2"] * _MAX_CREDENTIALS_REQUEST_SIZE paths_chunk_3 = ["path3"] * 5 credential_infos_mock_1 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE) ] credential_infos_mock_2 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE) ] credential_infos_mock_3 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(5) ] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json" ) as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint" ) as call_endpoint_mock: call_endpoint_mock.side_effect = [ GetCredentialsForRead.Response(credential_infos=credential_infos_mock_1,), GetCredentialsForRead.Response(credential_infos=credential_infos_mock_2,), GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3,), ] databricks_artifact_repo._get_read_credential_infos( MOCK_RUN_ID, paths_chunk_1 + paths_chunk_2 + paths_chunk_3, ) assert call_endpoint_mock.call_count == 3 assert message_mock.call_count == 3 message_mock.assert_has_calls( [ mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_1)), mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_2)), mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_3)), ] )
def test_download_artifacts_provides_failure_info(self, databricks_artifact_repo): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos" ) as read_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts" ) as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud" ) as download_mock: get_list_mock.return_value = [ FileInfo(path="file_1.txt", is_dir=False, file_size=100), FileInfo(path="file_2.txt", is_dir=False, file_size=0), ] read_credentials_mock.return_value = [ ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ] download_mock.side_effect = [ MlflowException("MOCK ERROR 1"), MlflowException("MOCK ERROR 2"), ] with pytest.raises(MlflowException) as exc: databricks_artifact_repo.download_artifacts("test_path") err_msg = str(exc.value) assert MOCK_RUN_ROOT_URI in err_msg assert "file_1.txt" in err_msg assert "MOCK ERROR 1" in err_msg assert "file_2.txt" in err_msg assert "MOCK ERROR 2" in err_msg
def test_get_write_credential_infos_handles_pagination(self, databricks_artifact_repo): """ Verifies that the `_get_write_credential_infos` method, which is used to resolve write access credentials for a collection of artifacts, handles paginated responses properly, issuing incremental requests until all pages have been consumed """ credential_infos_mock_1 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL ), ArtifactCredentialInfo( signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL ), ] credential_infos_mock_2 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) ] credential_infos_mock_3 = [] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json" ) as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint" ) as call_endpoint_mock: get_credentials_for_write_responses = [ GetCredentialsForWrite.Response( credential_infos=credential_infos_mock_1, next_page_token="2" ), GetCredentialsForWrite.Response( credential_infos=credential_infos_mock_2, next_page_token="3" ), GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_3), ] call_endpoint_mock.side_effect = get_credentials_for_write_responses write_credential_infos = databricks_artifact_repo._get_write_credential_infos( MOCK_RUN_ID, ["testpath"], ) assert write_credential_infos == credential_infos_mock_1 + credential_infos_mock_2 message_mock.assert_has_calls( [ mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=["testpath"])), mock.call( GetCredentialsForWrite( run_id=MOCK_RUN_ID, path=["testpath"], page_token="2" ) ), mock.call( GetCredentialsForWrite( run_id=MOCK_RUN_ID, path=["testpath"], page_token="3" ) ), ] ) assert call_endpoint_mock.call_count == 3
def test_get_write_credential_infos_respects_max_request_size(self, databricks_artifact_repo): """ Verifies that the `_get_write_credential_infos` method, which is used to resolve write access credentials for a collection of artifacts, batches requests according to a maximum request size configured by the backend """ # Create 3 chunks of paths, two of which have the maximum request size and one of which # is smaller than the maximum chunk size. Aggregate and pass these to # `_get_write_credential_infos`, validating that this method decomposes the aggregate # list into these expected chunks and makes 3 separate requests paths_chunk_1 = ["path1"] * 2000 paths_chunk_2 = ["path2"] * 2000 paths_chunk_3 = ["path3"] * 5 credential_infos_mock_1 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(2000) ] credential_infos_mock_2 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(2000) ] credential_infos_mock_3 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(5) ] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json" ) as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint" ) as call_endpoint_mock: call_endpoint_mock.side_effect = [ GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_1), GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_2), GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_3), ] databricks_artifact_repo._get_write_credential_infos( MOCK_RUN_ID, paths_chunk_1 + paths_chunk_2 + paths_chunk_3, ) assert call_endpoint_mock.call_count == message_mock.call_count == 3 message_mock.assert_has_calls( [ mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=paths_chunk_1)), mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=paths_chunk_2)), mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=paths_chunk_3)), ] )
def test_log_artifact_gcp_with_headers( self, databricks_artifact_repo, test_file, artifact_path, expected_location ): expected_headers = {header.name: header.value for header in MOCK_HEADERS} mock_response = Response() mock_response.status_code = 200 mock_response.close = lambda: None with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos" ) as write_credential_infos_mock, mock.patch( "mlflow.utils.rest_utils.cloud_storage_http_request" ) as request_mock: mock_credential_info = ArtifactCredentialInfo( signed_uri=MOCK_GCP_SIGNED_URL, type=ArtifactCredentialType.GCP_SIGNED_URL, headers=MOCK_HEADERS, ) write_credential_infos_mock.return_value = [mock_credential_info] request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credential_infos_mock.assert_called_with( run_id=MOCK_RUN_ID, paths=[expected_location] ) request_mock.assert_called_with( "put", MOCK_GCP_SIGNED_URL, data=ANY, headers=expected_headers )
def test_databricks_download_file_with_relative_path( self, remote_file_path, local_path): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credentials") as read_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts") as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud") as download_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto download_mock.return_value = None get_list_mock.return_value = [] databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) databricks_artifact_repo.download_artifacts( remote_file_path, local_path) read_credentials_mock.assert_called_with( MOCK_RUN_ID, posixpath.join(MOCK_SUBDIR, remote_file_path)) download_mock.assert_called_with(mock_credentials, ANY)
def test_log_artifact_aws_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): expected_headers = { header.name: header.value for header in MOCK_HEADERS } mock_response = Response() mock_response.status_code = 200 with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "requests.put") as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL, headers=MOCK_HEADERS, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with(ANY, ANY, headers=expected_headers)
def test_databricks_download_file_with_relative_path(self, remote_file_path, local_path): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos" ) as read_credential_infos_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts" ) as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud" ) as download_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI mock_credential_info = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ) read_credential_infos_mock.return_value = [mock_credential_info] download_mock.return_value = None get_list_mock.return_value = [] databricks_artifact_repo = get_artifact_repository(MOCK_SUBDIR_ROOT_URI) databricks_artifact_repo.download_artifacts(remote_file_path, local_path) read_credential_infos_mock.assert_called_with( run_id=MOCK_RUN_ID, paths=[posixpath.join(MOCK_SUBDIR, remote_file_path)] ) download_mock.assert_called_with( cloud_credential_info=mock_credential_info, dst_local_file_path=ANY, )
def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud") as upload_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) upload_mock.assert_called_with(write_credentials_response_proto, test_file.strpath, expected_location)
def test_log_artifact_gcp(self, databricks_artifact_repo, test_file, artifact_path, expected_location): mock_response = Response() mock_response.status_code = 200 mock_response.close = lambda: None with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "mlflow.utils.rest_utils.cloud_storage_http_request" ) as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_GCP_SIGNED_URL, type=ArtifactCredentialType.GCP_SIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with("put", MOCK_GCP_SIGNED_URL, ANY, headers={})
def test_artifact_logging_awaits_upload_completion( self, databricks_artifact_repo, tmpdir): """ Verifies that all asynchronous artifact uploads initiated by `log_artifact()` and `log_artifacts()` are joined before these methods return a result to the caller """ src_dir = os.path.join(str(tmpdir), "src") os.makedirs(src_dir) src_file1_path = os.path.join(src_dir, "file_1.txt") with open(src_file1_path, "w") as f: f.write("file1") src_file2_path = os.path.join(src_dir, "file_2.txt") with open(src_file2_path, "w") as f: f.write("file2") dst_dir = os.path.join(str(tmpdir), "dst") os.makedirs(dst_dir) with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud") as upload_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto def mock_upload_to_cloud(cloud_credentials, local_file_path, artifact_path): # pylint: disable=unused-argument # Sleep in order to simulate a longer-running asynchronous upload time.sleep(2) dst_path = os.path.join(dst_dir, artifact_path) os.makedirs(os.path.dirname(dst_path), exist_ok=True) shutil.copyfile(src=local_file_path, dst=dst_path) upload_mock.side_effect = mock_upload_to_cloud databricks_artifact_repo.log_artifacts(src_dir, "dir_artifact") expected_dst_dir_file1_path = os.path.join(dst_dir, "dir_artifact", "file_1.txt") expected_dst_dir_file2_path = os.path.join(dst_dir, "dir_artifact", "file_2.txt") assert os.path.exists(expected_dst_dir_file1_path) assert os.path.exists(expected_dst_dir_file2_path) with open(expected_dst_dir_file1_path, "r") as f: assert f.read() == "file1" with open(expected_dst_dir_file2_path, "r") as f: assert f.read() == "file2" databricks_artifact_repo.log_artifact(src_file1_path) expected_dst_file_path = os.path.join(dst_dir, "file_1.txt") assert os.path.exists(expected_dst_file_path) with open(expected_dst_file_path, "r") as f: assert f.read() == "file1"
def test_download_artifacts_awaits_download_completion(self, databricks_artifact_repo, tmpdir): """ Verifies that all asynchronous artifact downloads are joined before `download_artifacts()` returns a result to the caller """ with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos" ) as read_credential_infos_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts" ) as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud" ) as download_mock: ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ) read_credential_infos_mock.return_value = [ ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ] get_list_mock.return_value = [ FileInfo(path="file_1.txt", is_dir=False, file_size=100), FileInfo(path="file_2.txt", is_dir=False, file_size=0), ] def mock_download_from_cloud( cloud_credential_info, dst_local_file_path ): # pylint: disable=unused-argument # Sleep in order to simulate a longer-running asynchronous download time.sleep(2) with open(dst_local_file_path, "w") as f: f.write("content") download_mock.side_effect = mock_download_from_cloud databricks_artifact_repo.download_artifacts("test_path", str(tmpdir)) expected_file1_path = os.path.join(str(tmpdir), "file_1.txt") expected_file2_path = os.path.join(str(tmpdir), "file_2.txt") for path in [expected_file1_path, expected_file2_path]: assert os.path.exists(path) with open(path, "r") as f: assert f.read() == "content"
def test_log_artifact_azure_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): mock_azure_headers = { "x-ms-encryption-scope": "test-scope", "x-ms-tags": "some-tags", "x-ms-blob-type": "some-type", } filtered_azure_headers = { "x-ms-encryption-scope": "test-scope", "x-ms-tags": "some-tags", } mock_response = Response() mock_response.status_code = 200 mock_response.close = lambda: None with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "mlflow.utils.rest_utils.cloud_storage_http_request" ) as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI, headers=[ ArtifactCredentialInfo.HttpHeader(name=header_name, value=header_value) for header_name, header_value in mock_azure_headers.items() ], ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with( "put", MOCK_AZURE_SIGNED_URI + "?comp=blocklist", ANY, headers=filtered_azure_headers, )
def test_log_artifact_gcp_presigned_url_error(self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos" ) as write_credential_infos_mock, mock.patch( "mlflow.utils.rest_utils.cloud_storage_http_request" ) as request_mock: mock_credential_info = ArtifactCredentialInfo( signed_uri=MOCK_GCP_SIGNED_URL, type=ArtifactCredentialType.GCP_SIGNED_URL ) write_credential_infos_mock.return_value = [mock_credential_info] request_mock.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) write_credential_infos_mock.assert_called_with(run_id=MOCK_RUN_ID, paths=ANY)
def test_log_artifact_azure_blob_client_sas_error(self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos" ) as write_credential_infos_mock, mock.patch( "azure.storage.blob.BlobClient.from_blob_url" ) as mock_create_blob_client: mock_credential_info = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ) write_credential_infos_mock.return_value = [mock_credential_info] mock_create_blob_client.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) write_credential_infos_mock.assert_called_with(run_id=MOCK_RUN_ID, paths=ANY)
def test_log_artifact_aws_presigned_url_error(self, databricks_artifact_repo, test_file): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ as write_credentials_mock, \ mock.patch('requests.put') as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
def test_log_artifact_azure_blob_client_sas_error(self, databricks_artifact_repo, test_file): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ as write_credentials_mock, \ mock.patch( 'azure.storage.blob.BlobClient.from_blob_url') as mock_create_blob_client: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto mock_create_blob_client.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
def test_databricks_download_file_get_request_fail(self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos" ) as read_credential_infos_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts" ) as get_list_mock, mock.patch( "requests.get" ) as request_mock: mock_credential_info = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ) read_credential_infos_mock.return_value = [mock_credential_info] get_list_mock.return_value = [] request_mock.return_value = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.download_artifacts(test_file.strpath) read_credential_infos_mock.assert_called_with( run_id=MOCK_RUN_ID, paths=[test_file.strpath] )
def test_databricks_download_file_get_request_fail( self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ mock.patch('requests.get') as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto get_list_mock.return_value = [] request_mock.return_value = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.download_artifacts(test_file.strpath) read_credentials_mock.assert_called_with(MOCK_RUN_ID, test_file.strpath)
def test_log_artifact_aws(self, databricks_artifact_repo, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ as write_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._aws_upload_file') \ as aws_upload_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto aws_upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath)
def test_log_artifact_azure( self, databricks_artifact_repo, test_file, artifact_path, expected_location ): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos" ) as write_credential_infos_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._azure_upload_file" ) as azure_upload_mock: mock_credential_info = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ) write_credential_infos_mock.return_value = [mock_credential_info] azure_upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credential_infos_mock.assert_called_with( run_id=MOCK_RUN_ID, paths=[expected_location] ) azure_upload_mock.assert_called_with( mock_credential_info, test_file.strpath, expected_location )
def test_databricks_download_file(self, databricks_artifact_repo, remote_file_path, local_path, cloud_credential_type): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \ as download_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto download_mock.return_value = None get_list_mock.return_value = [] databricks_artifact_repo.download_artifacts( remote_file_path, local_path) read_credentials_mock.assert_called_with(MOCK_RUN_ID, remote_file_path) download_mock.assert_called_with(mock_credentials, ANY)
def test_log_artifact_azure_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): expected_headers = { header.name: header.value for header in MOCK_HEADERS } mock_blob_service = mock.MagicMock(autospec=BlobClient) with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "azure.storage.blob.BlobClient.from_blob_url" ) as mock_create_blob_client: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI, headers=MOCK_HEADERS, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto mock_create_blob_client.return_value = mock_blob_service mock_blob_service.stage_block.side_effect = None mock_blob_service.commit_block_list.side_effect = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) mock_create_blob_client.assert_called_with( blob_url=MOCK_AZURE_SIGNED_URI, credential=None, headers=expected_headers) mock_blob_service.stage_block.assert_called_with( ANY, ANY, headers=expected_headers) mock_blob_service.commit_block_list.assert_called_with( ANY, headers=expected_headers)
def test_databricks_download_file( self, databricks_artifact_repo, remote_file_path, local_path, cloud_credential_type ): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos" ) as read_credential_infos_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts" ) as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud" ) as download_mock: mock_credential_info = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type ) read_credential_infos_mock.return_value = [mock_credential_info] download_mock.return_value = None get_list_mock.return_value = [] databricks_artifact_repo.download_artifacts(remote_file_path, local_path) read_credential_infos_mock.assert_called_with( run_id=MOCK_RUN_ID, paths=[remote_file_path] ) download_mock.assert_called_with( cloud_credential_info=mock_credential_info, dst_local_file_path=ANY, )
def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos" ) as write_credential_infos_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud" ) as upload_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI databricks_artifact_repo = get_artifact_repository(MOCK_SUBDIR_ROOT_URI) mock_credential_info = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ) write_credential_infos_mock.return_value = [mock_credential_info] upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credential_infos_mock.assert_called_with( run_id=MOCK_RUN_ID, paths=[expected_location] ) upload_mock.assert_called_with( cloud_credential_info=mock_credential_info, src_file_path=test_file.strpath, dst_run_relative_artifact_path=expected_location, )
from mlflow.exceptions import MlflowException from mlflow.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead, \ ArtifactCredentialType, ArtifactCredentialInfo from mlflow.protos.service_pb2 import ListArtifacts, FileInfo from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = 'mlflow.store.artifact.databricks_artifact_repo' DATABRICKS_ARTIFACT_REPOSITORY = DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + \ ".DatabricksArtifactRepository" MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure" MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?" MOCK_RUN_ID = "MOCK-RUN-ID" MOCK_HEADERS = [ ArtifactCredentialInfo.HttpHeader(name='Mock-Name1', value='Mock-Value1'), ArtifactCredentialInfo.HttpHeader(name='Mock-Name2', value='Mock-Value2') ] MOCK_RUN_ROOT_URI = \ "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" MOCK_SUBDIR = "subdir/path" MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) @pytest.fixture() def databricks_artifact_repo(): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ as get_run_artifact_root_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI return get_artifact_repository( "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")
ArtifactCredentialType, ArtifactCredentialInfo, ) from mlflow.protos.service_pb2 import ListArtifacts, FileInfo from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = "mlflow.store.artifact.databricks_artifact_repo" DATABRICKS_ARTIFACT_REPOSITORY = (DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".DatabricksArtifactRepository") MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure" MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?" MOCK_RUN_ID = "MOCK-RUN-ID" MOCK_HEADERS = [ ArtifactCredentialInfo.HttpHeader(name="Mock-Name1", value="Mock-Value1"), ArtifactCredentialInfo.HttpHeader(name="Mock-Name2", value="Mock-Value2"), ] MOCK_RUN_ROOT_URI = "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" MOCK_SUBDIR = "subdir/path" MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) @pytest.fixture() def databricks_artifact_repo(): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root") as get_run_artifact_root_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI return get_artifact_repository( "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")