def test_log_artifact_azure_with_headers(self, databricks_artifact_repo, test_file, artifact_path, expected_location): mock_azure_headers = { "x-ms-encryption-scope": "test-scope", "x-ms-tags": "some-tags", "x-ms-blob-type": "some-type", } filtered_azure_headers = { "x-ms-encryption-scope": "test-scope", "x-ms-tags": "some-tags", } mock_response = Response() mock_response.status_code = 200 mock_response.close = lambda: None with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials" ) as write_credentials_mock, mock.patch( "mlflow.utils.rest_utils.cloud_storage_http_request" ) as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI, headers=[ ArtifactCredentialInfo.HttpHeader(name=header_name, value=header_value) for header_name, header_value in mock_azure_headers.items() ], ) write_credentials_response_proto = GetCredentialsForWrite.Response( credentials=mock_credentials) write_credentials_mock.return_value = write_credentials_response_proto request_mock.return_value = mock_response databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path) write_credentials_mock.assert_called_with(MOCK_RUN_ID, expected_location) request_mock.assert_called_with( "put", MOCK_AZURE_SIGNED_URI + "?comp=blocklist", ANY, headers=filtered_azure_headers, )
ArtifactCredentialType, ArtifactCredentialInfo, ) from mlflow.protos.service_pb2 import ListArtifacts, FileInfo from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = "mlflow.store.artifact.databricks_artifact_repo" DATABRICKS_ARTIFACT_REPOSITORY = (DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".DatabricksArtifactRepository") MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure" MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?" MOCK_RUN_ID = "MOCK-RUN-ID" MOCK_HEADERS = [ ArtifactCredentialInfo.HttpHeader(name="Mock-Name1", value="Mock-Value1"), ArtifactCredentialInfo.HttpHeader(name="Mock-Name2", value="Mock-Value2"), ] MOCK_RUN_ROOT_URI = "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" MOCK_SUBDIR = "subdir/path" MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) @pytest.fixture() def databricks_artifact_repo(): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root") as get_run_artifact_root_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI return get_artifact_repository( "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")
from mlflow.exceptions import MlflowException from mlflow.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead, \ ArtifactCredentialType, ArtifactCredentialInfo from mlflow.protos.service_pb2 import ListArtifacts, FileInfo from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = 'mlflow.store.artifact.databricks_artifact_repo' DATABRICKS_ARTIFACT_REPOSITORY = DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + \ ".DatabricksArtifactRepository" MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure" MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?" MOCK_RUN_ID = "MOCK-RUN-ID" MOCK_HEADERS = [ ArtifactCredentialInfo.HttpHeader(name='Mock-Name1', value='Mock-Value1'), ArtifactCredentialInfo.HttpHeader(name='Mock-Name2', value='Mock-Value2') ] MOCK_RUN_ROOT_URI = \ "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts" MOCK_SUBDIR = "subdir/path" MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR) @pytest.fixture() def databricks_artifact_repo(): with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ as get_run_artifact_root_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI return get_artifact_repository( "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")