def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud") as upload_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) upload_mock.assert_called_with(write_credentials_response_proto, test_file.strpath, expected_location)
def test_log_artifact_aws_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): expected_headers = { header.name: header.value for header in MOCK_HEADERS } mock_response = Response() mock_response.status_code = 200 with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".http_put") as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL, headers=MOCK_HEADERS, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with(ANY, ANY, headers=expected_headers)
def test_log_artifact_azure_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): expected_headers = { header.name: header.value for header in MOCK_HEADERS } with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".put_block") as mock_put_block, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".put_block_list") as mock_put_block_list: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI, headers=MOCK_HEADERS, ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) mock_put_block.assert_called_with(MOCK_AZURE_SIGNED_URI, ANY, ANY, headers=expected_headers) mock_put_block_list.assert_called_with(MOCK_AZURE_SIGNED_URI, ANY, headers=expected_headers)
def test_log_artifact_aws_presigned_url_error(self, databricks_artifact_repo, test_file): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".http_put") as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.side_effect = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.log_artifact(test_file.strpath) write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
def test_log_artifact_aws(self, databricks_artifact_repo, test_file, artifact_path, expected_location): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._aws_upload_file") as aws_upload_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AWS_SIGNED_URI, type=ArtifactCredentialType.AWS_PRESIGNED_URL) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto aws_upload_mock.return_value = None databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) aws_upload_mock.assert_called_with(mock_credentials, test_file.strpath)