def test_databricks_download_file_with_relative_path( self, remote_file_path, local_path): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credentials") as read_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts") as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud") as download_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto download_mock.return_value = None get_list_mock.return_value = [] databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) databricks_artifact_repo.download_artifacts( remote_file_path, local_path) read_credentials_mock.assert_called_with( MOCK_RUN_ID, posixpath.join(MOCK_SUBDIR, remote_file_path)) download_mock.assert_called_with(mock_credentials, ANY)
def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud") as upload_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) upload_mock.assert_called_with(write_credentials_response_proto, test_file.strpath, expected_location)
def test_log_artifact_aws_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): expected_headers = { header.name: header.value for header in MOCK_HEADERS } mock_response = Response() mock_response.status_code = 200 with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".http_put") as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL, headers=MOCK_HEADERS, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with(ANY, ANY, headers=expected_headers)
def test_log_artifact_azure_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): expected_headers = { header.name: header.value for header in MOCK_HEADERS } with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".put_block") as mock_put_block, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".put_block_list") as mock_put_block_list: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI, headers=MOCK_HEADERS, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) mock_put_block.assert_called_with(MOCK_AZURE_SIGNED_URI, ANY, ANY, headers=expected_headers) mock_put_block_list.assert_called_with(MOCK_AZURE_SIGNED_URI, ANY, headers=expected_headers)
def test_log_artifact_aws_presigned_url_error(self, databricks_artifact_repo, test_file): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".http_put") as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
def test_log_artifact_aws(self, databricks_artifact_repo, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._aws_upload_file") as aws_upload_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto aws_upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath)
def test_databricks_download_file_get_request_fail( self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credentials") as read_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts") as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".http_get") as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto get_list_mock.return_value = [] request_mock.return_value = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.download_artifacts(test_file.strpath) read_credentials_mock.assert_called_with(MOCK_RUN_ID, test_file.strpath)
def test_databricks_download_file(self, databricks_artifact_repo, remote_file_path, local_path, cloud_credential_type): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credentials") as read_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts") as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud") as download_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto download_mock.return_value = None get_list_mock.return_value = [] databricks_artifact_repo.download_artifacts( remote_file_path, local_path) read_credentials_mock.assert_called_with(MOCK_RUN_ID, remote_file_path) download_mock.assert_called_with(mock_credentials, ANY)
GetCredentialsForWrite, GetCredentialsForRead, ArtifactCredentialType, ArtifactCredentialInfo, ) from mlflow_databricks_artifacts.store.artifact_repo import DatabricksArtifactRepository DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = "mlflow_databricks_artifacts.store.artifact_repo" DATABRICKS_ARTIFACT_REPOSITORY = (DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".DatabricksArtifactRepository") MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure" MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?" MOCK_RUN_ID = "MOCK-RUN-ID" MOCK_HEADERS = [ ArtifactCredentialInfo.HttpHeader(name="Mock-Name1", value="Mock-Value1"), ArtifactCredentialInfo.HttpHeader(name="Mock-Name2", value="Mock-Value2"), ] MOCK_RUN_ROOT_URI = "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" MOCK_SUBDIR = "subdir/path" MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) @pytest.fixture() def databricks_artifact_repo(): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root") as get_run_artifact_root_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI return get_artifact_repository( "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")