async def _get_results( self, inputs: 'InputType', on_done: 'CallbackFnType', on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, compression: Optional[str] = None, **kwargs, ): try: self.compression = ( getattr(grpc.Compression, compression) if compression else grpc.Compression.NoCompression ) self.inputs = inputs req_iter = self._get_requests(**kwargs) async with GrpcConnectionPool.get_grpc_channel( f'{self.args.host}:{self.args.port}', asyncio=True, tls=self.args.tls, ) as channel: stub = jina_pb2_grpc.JinaRPCStub(channel) self.logger.debug(f'connected to {self.args.host}:{self.args.port}') with ProgressBar( total_length=self._inputs_length, disable=not (self.show_progress) ) as p_bar: async for resp in stub.Call( req_iter, compression=self.compression, ): callback_exec( response=resp, on_error=on_error, on_done=on_done, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) if self.show_progress: p_bar.update() yield resp except KeyboardInterrupt: self.logger.warning('user cancel the process') except asyncio.CancelledError as ex: self.logger.warning(f'process error: {ex!r}') except (grpc.aio._call.AioRpcError, InternalNetworkError) as err: my_code = err.code() my_details = err.details() msg = f'gRPC error: {my_code} {my_details}' try: if my_code == grpc.StatusCode.UNAVAILABLE: self.logger.error( f'{msg}\nThe ongoing request is terminated as the server is not available or closed already.' ) raise ConnectionError(my_details) from None elif my_code == grpc.StatusCode.DEADLINE_EXCEEDED: self.logger.error( f'{msg}\nThe ongoing request is terminated due to a server-side timeout.' ) raise ConnectionError(my_details) from None elif my_code == grpc.StatusCode.INTERNAL: self.logger.error(f'{msg}\ninternal error on the server side') raise err elif ( my_code == grpc.StatusCode.UNKNOWN and 'asyncio.exceptions.TimeoutError' in my_details ): raise BadClientInput( f'{msg}\n' 'often the case is that you define/send a bad input iterator to jina, ' 'please double check your input iterator' ) from err else: raise BadClient(msg) from err except ( grpc.aio._call.AioRpcError, BaseJinaException, ConnectionError, ) as e: # depending on if there are callbacks we catch or not the exception if on_error or on_always: if on_error: callback_exec_on_error(on_error, e, self.logger) if on_always: callback_exec( response=None, on_error=None, on_done=None, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) else: raise e
async def _get_results( self, inputs: 'InputType', on_done: 'CallbackFnType', on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, **kwargs, ): """ :param inputs: the callable :param on_done: the callback for on_done :param on_error: the callback for on_error :param on_always: the callback for on_always :param kwargs: kwargs for _get_task_name and _get_requests :yields: generator over results """ with ImportExtensions(required=True): import aiohttp self.inputs = inputs request_iterator = self._get_requests(**kwargs) async with AsyncExitStack() as stack: try: cm1 = ProgressBar(total_length=self._inputs_length, disable=not (self.show_progress)) p_bar = stack.enter_context(cm1) proto = 'https' if self.args.tls else 'http' url = f'{proto}://{self.args.host}:{self.args.port}/post' iolet = await stack.enter_async_context( HTTPClientlet(url=url, logger=self.logger)) def _request_handler(request: 'Request') -> 'asyncio.Future': """ For HTTP Client, for each request in the iterator, we `send_message` using http POST request and add it to the list of tasks which is awaited and yielded. :param request: current request in the iterator :return: asyncio Task for sending message """ return asyncio.ensure_future( iolet.send_message(request=request)) def _result_handler(result): return result streamer = RequestStreamer( self.args, request_handler=_request_handler, result_handler=_result_handler, ) async for response in streamer.stream(request_iterator): r_status = response.status r_str = await response.json() self._handle_response_status(r_status, r_str, url) da = None if 'data' in r_str and r_str['data'] is not None: from docarray import DocumentArray da = DocumentArray.from_dict(r_str['data']) del r_str['data'] resp = DataRequest(r_str) if da is not None: resp.data.docs = da callback_exec( response=resp, on_error=on_error, on_done=on_done, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) if self.show_progress: p_bar.update() yield resp except (aiohttp.ClientError, ValueError, ConnectionError) as e: self.logger.error( f'Error while fetching response from HTTP server {e!r}') if on_error or on_always: if on_error: callback_exec_on_error(on_error, e, self.logger) if on_always: callback_exec( response=None, on_error=None, on_done=None, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) else: raise e
async def _get_results( self, inputs: 'InputType', on_done: 'CallbackFnType', on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, **kwargs, ): """ :param inputs: the callable :param on_done: the callback for on_done :param on_error: the callback for on_error :param on_always: the callback for on_always :param kwargs: kwargs for _get_task_name and _get_requests :yields: generator over results """ with ImportExtensions(required=True): import aiohttp self.inputs = inputs request_iterator = self._get_requests(**kwargs) async with AsyncExitStack() as stack: try: cm1 = ProgressBar( total_length=self._inputs_length, disable=not (self.show_progress) ) p_bar = stack.enter_context(cm1) proto = 'wss' if self.args.tls else 'ws' url = f'{proto}://{self.args.host}:{self.args.port}/' iolet = await stack.enter_async_context( WebsocketClientlet(url=url, logger=self.logger) ) request_buffer: Dict[str, asyncio.Future] = dict() def _result_handler(result): return result async def _receive(): def _response_handler(response): if response.header.request_id in request_buffer: future = request_buffer.pop(response.header.request_id) future.set_result(response) else: self.logger.warning( f'discarding unexpected response with request id {response.header.request_id}' ) """Await messages from WebsocketGateway and process them in the request buffer""" try: async for response in iolet.recv_message(): _response_handler(response) finally: if request_buffer: self.logger.warning( f'{self.__class__.__name__} closed, cancelling all outstanding requests' ) for future in request_buffer.values(): future.cancel() request_buffer.clear() def _handle_end_of_iter(): """Send End of iteration signal to the Gateway""" asyncio.create_task(iolet.send_eoi()) def _request_handler(request: 'Request') -> 'asyncio.Future': """ For each request in the iterator, we send the `Message` using `iolet.send_message()`. For websocket requests from client, for each request in the iterator, we send the request in `bytes` using using `iolet.send_message()`. Then add {<request-id>: <an-empty-future>} to the request buffer. This empty future is used to track the `result` of this request during `receive`. :param request: current request in the iterator :return: asyncio Future for sending message """ future = get_or_reuse_loop().create_future() request_buffer[request.header.request_id] = future asyncio.create_task(iolet.send_message(request)) return future streamer = RequestStreamer( args=self.args, request_handler=_request_handler, result_handler=_result_handler, end_of_iter_handler=_handle_end_of_iter, ) receive_task = get_or_reuse_loop().create_task(_receive()) if receive_task.done(): raise RuntimeError( 'receive task not running, can not send messages' ) async for response in streamer.stream(request_iterator): callback_exec( response=response, on_error=on_error, on_done=on_done, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) if self.show_progress: p_bar.update() yield response except aiohttp.ClientError as e: self.logger.error( f'Error while streaming response from websocket server {e!r}' ) if on_error or on_always: if on_error: callback_exec_on_error(on_error, e, self.logger) if on_always: callback_exec( response=None, on_error=None, on_done=None, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) else: raise e
async def _get_results( self, inputs: 'InputType', on_done: 'CallbackFnType', on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, compression: str = 'NoCompression', **kwargs, ): try: if compression.lower() not in GRPC_COMPRESSION_MAP: import warnings warnings.warn( message= f'Your compression "{compression}" is not supported. Supported ' f'algorithms are `Gzip`, `Deflate` and `NoCompression`. NoCompression will be used as ' f'default') compression = 'NoCompression' self.inputs = inputs req_iter = self._get_requests(**kwargs) async with GrpcConnectionPool.get_grpc_channel( f'{self.args.host}:{self.args.port}', asyncio=True, tls=self.args.tls, ) as channel: stub = jina_pb2_grpc.JinaRPCStub(channel) self.logger.debug( f'connected to {self.args.host}:{self.args.port}') with ProgressBar(total_length=self._inputs_length, disable=not (self.show_progress)) as p_bar: async for resp in stub.Call( req_iter, compression=GRPC_COMPRESSION_MAP.get( compression.lower(), grpc.Compression.NoCompression), ): callback_exec( response=resp, on_error=on_error, on_done=on_done, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) if self.show_progress: p_bar.update() yield resp except KeyboardInterrupt: self.logger.warning('user cancel the process') except asyncio.CancelledError as ex: self.logger.warning(f'process error: {ex!r}') except grpc.aio._call.AioRpcError as rpc_ex: # Since this object is guaranteed to be a grpc.Call, might as well include that in its name. my_code = rpc_ex.code() my_details = rpc_ex.details() msg = f'gRPC error: {my_code} {my_details}' try: if my_code == grpc.StatusCode.UNAVAILABLE: self.logger.error( f'{msg}\nthe ongoing request is terminated as the server is not available or closed already' ) raise rpc_ex elif my_code == grpc.StatusCode.INTERNAL: self.logger.error( f'{msg}\ninternal error on the server side') raise rpc_ex elif (my_code == grpc.StatusCode.UNKNOWN and 'asyncio.exceptions.TimeoutError' in my_details): raise BadClientInput( f'{msg}\n' 'often the case is that you define/send a bad input iterator to jina, ' 'please double check your input iterator') from rpc_ex else: raise BadClient(msg) from rpc_ex except ( grpc.aio._call.AioRpcError, BaseJinaException, ) as e: # depending on if there are callbacks we catch or not the exception if on_error or on_always: if on_error: callback_exec_on_error(on_error, e, self.logger) if on_always: callback_exec( response=None, on_error=None, on_done=None, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) else: raise e