Example #1
0
    def test_get_read_credential_infos_respects_max_request_size(self, databricks_artifact_repo):
        """
        Verifies that the `_get_read_credential_infos` method, which is used to resolve read access
        credentials for a collection of artifacts, handles paginated responses properly, issuing
        incremental requests until all pages have been consumed
        """
        assert _MAX_CREDENTIALS_REQUEST_SIZE == 2000, (
            "The maximum request size configured by the client should be consistent with the"
            " Databricks backend. Only update this value of the backend limit has changed."
        )

        # Create 3 chunks of paths, two of which have the maximum request size and one of which
        # is smaller than the maximum chunk size. Aggregate and pass these to
        # `_get_read_credential_infos`, validating that this method decomposes the aggregate
        # list into these expected chunks and makes 3 separate requests
        paths_chunk_1 = ["path1"] * _MAX_CREDENTIALS_REQUEST_SIZE
        paths_chunk_2 = ["path2"] * _MAX_CREDENTIALS_REQUEST_SIZE
        paths_chunk_3 = ["path3"] * 5
        credential_infos_mock_1 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE)
        ]
        credential_infos_mock_2 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE)
        ]
        credential_infos_mock_3 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(5)
        ]

        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json"
        ) as message_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint"
        ) as call_endpoint_mock:
            call_endpoint_mock.side_effect = [
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_1,),
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_2,),
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3,),
            ]

            databricks_artifact_repo._get_read_credential_infos(
                MOCK_RUN_ID, paths_chunk_1 + paths_chunk_2 + paths_chunk_3,
            )
            assert call_endpoint_mock.call_count == 3
            assert message_mock.call_count == 3
            message_mock.assert_has_calls(
                [
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_1)),
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_2)),
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_3)),
                ]
            )
    def test_get_read_credential_infos_handles_pagination(self, databricks_artifact_repo):
        """
        Verifies that the `get_read_credential_infos` method, which is used to resolve read access
        credentials for a collection of artifacts, handles paginated responses properly, issuing
        incremental requests until all pages have been consumed
        """
        credential_infos_mock_1 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            ),
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            ),
        ]
        credential_infos_mock_2 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
        ]
        credential_infos_mock_3 = []

        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json"
        ) as message_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint"
        ) as call_endpoint_mock:
            get_credentials_for_read_responses = [
                GetCredentialsForRead.Response(
                    credential_infos=credential_infos_mock_1, next_page_token="2"
                ),
                GetCredentialsForRead.Response(
                    credential_infos=credential_infos_mock_2, next_page_token="3"
                ),
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3),
            ]
            call_endpoint_mock.side_effect = get_credentials_for_read_responses
            read_credential_infos = databricks_artifact_repo._get_read_credential_infos(
                MOCK_RUN_ID,
                ["testpath"],
            )
            assert read_credential_infos == credential_infos_mock_1 + credential_infos_mock_2
            message_mock.assert_has_calls(
                [
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"])),
                    mock.call(
                        GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"], page_token="2")
                    ),
                    mock.call(
                        GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"], page_token="3")
                    ),
                ]
            )
            assert call_endpoint_mock.call_count == 3
Example #3
0
 def test_databricks_download_file_with_relative_path(
         self, remote_file_path, local_path):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root"
     ) as get_run_artifact_root_mock, mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY +
             "._get_read_credentials") as read_credentials_mock, mock.patch(
                 DATABRICKS_ARTIFACT_REPOSITORY +
                 ".list_artifacts") as get_list_mock, mock.patch(
                     DATABRICKS_ARTIFACT_REPOSITORY +
                     "._download_from_cloud") as download_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo = get_artifact_repository(
             MOCK_SUBDIR_ROOT_URI)
         databricks_artifact_repo.download_artifacts(
             remote_file_path, local_path)
         read_credentials_mock.assert_called_with(
             MOCK_RUN_ID, posixpath.join(MOCK_SUBDIR, remote_file_path))
         download_mock.assert_called_with(mock_credentials, ANY)
Example #4
0
    def test_download_artifacts_provides_failure_info(
            self, databricks_artifact_repo):
        with mock.patch(
                DATABRICKS_ARTIFACT_REPOSITORY +
                "._get_read_credentials") as read_credentials_mock, mock.patch(
                    DATABRICKS_ARTIFACT_REPOSITORY +
                    ".list_artifacts") as get_list_mock, mock.patch(
                        DATABRICKS_ARTIFACT_REPOSITORY +
                        "._download_from_cloud") as download_mock:
            mock_credentials = ArtifactCredentialInfo(
                signed_uri=MOCK_AZURE_SIGNED_URI,
                type=ArtifactCredentialType.AZURE_SAS_URI)
            read_credentials_response_proto = GetCredentialsForRead.Response(
                credentials=mock_credentials)
            read_credentials_mock.return_value = read_credentials_response_proto
            get_list_mock.return_value = [
                FileInfo(path="file_1.txt", is_dir=False, file_size=100),
                FileInfo(path="file_2.txt", is_dir=False, file_size=0),
            ]
            download_mock.side_effect = [
                MlflowException("MOCK ERROR 1"),
                MlflowException("MOCK ERROR 2"),
            ]

            with pytest.raises(MlflowException) as exc:
                databricks_artifact_repo.download_artifacts("test_path")

            assert MOCK_RUN_ROOT_URI in str(exc)
            assert "file_1.txt" in str(exc)
            assert "MOCK ERROR 1" in str(exc)
            assert "file_2.txt" in str(exc)
            assert "MOCK ERROR 2" in str(exc)
 def test_databricks_download_file_get_request_fail(
         self, databricks_artifact_repo, test_file):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \
             as read_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \
             mock.patch('requests.get') as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         get_list_mock.return_value = []
         request_mock.return_value = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.download_artifacts(test_file.strpath)
         read_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                  test_file.strpath)
Example #6
0
    def test_download_artifacts_awaits_download_completion(
            self, databricks_artifact_repo, tmpdir):
        """
        Verifies that all asynchronous artifact downloads are joined before `download_artifacts()`
        returns a result to the caller
        """
        with mock.patch(
                DATABRICKS_ARTIFACT_REPOSITORY +
                "._get_read_credentials") as read_credentials_mock, mock.patch(
                    DATABRICKS_ARTIFACT_REPOSITORY +
                    ".list_artifacts") as get_list_mock, mock.patch(
                        DATABRICKS_ARTIFACT_REPOSITORY +
                        "._download_from_cloud") as download_mock:
            mock_credentials = ArtifactCredentialInfo(
                signed_uri=MOCK_AZURE_SIGNED_URI,
                type=ArtifactCredentialType.AZURE_SAS_URI)
            read_credentials_response_proto = GetCredentialsForRead.Response(
                credentials=mock_credentials)
            read_credentials_mock.return_value = read_credentials_response_proto
            get_list_mock.return_value = [
                FileInfo(path="file_1.txt", is_dir=False, file_size=100),
                FileInfo(path="file_2.txt", is_dir=False, file_size=0),
            ]

            def mock_download_from_cloud(cloud_credential, local_file_path):  # pylint: disable=unused-argument
                # Sleep in order to simulate a longer-running asynchronous download
                time.sleep(2)
                with open(local_file_path, "w") as f:
                    f.write("content")

            download_mock.side_effect = mock_download_from_cloud

            databricks_artifact_repo.download_artifacts(
                "test_path", str(tmpdir))

            expected_file1_path = os.path.join(str(tmpdir), "file_1.txt")
            expected_file2_path = os.path.join(str(tmpdir), "file_2.txt")
            for path in [expected_file1_path, expected_file2_path]:
                assert os.path.exists(path)
                with open(path, "r") as f:
                    assert f.read() == "content"
 def test_databricks_download_file(self, databricks_artifact_repo,
                                   remote_file_path, local_path,
                                   cloud_credential_type):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \
             as read_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \
             as download_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo.download_artifacts(
             remote_file_path, local_path)
         read_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                  remote_file_path)
         download_mock.assert_called_with(mock_credentials, ANY)