def do_bulk_inference( self, model_name: str, objects: List[dict], top_n: int = TOP_N, retry: bool = False, ) -> List[dict]: """ Performs bulk inference for larger collections. For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits the data into several smaller Inference requests. Returns the aggregated values of the *predictions* of the original API response as returned by :meth:`create_inference_request`. :param model_name: name of the model used for inference :param objects: Objects to be classified :param top_n: How many predictions to return per object :return: the aggregated ObjectPrediction dictionaries """ result = [] # type: List[dict] for work_package in split_list(objects, LIMIT_OBJECTS_PER_CALL): response = self.create_inference_request(model_name, work_package, top_n=top_n, retry=retry) result.extend(response["predictions"]) return result
def test_slice_size_invalid(self): for invalid_slice_size in [-1000, -1, 0]: with pytest.raises(ValueError): list(split_list(["a", "b"], invalid_slice_size))
def test_slice_size_bigger_than_list(self): res = list(split_list(["a", "b", "c", "d"], 6)) assert res == [["a", "b", "c", "d"]]
def test_list_uneven(self): res = list(split_list(["a", "b", "c", "d"], 3)) assert res == [["a", "b", "c"], ["d"]]
def test_regular_case(self): res = list(split_list(["a", "b", "c", "d"], 2)) assert res == [["a", "b"], ["c", "d"]]
def test_empty_list(self): res = list(split_list([], slice_size=1)) assert res == [[]]