def test_log_artifact_aws_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): expected_headers = { header.name: header.value for header in MOCK_HEADERS } mock_response = Response() mock_response.status_code = 200 with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "requests.put") as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL, headers=MOCK_HEADERS, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with(ANY, ANY, headers=expected_headers)
def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud") as upload_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) upload_mock.assert_called_with(write_credentials_response_proto, test_file.strpath, expected_location)
def test_log_artifact_gcp_with_headers( self, databricks_artifact_repo, test_file, artifact_path, expected_location ): expected_headers = {header.name: header.value for header in MOCK_HEADERS} mock_response = Response() mock_response.status_code = 200 mock_response.close = lambda: None with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "mlflow.utils.rest_utils.cloud_storage_http_request" ) as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_GCP_SIGNED_URL, type=ArtifactCredentialType.GCP_SIGNED_URL, headers=MOCK_HEADERS, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials ) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with( "put", MOCK_GCP_SIGNED_URL, ANY, headers=expected_headers )
def test_artifact_logging_awaits_upload_completion( self, databricks_artifact_repo, tmpdir): """ Verifies that all asynchronous artifact uploads initiated by `log_artifact()` and `log_artifacts()` are joined before these methods return a result to the caller """ src_dir = os.path.join(str(tmpdir), "src") os.makedirs(src_dir) src_file1_path = os.path.join(src_dir, "file_1.txt") with open(src_file1_path, "w") as f: f.write("file1") src_file2_path = os.path.join(src_dir, "file_2.txt") with open(src_file2_path, "w") as f: f.write("file2") dst_dir = os.path.join(str(tmpdir), "dst") os.makedirs(dst_dir) with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud") as upload_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto def mock_upload_to_cloud(cloud_credentials, local_file_path, artifact_path): # pylint: disable=unused-argument # Sleep in order to simulate a longer-running asynchronous upload time.sleep(2) dst_path = os.path.join(dst_dir, artifact_path) os.makedirs(os.path.dirname(dst_path), exist_ok=True) shutil.copyfile(src=local_file_path, dst=dst_path) upload_mock.side_effect = mock_upload_to_cloud databricks_artifact_repo.log_artifacts(src_dir, "dir_artifact") expected_dst_dir_file1_path = os.path.join(dst_dir, "dir_artifact", "file_1.txt") expected_dst_dir_file2_path = os.path.join(dst_dir, "dir_artifact", "file_2.txt") assert os.path.exists(expected_dst_dir_file1_path) assert os.path.exists(expected_dst_dir_file2_path) with open(expected_dst_dir_file1_path, "r") as f: assert f.read() == "file1" with open(expected_dst_dir_file2_path, "r") as f: assert f.read() == "file2" databricks_artifact_repo.log_artifact(src_file1_path) expected_dst_file_path = os.path.join(dst_dir, "file_1.txt") assert os.path.exists(expected_dst_file_path) with open(expected_dst_file_path, "r") as f: assert f.read() == "file1"
def test_get_write_credential_infos_handles_pagination(self, databricks_artifact_repo): """ Verifies that the `_get_write_credential_infos` method, which is used to resolve write access credentials for a collection of artifacts, handles paginated responses properly, issuing incremental requests until all pages have been consumed """ credential_infos_mock_1 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL ), ArtifactCredentialInfo( signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL ), ] credential_infos_mock_2 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) ] credential_infos_mock_3 = [] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json" ) as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint" ) as call_endpoint_mock: get_credentials_for_write_responses = [ GetCredentialsForWrite.Response( credential_infos=credential_infos_mock_1, next_page_token="2" ), GetCredentialsForWrite.Response( credential_infos=credential_infos_mock_2, next_page_token="3" ), GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_3), ] call_endpoint_mock.side_effect = get_credentials_for_write_responses write_credential_infos = databricks_artifact_repo._get_write_credential_infos( MOCK_RUN_ID, ["testpath"], ) assert write_credential_infos == credential_infos_mock_1 + credential_infos_mock_2 message_mock.assert_has_calls( [ mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=["testpath"])), mock.call( GetCredentialsForWrite( run_id=MOCK_RUN_ID, path=["testpath"], page_token="2" ) ), mock.call( GetCredentialsForWrite( run_id=MOCK_RUN_ID, path=["testpath"], page_token="3" ) ), ] ) assert call_endpoint_mock.call_count == 3
def test_get_write_credential_infos_respects_max_request_size(self, databricks_artifact_repo): """ Verifies that the `_get_write_credential_infos` method, which is used to resolve write access credentials for a collection of artifacts, batches requests according to a maximum request size configured by the backend """ # Create 3 chunks of paths, two of which have the maximum request size and one of which # is smaller than the maximum chunk size. Aggregate and pass these to # `_get_write_credential_infos`, validating that this method decomposes the aggregate # list into these expected chunks and makes 3 separate requests paths_chunk_1 = ["path1"] * 2000 paths_chunk_2 = ["path2"] * 2000 paths_chunk_3 = ["path3"] * 5 credential_infos_mock_1 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(2000) ] credential_infos_mock_2 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(2000) ] credential_infos_mock_3 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(5) ] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json" ) as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint" ) as call_endpoint_mock: call_endpoint_mock.side_effect = [ GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_1), GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_2), GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_3), ] databricks_artifact_repo._get_write_credential_infos( MOCK_RUN_ID, paths_chunk_1 + paths_chunk_2 + paths_chunk_3, ) assert call_endpoint_mock.call_count == message_mock.call_count == 3 message_mock.assert_has_calls( [ mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=paths_chunk_1)), mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=paths_chunk_2)), mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=paths_chunk_3)), ] )
def test_log_artifact_aws_presigned_url_error(self, databricks_artifact_repo, test_file): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ as write_credentials_mock, \ mock.patch('requests.put') as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
def test_log_artifact_gcp_presigned_url_error(self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "mlflow.utils.rest_utils.cloud_storage_http_request" ) as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_GCP_SIGNED_URL, type=ArtifactCredentialType.GCP_SIGNED_URL ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials ) write_credentials_mock.return_value = write_credentials_response_proto request_mock.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
def test_log_artifact_azure_blob_client_sas_error(self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "azure.storage.blob.BlobClient.from_blob_url" ) as mock_create_blob_client: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials ) write_credentials_mock.return_value = write_credentials_response_proto mock_create_blob_client.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
def test_log_artifact_aws(self, databricks_artifact_repo, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \ as write_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._aws_upload_file') \ as aws_upload_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto aws_upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath)
def test_log_artifact_azure_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): mock_azure_headers = { "x-ms-encryption-scope": "test-scope", "x-ms-tags": "some-tags", "x-ms-blob-type": "some-type", } filtered_azure_headers = { "x-ms-encryption-scope": "test-scope", "x-ms-tags": "some-tags", } mock_response = Response() mock_response.status_code = 200 mock_response.close = lambda: None with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "mlflow.utils.rest_utils.cloud_storage_http_request" ) as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI, headers=[ ArtifactCredentialInfo.HttpHeader(name=header_name, value=header_value) for header_name, header_value in mock_azure_headers.items() ], ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with( "put", MOCK_AZURE_SIGNED_URI + "?comp=blocklist", ANY, headers=filtered_azure_headers, )
def test_log_artifact_gcp(self, databricks_artifact_repo, test_file, artifact_path, expected_location): mock_response = Response() mock_response.status_code = 200 with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "requests.put") as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_GCP_SIGNED_URL, type=ArtifactCredentialType.GCP_SIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with(MOCK_GCP_SIGNED_URL, ANY, headers={})
def test_log_artifact_azure_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): expected_headers = { header.name: header.value for header in MOCK_HEADERS } mock_blob_service = mock.MagicMock(autospec=BlobClient) with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "azure.storage.blob.BlobClient.from_blob_url" ) as mock_create_blob_client: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI, headers=MOCK_HEADERS, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto mock_create_blob_client.return_value = mock_blob_service mock_blob_service.stage_block.side_effect = None mock_blob_service.commit_block_list.side_effect = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) mock_create_blob_client.assert_called_with( blob_url=MOCK_AZURE_SIGNED_URI, credential=None, headers=expected_headers) mock_blob_service.stage_block.assert_called_with( ANY, ANY, headers=expected_headers) mock_blob_service.commit_block_list.assert_called_with( ANY, headers=expected_headers)
def test_log_artifacts_provides_failure_info(self, databricks_artifact_repo, tmpdir): src_file1_path = os.path.join(str(tmpdir), "file_1.txt") with open(src_file1_path, "w") as f: f.write("file1") src_file2_path = os.path.join(str(tmpdir), "file_2.txt") with open(src_file2_path, "w") as f: f.write("file2") with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud") as upload_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto upload_mock.side_effect = [ MlflowException("MOCK ERROR 1"), MlflowException("MOCK ERROR 2"), ] with pytest.raises(MlflowException) as exc: databricks_artifact_repo.log_artifacts(str(tmpdir), "test_artifacts") assert MOCK_RUN_ROOT_URI in str(exc) assert "file_1.txt" in str(exc) assert "MOCK ERROR 1" in str(exc) assert "file_2.txt" in str(exc) assert "MOCK ERROR 2" in str(exc)
def _get_write_credentials(self, run_id, path=None): json_body = message_to_json( GetCredentialsForWrite(run_id=run_id, path=path)) return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForWrite, json_body)