def test_get_read_credential_infos_respects_max_request_size(self, databricks_artifact_repo): """ Verifies that the `_get_read_credential_infos` method, which is used to resolve read access credentials for a collection of artifacts, handles paginated responses properly, issuing incremental requests until all pages have been consumed """ assert _MAX_CREDENTIALS_REQUEST_SIZE == 2000, ( "The maximum request size configured by the client should be consistent with the" " Databricks backend. Only update this value of the backend limit has changed." ) # Create 3 chunks of paths, two of which have the maximum request size and one of which # is smaller than the maximum chunk size. Aggregate and pass these to # `_get_read_credential_infos`, validating that this method decomposes the aggregate # list into these expected chunks and makes 3 separate requests paths_chunk_1 = ["path1"] * _MAX_CREDENTIALS_REQUEST_SIZE paths_chunk_2 = ["path2"] * _MAX_CREDENTIALS_REQUEST_SIZE paths_chunk_3 = ["path3"] * 5 credential_infos_mock_1 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE) ] credential_infos_mock_2 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE) ] credential_infos_mock_3 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) for _ in range(5) ] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json" ) as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint" ) as call_endpoint_mock: call_endpoint_mock.side_effect = [ GetCredentialsForRead.Response(credential_infos=credential_infos_mock_1,), GetCredentialsForRead.Response(credential_infos=credential_infos_mock_2,), GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3,), ] databricks_artifact_repo._get_read_credential_infos( MOCK_RUN_ID, paths_chunk_1 + paths_chunk_2 + paths_chunk_3, ) assert call_endpoint_mock.call_count == 3 assert message_mock.call_count == 3 message_mock.assert_has_calls( [ mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_1)), mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_2)), mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_3)), ] )
def test_get_read_credential_infos_handles_pagination(self, databricks_artifact_repo): """ Verifies that the `get_read_credential_infos` method, which is used to resolve read access credentials for a collection of artifacts, handles paginated responses properly, issuing incremental requests until all pages have been consumed """ credential_infos_mock_1 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL ), ArtifactCredentialInfo( signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL ), ] credential_infos_mock_2 = [ ArtifactCredentialInfo( signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL ) ] credential_infos_mock_3 = [] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json" ) as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint" ) as call_endpoint_mock: get_credentials_for_read_responses = [ GetCredentialsForRead.Response( credential_infos=credential_infos_mock_1, next_page_token="2" ), GetCredentialsForRead.Response( credential_infos=credential_infos_mock_2, next_page_token="3" ), GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3), ] call_endpoint_mock.side_effect = get_credentials_for_read_responses read_credential_infos = databricks_artifact_repo._get_read_credential_infos( MOCK_RUN_ID, ["testpath"], ) assert read_credential_infos == credential_infos_mock_1 + credential_infos_mock_2 message_mock.assert_has_calls( [ mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"])), mock.call( GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"], page_token="2") ), mock.call( GetCredentialsForRead(run_id=MOCK_RUN_ID, path=["testpath"], page_token="3") ), ] ) assert call_endpoint_mock.call_count == 3
def test_databricks_download_file_with_relative_path( self, remote_file_path, local_path): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credentials") as read_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts") as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud") as download_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto download_mock.return_value = None get_list_mock.return_value = [] databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) databricks_artifact_repo.download_artifacts( remote_file_path, local_path) read_credentials_mock.assert_called_with( MOCK_RUN_ID, posixpath.join(MOCK_SUBDIR, remote_file_path)) download_mock.assert_called_with(mock_credentials, ANY)
def test_download_artifacts_provides_failure_info( self, databricks_artifact_repo): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credentials") as read_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts") as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud") as download_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto get_list_mock.return_value = [ FileInfo(path="file_1.txt", is_dir=False, file_size=100), FileInfo(path="file_2.txt", is_dir=False, file_size=0), ] download_mock.side_effect = [ MlflowException("MOCK ERROR 1"), MlflowException("MOCK ERROR 2"), ] with pytest.raises(MlflowException) as exc: databricks_artifact_repo.download_artifacts("test_path") assert MOCK_RUN_ROOT_URI in str(exc) assert "file_1.txt" in str(exc) assert "MOCK ERROR 1" in str(exc) assert "file_2.txt" in str(exc) assert "MOCK ERROR 2" in str(exc)
def test_databricks_download_file_get_request_fail( self, databricks_artifact_repo, test_file): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ mock.patch('requests.get') as request_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto get_list_mock.return_value = [] request_mock.return_value = MlflowException("MOCK ERROR") with pytest.raises(MlflowException): databricks_artifact_repo.download_artifacts(test_file.strpath) read_credentials_mock.assert_called_with(MOCK_RUN_ID, test_file.strpath)
def test_download_artifacts_awaits_download_completion( self, databricks_artifact_repo, tmpdir): """ Verifies that all asynchronous artifact downloads are joined before `download_artifacts()` returns a result to the caller """ with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credentials") as read_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts") as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud") as download_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto get_list_mock.return_value = [ FileInfo(path="file_1.txt", is_dir=False, file_size=100), FileInfo(path="file_2.txt", is_dir=False, file_size=0), ] def mock_download_from_cloud(cloud_credential, local_file_path): # pylint: disable=unused-argument # Sleep in order to simulate a longer-running asynchronous download time.sleep(2) with open(local_file_path, "w") as f: f.write("content") download_mock.side_effect = mock_download_from_cloud databricks_artifact_repo.download_artifacts( "test_path", str(tmpdir)) expected_file1_path = os.path.join(str(tmpdir), "file_1.txt") expected_file2_path = os.path.join(str(tmpdir), "file_2.txt") for path in [expected_file1_path, expected_file2_path]: assert os.path.exists(path) with open(path, "r") as f: assert f.read() == "content"
def test_databricks_download_file(self, databricks_artifact_repo, remote_file_path, local_path, cloud_credential_type): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \ as read_credentials_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \ mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \ as download_mock: mock_credentials = ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type) read_credentials_response_proto = GetCredentialsForRead.Response( credentials=mock_credentials) read_credentials_mock.return_value = read_credentials_response_proto download_mock.return_value = None get_list_mock.return_value = [] databricks_artifact_repo.download_artifacts( remote_file_path, local_path) read_credentials_mock.assert_called_with(MOCK_RUN_ID, remote_file_path) download_mock.assert_called_with(mock_credentials, ANY)