Example #1
0
    def test_get_read_credential_infos_respects_max_request_size(self, databricks_artifact_repo):
        """
        Verifies that the `_get_read_credential_infos` method, which is used to resolve read access
        credentials for a collection of artifacts, handles paginated responses properly, issuing
        incremental requests until all pages have been consumed
        """
        assert _MAX_CREDENTIALS_REQUEST_SIZE == 2000, (
            "The maximum request size configured by the client should be consistent with the"
            " Databricks backend. Only update this value of the backend limit has changed."
        )

        # Create 3 chunks of paths, two of which have the maximum request size and one of which
        # is smaller than the maximum chunk size. Aggregate and pass these to
        # `_get_read_credential_infos`, validating that this method decomposes the aggregate
        # list into these expected chunks and makes 3 separate requests
        paths_chunk_1 = ["path1"] * _MAX_CREDENTIALS_REQUEST_SIZE
        paths_chunk_2 = ["path2"] * _MAX_CREDENTIALS_REQUEST_SIZE
        paths_chunk_3 = ["path3"] * 5
        credential_infos_mock_1 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE)
        ]
        credential_infos_mock_2 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE)
        ]
        credential_infos_mock_3 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(5)
        ]

        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json"
        ) as message_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint"
        ) as call_endpoint_mock:
            call_endpoint_mock.side_effect = [
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_1,),
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_2,),
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3,),
            ]

            databricks_artifact_repo._get_read_credential_infos(
                MOCK_RUN_ID, paths_chunk_1 + paths_chunk_2 + paths_chunk_3,
            )
            assert call_endpoint_mock.call_count == 3
            assert message_mock.call_count == 3
            message_mock.assert_has_calls(
                [
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_1)),
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_2)),
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_3)),
                ]
            )
    def test_get_read_credential_infos_handles_pagination(self, databricks_artifact_repo):
        """
        Verifies that the `get_read_credential_infos` method, which is used to resolve read access
        credentials for a collection of artifacts, handles paginated responses properly, issuing
        incremental requests until all pages have been consumed
        """
        credential_infos_mock_1 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            ),
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            ),
        ]
        credential_infos_mock_2 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
        ]
        credential_infos_mock_3 = []

        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json"
        ) as message_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint"
        ) as call_endpoint_mock:
            get_credentials_for_read_responses = [
                GetCredentialsForRead.Response(
                    credential_infos=credential_infos_mock_1, next_page_token="2"
                ),
                GetCredentialsForRead.Response(
                    credential_infos=credential_infos_mock_2, next_page_token="3"
                ),
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3),
            ]
            call_endpoint_mock.side_effect = get_credentials_for_read_responses
            read_credential_infos = databricks_artifact_repo._get_read_credential_infos(
                MOCK_RUN_ID,
                ["testpath"],
            )
            assert read_credential_infos == credential_infos_mock_1 + credential_infos_mock_2
            message_mock.assert_has_calls(
                [
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"])),
                    mock.call(
                        GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"], page_token="2")
                    ),
                    mock.call(
                        GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"], page_token="3")
                    ),
                ]
            )
            assert call_endpoint_mock.call_count == 3
Example #3
0
 def _get_read_credentials(self, run_id, path=None):
     json_body = message_to_json(
         GetCredentialsForRead(run_id=run_id, path=path))
     return self._call_endpoint(DatabricksMlflowArtifactsService,
                                GetCredentialsForRead, json_body)