def test_log_artifact_with_relative_path(self, test_file, artifact_path,
                                          expected_location):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_run_artifact_root"
                     ) as get_run_artifact_root_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY +
                         "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY +
                         "._upload_to_cloud") as upload_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         databricks_artifact_repo = get_artifact_repository(
             MOCK_SUBDIR_ROOT_URI)
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         upload_mock.assert_called_with(write_credentials_response_proto,
                                        test_file.strpath,
                                        expected_location)
 def test_log_artifact_aws_with_headers(self, databricks_artifact_repo,
                                        test_file, artifact_path,
                                        expected_location):
     expected_headers = {
         header.name: header.value
         for header in MOCK_HEADERS
     }
     mock_response = Response()
     mock_response.status_code = 200
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                         ".http_put") as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL,
             headers=MOCK_HEADERS,
         )
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.return_value = mock_response
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         request_mock.assert_called_with(ANY, ANY, headers=expected_headers)
    def test_log_artifact_azure_with_headers(self, databricks_artifact_repo,
                                             test_file, artifact_path,
                                             expected_location):
        expected_headers = {
            header.name: header.value
            for header in MOCK_HEADERS
        }
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                        "._get_write_credentials"
                        ) as write_credentials_mock, mock.patch(
                            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                            ".put_block") as mock_put_block, mock.patch(
                                DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                                ".put_block_list") as mock_put_block_list:
            mock_credentials = ArtifactCredentialInfo(
                signed_uri=MOCK_AZURE_SIGNED_URI,
                type=ArtifactCredentialType.AZURE_SAS_URI,
                headers=MOCK_HEADERS,
            )
            write_credentials_response_proto = GetCredentialsForWrite.Response(
                credentials=mock_credentials)
            write_credentials_mock.return_value = write_credentials_response_proto

            databricks_artifact_repo.log_artifact(test_file.strpath,
                                                  artifact_path)
            write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                      expected_location)

            mock_put_block.assert_called_with(MOCK_AZURE_SIGNED_URI,
                                              ANY,
                                              ANY,
                                              headers=expected_headers)
            mock_put_block_list.assert_called_with(MOCK_AZURE_SIGNED_URI,
                                                   ANY,
                                                   headers=expected_headers)
 def test_log_artifact_aws_presigned_url_error(self,
                                               databricks_artifact_repo,
                                               test_file):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                         ".http_put") as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.side_effect = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.log_artifact(test_file.strpath)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
 def test_log_artifact_aws(self, databricks_artifact_repo, test_file,
                           artifact_path, expected_location):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY +
                         "._aws_upload_file") as aws_upload_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         aws_upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         aws_upload_mock.assert_called_with(mock_credentials,
                                            test_file.strpath)