Esempio n. 1
0
 def test_log_artifact_aws_with_headers(self, databricks_artifact_repo,
                                        test_file, artifact_path,
                                        expected_location):
     expected_headers = {
         header.name: header.value
         for header in MOCK_HEADERS
     }
     mock_response = Response()
     mock_response.status_code = 200
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \
             as write_credentials_mock, \
             mock.patch('requests.put') as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL,
             headers=MOCK_HEADERS)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.return_value = mock_response
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         request_mock.assert_called_with(ANY, ANY, headers=expected_headers)
Esempio n. 2
0
 def test_databricks_download_file_with_relative_path(
         self, remote_file_path, local_path):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
             as get_run_artifact_root_mock, \
             mock.patch(
                 DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \
             as read_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \
             as download_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo = get_artifact_repository(
             MOCK_SUBDIR_ROOT_URI)
         databricks_artifact_repo.download_artifacts(
             remote_file_path, local_path)
         read_credentials_mock.assert_called_with(
             MOCK_RUN_ID, posixpath.join(MOCK_SUBDIR, remote_file_path))
         download_mock.assert_called_with(mock_credentials, ANY)
Esempio n. 3
0
 def test_log_artifact_with_relative_path(self, test_file, artifact_path,
                                          expected_location):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
             as get_run_artifact_root_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \
             as write_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._upload_to_cloud') \
             as upload_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         databricks_artifact_repo = get_artifact_repository(
             MOCK_SUBDIR_ROOT_URI)
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         upload_mock.assert_called_with(write_credentials_response_proto,
                                        test_file.strpath,
                                        expected_location)
Esempio n. 4
0
 def test_log_artifact_aws_presigned_url_error(self,
                                               databricks_artifact_repo,
                                               test_file):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \
             as write_credentials_mock, \
             mock.patch('requests.put') as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.side_effect = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.log_artifact(test_file.strpath)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
Esempio n. 5
0
 def test_log_artifact_azure_blob_client_sas_error(self,
                                                   databricks_artifact_repo,
                                                   test_file):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \
             as write_credentials_mock, \
             mock.patch(
                 'azure.storage.blob.BlobClient.from_blob_url') as mock_create_blob_client:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         mock_create_blob_client.side_effect = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.log_artifact(test_file.strpath)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
Esempio n. 6
0
 def test_databricks_download_file_get_request_fail(
         self, databricks_artifact_repo, test_file):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \
             as read_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \
             mock.patch('requests.get') as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         get_list_mock.return_value = []
         request_mock.return_value = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.download_artifacts(test_file.strpath)
         read_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                  test_file.strpath)
Esempio n. 7
0
 def test_log_artifact_aws(self, databricks_artifact_repo, test_file,
                           artifact_path, expected_location):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \
             as write_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._aws_upload_file') \
             as aws_upload_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         aws_upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         aws_upload_mock.assert_called_with(mock_credentials,
                                            test_file.strpath)
Esempio n. 8
0
 def test_databricks_download_file(self, databricks_artifact_repo,
                                   remote_file_path, local_path,
                                   cloud_credential_type):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \
             as read_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \
             as download_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo.download_artifacts(
             remote_file_path, local_path)
         read_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                  remote_file_path)
         download_mock.assert_called_with(mock_credentials, ANY)
Esempio n. 9
0
    def test_log_artifact_azure_with_headers(self, databricks_artifact_repo,
                                             test_file, artifact_path,
                                             expected_location):
        expected_headers = {
            header.name: header.value
            for header in MOCK_HEADERS
        }
        mock_blob_service = mock.MagicMock(autospec=BlobClient)
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \
                as write_credentials_mock, \
                mock.patch(
                    'azure.storage.blob.BlobClient.from_blob_url') as mock_create_blob_client:
            mock_credentials = ArtifactCredentialInfo(
                signed_uri=MOCK_AZURE_SIGNED_URI,
                type=ArtifactCredentialType.AZURE_SAS_URI,
                headers=MOCK_HEADERS)
            write_credentials_response_proto = GetCredentialsForWrite.Response(
                credentials=mock_credentials)
            write_credentials_mock.return_value = write_credentials_response_proto

            mock_create_blob_client.return_value = mock_blob_service
            mock_blob_service.stage_block.side_effect = None
            mock_blob_service.commit_block_list.side_effect = None

            databricks_artifact_repo.log_artifact(test_file.strpath,
                                                  artifact_path)
            write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                      expected_location)
            mock_create_blob_client.assert_called_with(
                blob_url=MOCK_AZURE_SIGNED_URI,
                credential=None,
                headers=expected_headers)
            mock_blob_service.stage_block.assert_called_with(
                ANY, ANY, headers=expected_headers)
            mock_blob_service.commit_block_list.assert_called_with(
                ANY, headers=expected_headers)
Esempio n. 10
0
from kiwi.exceptions import MlflowException
from kiwi.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead, \
    ArtifactCredentialType, ArtifactCredentialInfo
from kiwi.protos.service_pb2 import ListArtifacts, FileInfo
from kiwi.store.artifact.artifact_repository_registry import get_artifact_repository
from kiwi.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository

DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = 'mlflow.store.artifact.databricks_artifact_repo'
DATABRICKS_ARTIFACT_REPOSITORY = DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + \
                                 ".DatabricksArtifactRepository"

MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure"
MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?"
MOCK_RUN_ID = "MOCK-RUN-ID"
MOCK_HEADERS = [
    ArtifactCredentialInfo.HttpHeader(name='Mock-Name1', value='Mock-Value1'),
    ArtifactCredentialInfo.HttpHeader(name='Mock-Name2', value='Mock-Value2')
]
MOCK_RUN_ROOT_URI = \
    "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts"
MOCK_SUBDIR = "subdir/path"
MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR)


@pytest.fixture()
def databricks_artifact_repo():
    with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
            as get_run_artifact_root_mock:
        get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
        return get_artifact_repository(
            "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")