Пример #1
0
    def handle_batch_request(self, requests: Iterable[SimpleRequest],
                             func) -> Iterable[SimpleResponse]:
        bad_resp = SimpleResponse(400, None, "Bad Input")
        instances_list = [None] * len(requests)
        fallbacks = [bad_resp] * len(requests)
        batch_flags = [None] * len(requests)

        for i, request in enumerate(requests):
            batch_flags[i] = self.is_batch_request(request)
            try:
                raw_str = request.data
                parsed_json = json.loads(raw_str)
                instances_list[i] = parsed_json
            except (json.JSONDecodeError, UnicodeDecodeError):
                fallbacks[i] = SimpleResponse(400, None, "Not a valid json")
            except Exception:  # pylint: disable=broad-except
                import traceback

                err = traceback.format_exc()
                fallbacks[i] = SimpleResponse(500, None,
                                              f"Internal Server Error: {err}")

        merged_instances, slices = concat_list(instances_list,
                                               batch_flags=batch_flags)
        merged_result = func(merged_instances)
        return self.output_adapter.to_batch_response(merged_result, slices,
                                                     fallbacks, requests)
Пример #2
0
def test_concat_lists_with_flags():
    lst = [
        [[1], [2]],
        [],
        None,
        [1],
        "string",
        None,
    ]
    flags = [
        True,
        True,
        True,
        False,
        False,
        False,
    ]

    datas, slices = concat_list(lst, flags)
    assert datas == [[1], [2], [1], "string"]

    for s, origin_data in zip(slices, lst):
        if s is None:
            assert origin_data is None
        else:
            assert origin_data == datas[s]
Пример #3
0
    def handle_batch_request(self, requests: Iterable[SimpleRequest],
                             func) -> Iterable[SimpleResponse]:
        """
        TODO(hrmthw):
        1. specify batch dim
        1. output str fromat
        """
        import tensorflow as tf

        bad_resp = SimpleResponse(400, None, "input format error")
        instances_list = [None] * len(requests)
        responses = [bad_resp] * len(requests)
        batch_flags = [None] * len(requests)

        for i, request in enumerate(requests):
            try:
                raw_str = request.data
                batch_flags[i] = (request.formated_headers.get(
                    self._BATCH_REQUEST_HEADER.lower(),
                    "true" if self.config.get("is_batch_input") else "false",
                ) == "true")
                parsed_json = json.loads(raw_str)
                if parsed_json.get("instances") is not None:
                    instances = parsed_json.get("instances")
                    if instances is None:
                        continue
                    instances = decode_b64_if_needed(instances)
                    instances_list[i] = instances

                elif parsed_json.get("inputs"):
                    responses[i] = SimpleResponse(
                        501, None, "Column format 'inputs' not implemented")

            except (json.JSONDecodeError, UnicodeDecodeError):
                pass
            except Exception:  # pylint: disable=broad-except
                import traceback

                err = traceback.format_exc()
                responses[i] = SimpleResponse(500, None,
                                              f"Internal Server Error: {err}")

        merged_instances, slices = concat_list(instances_list,
                                               batch_flags=batch_flags)

        parsed_tensor = tf.constant(merged_instances)
        merged_result = func(parsed_tensor)
        return self.output_adapter.to_batch_response(merged_result,
                                                     slices=slices,
                                                     fallbacks=responses,
                                                     requests=requests)
Пример #4
0
def test_concat():
    lst = [
        [1],
        [1, 2],
        [],
        [1, 2, 3],
        None,
    ]
    datas, slices = concat_list(lst)

    for s, origin_data in zip(slices, lst):
        if s is None:
            assert origin_data is None
        else:
            assert origin_data == datas[s]

    lst = [
        [1],
        None,
        1,
        None,
    ]
    flags = [
        True,
        True,
        False,
        False,
    ]

    datas, slices = concat_list(lst, flags)

    for s, origin_data in zip(slices, lst):
        if s is None:
            assert origin_data is None
        else:
            assert origin_data == datas[s]
Пример #5
0
def test_concat():
    lst = [
        None,
        [],
        [1],
        [1, 2],
        [1, 2, 3],
    ]
    datas, slices = concat_list(lst)

    for s, origin_data in zip(slices, lst):
        if s is None:
            assert origin_data is None
        else:
            assert origin_data == datas[s]
Пример #6
0
    def handle_batch_request(
        self, requests: Iterable[SimpleRequest], func
    ) -> Iterable[SimpleResponse]:
        """
        TODO(bojiang):
        1. specify batch dim
        """
        import tensorflow as tf

        bad_resp = SimpleResponse(400, None, "input format error")
        instances_list = [None] * len(requests)
        responses = [bad_resp] * len(requests)
        batch_flags = [None] * len(requests)

        for i, request in enumerate(requests):
            try:
                raw_str = request.data
                batch_flags[i] = self.is_batch_request(request)
                parsed_json = json.loads(raw_str, object_hook=b64_hook)
                if parsed_json.get("instances") is not None:
                    instances = parsed_json.get("instances")
                    if instances is None:
                        continue
                    instances_list[i] = instances

                elif parsed_json.get("inputs"):
                    responses[i] = SimpleResponse(
                        501, None, "Column format 'inputs' not implemented"
                    )

            except (json.JSONDecodeError, UnicodeDecodeError):
                pass
            except Exception:  # pylint: disable=broad-except
                import traceback

                err = traceback.format_exc()
                responses[i] = SimpleResponse(
                    500, None, f"Internal Server Error: {err}"
                )
        merged_instances, slices = concat_list(instances_list, batch_flags=batch_flags)
        parsed_tensor = tf.constant(merged_instances)
        merged_result = func(parsed_tensor)
        return self.output_adapter.to_batch_response(
            merged_result, slices=slices, fallbacks=responses, requests=requests
        )