def test_databricks_download_file_with_relative_path(
         self, remote_file_path, local_path):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root"
     ) as get_run_artifact_root_mock, mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY +
             "._get_read_credentials") as read_credentials_mock, mock.patch(
                 DATABRICKS_ARTIFACT_REPOSITORY +
                 ".list_artifacts") as get_list_mock, mock.patch(
                     DATABRICKS_ARTIFACT_REPOSITORY +
                     "._download_from_cloud") as download_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo = get_artifact_repository(
             MOCK_SUBDIR_ROOT_URI)
         databricks_artifact_repo.download_artifacts(
             remote_file_path, local_path)
         read_credentials_mock.assert_called_with(
             MOCK_RUN_ID, posixpath.join(MOCK_SUBDIR, remote_file_path))
         download_mock.assert_called_with(mock_credentials, ANY)
 def test_log_artifact_with_relative_path(self, test_file, artifact_path,
                                          expected_location):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_run_artifact_root"
                     ) as get_run_artifact_root_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY +
                         "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY +
                         "._upload_to_cloud") as upload_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         databricks_artifact_repo = get_artifact_repository(
             MOCK_SUBDIR_ROOT_URI)
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         upload_mock.assert_called_with(write_credentials_response_proto,
                                        test_file.strpath,
                                        expected_location)
 def test_log_artifact_aws_with_headers(self, databricks_artifact_repo,
                                        test_file, artifact_path,
                                        expected_location):
     expected_headers = {
         header.name: header.value
         for header in MOCK_HEADERS
     }
     mock_response = Response()
     mock_response.status_code = 200
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                         ".http_put") as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL,
             headers=MOCK_HEADERS,
         )
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.return_value = mock_response
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         request_mock.assert_called_with(ANY, ANY, headers=expected_headers)
    def test_log_artifact_azure_with_headers(self, databricks_artifact_repo,
                                             test_file, artifact_path,
                                             expected_location):
        expected_headers = {
            header.name: header.value
            for header in MOCK_HEADERS
        }
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                        "._get_write_credentials"
                        ) as write_credentials_mock, mock.patch(
                            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                            ".put_block") as mock_put_block, mock.patch(
                                DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                                ".put_block_list") as mock_put_block_list:
            mock_credentials = ArtifactCredentialInfo(
                signed_uri=MOCK_AZURE_SIGNED_URI,
                type=ArtifactCredentialType.AZURE_SAS_URI,
                headers=MOCK_HEADERS,
            )
            write_credentials_response_proto = GetCredentialsForWrite.Response(
                credentials=mock_credentials)
            write_credentials_mock.return_value = write_credentials_response_proto

            databricks_artifact_repo.log_artifact(test_file.strpath,
                                                  artifact_path)
            write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                      expected_location)

            mock_put_block.assert_called_with(MOCK_AZURE_SIGNED_URI,
                                              ANY,
                                              ANY,
                                              headers=expected_headers)
            mock_put_block_list.assert_called_with(MOCK_AZURE_SIGNED_URI,
                                                   ANY,
                                                   headers=expected_headers)
 def test_log_artifact_aws_presigned_url_error(self,
                                               databricks_artifact_repo,
                                               test_file):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                         ".http_put") as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.side_effect = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.log_artifact(test_file.strpath)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
 def test_log_artifact_aws(self, databricks_artifact_repo, test_file,
                           artifact_path, expected_location):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY +
                         "._aws_upload_file") as aws_upload_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         aws_upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         aws_upload_mock.assert_called_with(mock_credentials,
                                            test_file.strpath)
 def test_databricks_download_file_get_request_fail(
         self, databricks_artifact_repo, test_file):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY +
             "._get_read_credentials") as read_credentials_mock, mock.patch(
                 DATABRICKS_ARTIFACT_REPOSITORY +
                 ".list_artifacts") as get_list_mock, mock.patch(
                     DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                     ".http_get") as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         get_list_mock.return_value = []
         request_mock.return_value = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.download_artifacts(test_file.strpath)
         read_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                  test_file.strpath)
 def test_databricks_download_file(self, databricks_artifact_repo,
                                   remote_file_path, local_path,
                                   cloud_credential_type):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY +
             "._get_read_credentials") as read_credentials_mock, mock.patch(
                 DATABRICKS_ARTIFACT_REPOSITORY +
                 ".list_artifacts") as get_list_mock, mock.patch(
                     DATABRICKS_ARTIFACT_REPOSITORY +
                     "._download_from_cloud") as download_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo.download_artifacts(
             remote_file_path, local_path)
         read_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                  remote_file_path)
         download_mock.assert_called_with(mock_credentials, ANY)
    GetCredentialsForWrite,
    GetCredentialsForRead,
    ArtifactCredentialType,
    ArtifactCredentialInfo,
)
from mlflow_databricks_artifacts.store.artifact_repo import DatabricksArtifactRepository

DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = "mlflow_databricks_artifacts.store.artifact_repo"
DATABRICKS_ARTIFACT_REPOSITORY = (DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                                  ".DatabricksArtifactRepository")

MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure"
MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?"
MOCK_RUN_ID = "MOCK-RUN-ID"
MOCK_HEADERS = [
    ArtifactCredentialInfo.HttpHeader(name="Mock-Name1", value="Mock-Value1"),
    ArtifactCredentialInfo.HttpHeader(name="Mock-Name2", value="Mock-Value2"),
]
MOCK_RUN_ROOT_URI = "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts"
MOCK_SUBDIR = "subdir/path"
MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR)


@pytest.fixture()
def databricks_artifact_repo():
    with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                    "._get_run_artifact_root") as get_run_artifact_root_mock:
        get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
        return get_artifact_repository(
            "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")