コード例 #1
0
def test_concat():
    lst = [
        [1],
        [1, 2],
        [],
        [1, 2, 3],
    ]
    datas, slices = concat_list(lst)

    for s, origin_data in zip(slices, lst):
        assert origin_data == datas[s]
コード例 #2
0
    def handle_batch_request(
        self, requests: Iterable[SimpleRequest], func
    ) -> Iterable[SimpleResponse]:
        """
        TODO(hrmthw):
        1. check content type
        1. specify batch dim
        1. output str fromat
        """
        import tensorflow as tf

        bad_resp = SimpleResponse(b"Bad Input", None, 400)
        instances_list = [None] * len(requests)
        responses = [bad_resp] * len(requests)

        for i, request in enumerate(requests):
            try:
                raw_str = request[0]  # .decode("utf-8")
                parsed_json = json.loads(raw_str)
                if parsed_json.get("instances") is not None:
                    instances = parsed_json.get("instances")
                    if instances is None:
                        continue
                    instances = decode_b64_if_needed(instances)
                    if not isinstance(instances, (list, tuple)):
                        instances = [instances]
                    instances_list[i] = instances

                elif parsed_json.get("inputs"):
                    responses[i] = SimpleResponse(
                        "Column format 'inputs' not implemented", None, 501,
                    )

            except (json.exceptions.JSONDecodeError, UnicodeDecodeError):
                import traceback

                traceback.print_exc()

        merged_instances, slices = concat_list(instances_list)

        parsed_tensor = tf.constant(merged_instances)
        merged_result = func(parsed_tensor)
        merged_result = decode_tf_if_needed(merged_result)
        assert isinstance(merged_result, (list, tuple))

        results = [merged_result[s] for s in slices]

        for i, result in enumerate(results):
            result_str = api_func_result_to_json(result)
            responses[i] = SimpleResponse(result_str, dict(), 200)

        return responses
コード例 #3
0
ファイル: json_handler.py プロジェクト: zhentan/BentoML
    def handle_batch_request(
        self, requests: Iterable[SimpleRequest], func
    ) -> Iterable[SimpleResponse]:
        bad_resp = SimpleResponse(400, None, "Bad Input")
        instances_list = [None] * len(requests)
        responses = [bad_resp] * len(requests)
        batch_flags = [None] * len(requests)

        for i, request in enumerate(requests):
            batch_flags[i] = (
                request.formated_headers.get(
                    self._BATCH_REQUEST_HEADER.lower(),
                    "true" if self.config.get('is_batch_input') else "false",
                )
                == "true"
            )
            try:
                raw_str = request.data
                parsed_json = json.loads(raw_str)
                if not batch_flags[i]:
                    parsed_json = (parsed_json,)
                instances_list[i] = parsed_json
            except (json.JSONDecodeError, UnicodeDecodeError):
                responses[i] = SimpleResponse(400, None, "Not a valid json")
            except Exception:  # pylint: disable=broad-except
                import traceback

                err = traceback.format_exc()
                responses[i] = SimpleResponse(
                    500, None, f"Internal Server Error: {err}"
                )

        merged_instances, slices = concat_list(instances_list)
        merged_result = func(merged_instances)
        if not isinstance(merged_result, (list, tuple)) or len(merged_result) != len(
            merged_instances
        ):
            raise ValueError(
                "The return value with JsonHandler must be list of jsonable objects, "
                "and have same length as the inputs."
            )

        for i, s in enumerate(slices):
            if s is None:
                continue
            result = merged_result[s]
            if not batch_flags[i]:
                result = result[0]
            result_str = api_func_result_to_json(result)
            responses[i] = SimpleResponse(200, dict(), result_str)

        return responses
コード例 #4
0
    def handle_batch_request(self, requests: Iterable[SimpleRequest],
                             func) -> Iterable[SimpleResponse]:
        bad_resp = SimpleResponse(400, None, "Bad Input")
        instances_list = [None] * len(requests)
        responses = [bad_resp] * len(requests)

        for i, request in enumerate(requests):
            try:
                raw_str = request.data
                parsed_json = json.loads(raw_str)
                instances_list[i] = parsed_json
            except (json.JSONDecodeError, UnicodeDecodeError):
                responses[i] = SimpleResponse(400, None,
                                              "not a valid json input")
            except Exception:  # pylint: disable=broad-except
                responses[i] = SimpleResponse(500, None,
                                              "internal server error")
                import traceback

                traceback.print_exc()

        merged_instances, slices = concat_list(instances_list)
        merged_result = func(merged_instances)
        if not isinstance(
                merged_result,
            (list, tuple)) or len(merged_result) != len(merged_instances):
            raise ValueError(
                "The return value with JsonHandler must be list of jsonable objects, "
                "and have same length as the inputs.")

        for i, s in enumerate(slices):
            if s is None:
                continue
            result_str = api_func_result_to_json(merged_result[s])
            responses[i] = SimpleResponse(200, dict(), result_str)

        return responses
コード例 #5
0
    def handle_batch_request(
        self, requests: Iterable[SimpleRequest], func
    ) -> Iterable[SimpleResponse]:
        """
        TODO(hrmthw):
        1. specify batch dim
        1. output str fromat
        """
        import tensorflow as tf

        bad_resp = SimpleResponse(400, None, "input format error")
        instances_list = [None] * len(requests)
        responses = [bad_resp] * len(requests)
        batch_flags = [None] * len(requests)

        for i, request in enumerate(requests):
            try:
                raw_str = request.data
                batch_flags[i] = (
                    request.formated_headers.get(
                        self._BATCH_REQUEST_HEADER.lower(),
                        "true" if self.config.get("is_batch_input") else "false",
                    )
                    == "true"
                )
                parsed_json = json.loads(raw_str)
                if parsed_json.get("instances") is not None:
                    instances = parsed_json.get("instances")
                    if instances is None:
                        continue
                    instances = decode_b64_if_needed(instances)
                    if not batch_flags[i]:
                        instances = (instances,)
                    instances_list[i] = instances

                elif parsed_json.get("inputs"):
                    responses[i] = SimpleResponse(
                        501, None, "Column format 'inputs' not implemented"
                    )

            except (json.JSONDecodeError, UnicodeDecodeError):
                pass
            except Exception:  # pylint: disable=broad-except
                import traceback

                err = traceback.format_exc()
                responses[i] = SimpleResponse(
                    500, None, f"Internal Server Error: {err}"
                )

        merged_instances, slices = concat_list(instances_list)

        parsed_tensor = tf.constant(merged_instances)
        merged_result = func(parsed_tensor)
        merged_result = decode_tf_if_needed(merged_result)
        assert isinstance(merged_result, (list, tuple))

        for i, s in enumerate(slices):
            if s is None:
                continue
            result = merged_result[s]
            if not batch_flags[i]:
                result = result[0]
            result_str = api_func_result_to_json(result)
            responses[i] = SimpleResponse(200, dict(), result_str)

        return responses