示例#1
0
文件: grpc.py 项目: jina-ai/jina
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        compression: Optional[str] = None,
        **kwargs,
    ):
        try:
            self.compression = (
                getattr(grpc.Compression, compression)
                if compression
                else grpc.Compression.NoCompression
            )

            self.inputs = inputs
            req_iter = self._get_requests(**kwargs)
            async with GrpcConnectionPool.get_grpc_channel(
                f'{self.args.host}:{self.args.port}',
                asyncio=True,
                tls=self.args.tls,
            ) as channel:
                stub = jina_pb2_grpc.JinaRPCStub(channel)
                self.logger.debug(f'connected to {self.args.host}:{self.args.port}')

                with ProgressBar(
                    total_length=self._inputs_length, disable=not (self.show_progress)
                ) as p_bar:

                    async for resp in stub.Call(
                        req_iter,
                        compression=self.compression,
                    ):
                        callback_exec(
                            response=resp,
                            on_error=on_error,
                            on_done=on_done,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                        if self.show_progress:
                            p_bar.update()
                        yield resp

        except KeyboardInterrupt:
            self.logger.warning('user cancel the process')
        except asyncio.CancelledError as ex:
            self.logger.warning(f'process error: {ex!r}')
        except (grpc.aio._call.AioRpcError, InternalNetworkError) as err:
            my_code = err.code()
            my_details = err.details()
            msg = f'gRPC error: {my_code} {my_details}'

            try:
                if my_code == grpc.StatusCode.UNAVAILABLE:
                    self.logger.error(
                        f'{msg}\nThe ongoing request is terminated as the server is not available or closed already.'
                    )
                    raise ConnectionError(my_details) from None
                elif my_code == grpc.StatusCode.DEADLINE_EXCEEDED:
                    self.logger.error(
                        f'{msg}\nThe ongoing request is terminated due to a server-side timeout.'
                    )
                    raise ConnectionError(my_details) from None
                elif my_code == grpc.StatusCode.INTERNAL:
                    self.logger.error(f'{msg}\ninternal error on the server side')
                    raise err
                elif (
                    my_code == grpc.StatusCode.UNKNOWN
                    and 'asyncio.exceptions.TimeoutError' in my_details
                ):
                    raise BadClientInput(
                        f'{msg}\n'
                        'often the case is that you define/send a bad input iterator to jina, '
                        'please double check your input iterator'
                    ) from err
                else:
                    raise BadClient(msg) from err

            except (
                grpc.aio._call.AioRpcError,
                BaseJinaException,
                ConnectionError,
            ) as e:  # depending on if there are callbacks we catch or not the exception
                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e
示例#2
0
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        **kwargs,
    ):
        """
        :param inputs: the callable
        :param on_done: the callback for on_done
        :param on_error: the callback for on_error
        :param on_always: the callback for on_always
        :param kwargs: kwargs for _get_task_name and _get_requests
        :yields: generator over results
        """
        with ImportExtensions(required=True):
            import aiohttp

        self.inputs = inputs
        request_iterator = self._get_requests(**kwargs)

        async with AsyncExitStack() as stack:
            try:
                cm1 = ProgressBar(total_length=self._inputs_length,
                                  disable=not (self.show_progress))
                p_bar = stack.enter_context(cm1)

                proto = 'https' if self.args.tls else 'http'
                url = f'{proto}://{self.args.host}:{self.args.port}/post'
                iolet = await stack.enter_async_context(
                    HTTPClientlet(url=url, logger=self.logger))

                def _request_handler(request: 'Request') -> 'asyncio.Future':
                    """
                    For HTTP Client, for each request in the iterator, we `send_message` using
                    http POST request and add it to the list of tasks which is awaited and yielded.
                    :param request: current request in the iterator
                    :return: asyncio Task for sending message
                    """
                    return asyncio.ensure_future(
                        iolet.send_message(request=request))

                def _result_handler(result):
                    return result

                streamer = RequestStreamer(
                    self.args,
                    request_handler=_request_handler,
                    result_handler=_result_handler,
                )
                async for response in streamer.stream(request_iterator):
                    r_status = response.status

                    r_str = await response.json()
                    self._handle_response_status(r_status, r_str, url)

                    da = None
                    if 'data' in r_str and r_str['data'] is not None:
                        from docarray import DocumentArray

                        da = DocumentArray.from_dict(r_str['data'])
                        del r_str['data']

                    resp = DataRequest(r_str)
                    if da is not None:
                        resp.data.docs = da

                    callback_exec(
                        response=resp,
                        on_error=on_error,
                        on_done=on_done,
                        on_always=on_always,
                        continue_on_error=self.continue_on_error,
                        logger=self.logger,
                    )
                    if self.show_progress:
                        p_bar.update()
                    yield resp

            except (aiohttp.ClientError, ValueError, ConnectionError) as e:
                self.logger.error(
                    f'Error while fetching response from HTTP server {e!r}')

                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e
示例#3
0
文件: websocket.py 项目: sthagen/jina
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        **kwargs,
    ):
        """
        :param inputs: the callable
        :param on_done: the callback for on_done
        :param on_error: the callback for on_error
        :param on_always: the callback for on_always
        :param kwargs: kwargs for _get_task_name and _get_requests
        :yields: generator over results
        """
        with ImportExtensions(required=True):
            import aiohttp

        self.inputs = inputs
        request_iterator = self._get_requests(**kwargs)

        async with AsyncExitStack() as stack:
            try:
                cm1 = ProgressBar(
                    total_length=self._inputs_length, disable=not (self.show_progress)
                )
                p_bar = stack.enter_context(cm1)

                proto = 'wss' if self.args.tls else 'ws'
                url = f'{proto}://{self.args.host}:{self.args.port}/'
                iolet = await stack.enter_async_context(
                    WebsocketClientlet(url=url, logger=self.logger)
                )

                request_buffer: Dict[str, asyncio.Future] = dict()

                def _result_handler(result):
                    return result

                async def _receive():
                    def _response_handler(response):
                        if response.header.request_id in request_buffer:
                            future = request_buffer.pop(response.header.request_id)
                            future.set_result(response)
                        else:
                            self.logger.warning(
                                f'discarding unexpected response with request id {response.header.request_id}'
                            )

                    """Await messages from WebsocketGateway and process them in the request buffer"""
                    try:
                        async for response in iolet.recv_message():
                            _response_handler(response)
                    finally:
                        if request_buffer:
                            self.logger.warning(
                                f'{self.__class__.__name__} closed, cancelling all outstanding requests'
                            )
                            for future in request_buffer.values():
                                future.cancel()
                            request_buffer.clear()

                def _handle_end_of_iter():
                    """Send End of iteration signal to the Gateway"""
                    asyncio.create_task(iolet.send_eoi())

                def _request_handler(request: 'Request') -> 'asyncio.Future':
                    """
                    For each request in the iterator, we send the `Message` using `iolet.send_message()`.
                    For websocket requests from client, for each request in the iterator, we send the request in `bytes`
                    using using `iolet.send_message()`.
                    Then add {<request-id>: <an-empty-future>} to the request buffer.
                    This empty future is used to track the `result` of this request during `receive`.
                    :param request: current request in the iterator
                    :return: asyncio Future for sending message
                    """
                    future = get_or_reuse_loop().create_future()
                    request_buffer[request.header.request_id] = future
                    asyncio.create_task(iolet.send_message(request))
                    return future

                streamer = RequestStreamer(
                    args=self.args,
                    request_handler=_request_handler,
                    result_handler=_result_handler,
                    end_of_iter_handler=_handle_end_of_iter,
                )

                receive_task = get_or_reuse_loop().create_task(_receive())

                if receive_task.done():
                    raise RuntimeError(
                        'receive task not running, can not send messages'
                    )
                async for response in streamer.stream(request_iterator):
                    callback_exec(
                        response=response,
                        on_error=on_error,
                        on_done=on_done,
                        on_always=on_always,
                        continue_on_error=self.continue_on_error,
                        logger=self.logger,
                    )
                    if self.show_progress:
                        p_bar.update()
                    yield response

            except aiohttp.ClientError as e:
                self.logger.error(
                    f'Error while streaming response from websocket server {e!r}'
                )

                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e
示例#4
0
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        compression: str = 'NoCompression',
        **kwargs,
    ):
        try:
            if compression.lower() not in GRPC_COMPRESSION_MAP:
                import warnings

                warnings.warn(
                    message=
                    f'Your compression "{compression}" is not supported. Supported '
                    f'algorithms are `Gzip`, `Deflate` and `NoCompression`. NoCompression will be used as '
                    f'default')
                compression = 'NoCompression'
            self.inputs = inputs
            req_iter = self._get_requests(**kwargs)
            async with GrpcConnectionPool.get_grpc_channel(
                    f'{self.args.host}:{self.args.port}',
                    asyncio=True,
                    tls=self.args.tls,
            ) as channel:
                stub = jina_pb2_grpc.JinaRPCStub(channel)
                self.logger.debug(
                    f'connected to {self.args.host}:{self.args.port}')

                with ProgressBar(total_length=self._inputs_length,
                                 disable=not (self.show_progress)) as p_bar:

                    async for resp in stub.Call(
                            req_iter,
                            compression=GRPC_COMPRESSION_MAP.get(
                                compression.lower(),
                                grpc.Compression.NoCompression),
                    ):
                        callback_exec(
                            response=resp,
                            on_error=on_error,
                            on_done=on_done,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                        if self.show_progress:
                            p_bar.update()
                        yield resp

        except KeyboardInterrupt:
            self.logger.warning('user cancel the process')
        except asyncio.CancelledError as ex:
            self.logger.warning(f'process error: {ex!r}')
        except grpc.aio._call.AioRpcError as rpc_ex:
            # Since this object is guaranteed to be a grpc.Call, might as well include that in its name.
            my_code = rpc_ex.code()
            my_details = rpc_ex.details()
            msg = f'gRPC error: {my_code} {my_details}'

            try:
                if my_code == grpc.StatusCode.UNAVAILABLE:
                    self.logger.error(
                        f'{msg}\nthe ongoing request is terminated as the server is not available or closed already'
                    )
                    raise rpc_ex
                elif my_code == grpc.StatusCode.INTERNAL:
                    self.logger.error(
                        f'{msg}\ninternal error on the server side')
                    raise rpc_ex
                elif (my_code == grpc.StatusCode.UNKNOWN
                      and 'asyncio.exceptions.TimeoutError' in my_details):
                    raise BadClientInput(
                        f'{msg}\n'
                        'often the case is that you define/send a bad input iterator to jina, '
                        'please double check your input iterator') from rpc_ex
                else:
                    raise BadClient(msg) from rpc_ex

            except (
                    grpc.aio._call.AioRpcError,
                    BaseJinaException,
            ) as e:  # depending on if there are callbacks we catch or not the exception
                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e