def test_get_read_credential_infos_respects_max_request_size(self, databricks_artifact_repo): """ Verifies that the `_get_read_credential_infos` method, which is used to resolve read access credentials for a collection of artifacts, handles paginated responses properly, issuing incremental requests until all pages have been consumed """ assert _MAX_CREDENTIALS_REQUEST_SIZE == 2000, ( "The maximum request size configured by the client should be consistent with the" " Databricks backend. Only update this value of the backend limit has changed." ) # Create 3 chunks of paths, two of which have the maximum request size and one of which # is smaller than the maximum chunk size. Aggregate and pass these to # `_get_read_credential_infos`, validating that this method decomposes the aggregate # list into these expected chunks and makes 3 separate requests paths_chunk_1 = ["path1"] * _MAX_CREDENTIALS_REQUEST_SIZE paths_chunk_2 = ["path2"] * _MAX_CREDENTIALS_REQUEST_SIZE paths_chunk_3 = ["path3"] * 5 credential_infos_mock_1 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE) ] credential_infos_mock_2 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE) ] credential_infos_mock_3 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(5) ] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json" ) as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint" ) as call_endpoint_mock: call_endpoint_mock.side_effect = [ GetCredentialsForRead.Response(credential_infos=credential_infos_mock_1,), GetCredentialsForRead.Response(credential_infos=credential_infos_mock_2,), GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3,), ] databricks_artifact_repo._get_read_credential_infos( MOCK_RUN_ID, paths_chunk_1 + paths_chunk_2 + paths_chunk_3, ) assert call_endpoint_mock.call_count == 3 assert message_mock.call_count == 3 message_mock.assert_has_calls( [ mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_1)), mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_2)), mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_3)), ] )
def test_get_read_credential_infos_handles_pagination(self, databricks_artifact_repo): """ Verifies that the `get_read_credential_infos` method, which is used to resolve read access credentials for a collection of artifacts, handles paginated responses properly, issuing incremental requests until all pages have been consumed """ credential_infos_mock_1 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL ), ArtifactCredentialInfo( signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL ), ] credential_infos_mock_2 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) ] credential_infos_mock_3 = [] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json" ) as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint" ) as call_endpoint_mock: get_credentials_for_read_responses = [ GetCredentialsForRead.Response( credential_infos=credential_infos_mock_1, next_page_token="2" ), GetCredentialsForRead.Response( credential_infos=credential_infos_mock_2, next_page_token="3" ), GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3), ] call_endpoint_mock.side_effect = get_credentials_for_read_responses read_credential_infos = databricks_artifact_repo._get_read_credential_infos( MOCK_RUN_ID, ["testpath"], ) assert read_credential_infos == credential_infos_mock_1 + credential_infos_mock_2 message_mock.assert_has_calls( [ mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"])), mock.call( GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"], page_token="2") ), mock.call( GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"], page_token="3") ), ] ) assert call_endpoint_mock.call_count == 3
def _get_read_credentials(self, run_id, path=None): json_body = message_to_json( GetCredentialsForRead(run_id=run_id, path=path)) return self._call_endpoint(DatabricksMlflowArtifactsService, GetCredentialsForRead, json_body)