Ejemplo n.º 1
0
    def do_bulk_inference(
        self,
        model_name: str,
        objects: List[dict],
        top_n: int = TOP_N,
        retry: bool = False,
    ) -> List[dict]:
        """
        Performs bulk inference for larger collections.

        For *objects* collections larger than *LIMIT_OBJECTS_PER_CALL*, splits
        the data into several smaller Inference requests.

        Returns the aggregated values of the *predictions* of the original API response
        as returned by :meth:`create_inference_request`.

        :param model_name: name of the model used for inference
        :param objects: Objects to be classified
        :param top_n: How many predictions to return per object
        :return: the aggregated ObjectPrediction dictionaries
        """
        result = []  # type: List[dict]
        for work_package in split_list(objects, LIMIT_OBJECTS_PER_CALL):
            response = self.create_inference_request(model_name,
                                                     work_package,
                                                     top_n=top_n,
                                                     retry=retry)
            result.extend(response["predictions"])
        return result
 def test_slice_size_invalid(self):
     for invalid_slice_size in [-1000, -1, 0]:
         with pytest.raises(ValueError):
             list(split_list(["a", "b"], invalid_slice_size))
 def test_slice_size_bigger_than_list(self):
     res = list(split_list(["a", "b", "c", "d"], 6))
     assert res == [["a", "b", "c", "d"]]
 def test_list_uneven(self):
     res = list(split_list(["a", "b", "c", "d"], 3))
     assert res == [["a", "b", "c"], ["d"]]
 def test_regular_case(self):
     res = list(split_list(["a", "b", "c", "d"], 2))
     assert res == [["a", "b"], ["c", "d"]]
 def test_empty_list(self):
     res = list(split_list([], slice_size=1))
     assert res == [[]]