async def _async_setup_server(self): request_handler = RequestHandler(self.metrics_registry, self.name) self.streamer = RequestStreamer( args=self.args, request_handler=request_handler.handle_request( graph=self._topology_graph, connection_pool=self._connection_pool), result_handler=request_handler.handle_result(), ) self.streamer.Call = self.streamer.stream jina_pb2_grpc.add_JinaRPCServicer_to_server(self.streamer, self.server) jina_pb2_grpc.add_JinaGatewayDryRunRPCServicer_to_server( self, self.server) jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self.server) service_names = ( jina_pb2.DESCRIPTOR.services_by_name['JinaRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC']. full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name, reflection.SERVICE_NAME, ) # Mark all services as healthy. health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self.server) for service in service_names: self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) reflection.enable_server_reflection(service_names, self.server) bind_addr = f'{__default_host__}:{self.args.port}' if self.args.ssl_keyfile and self.args.ssl_certfile: with open(self.args.ssl_keyfile, 'rb') as f: private_key = f.read() with open(self.args.ssl_certfile, 'rb') as f: certificate_chain = f.read() server_credentials = grpc.ssl_server_credentials((( private_key, certificate_chain, ), )) self.server.add_secure_port(bind_addr, server_credentials) elif ( self.args.ssl_keyfile != self.args.ssl_certfile ): # if we have only ssl_keyfile and not ssl_certfile or vice versa raise ValueError( f"you can't pass a ssl_keyfile without a ssl_certfile and vice versa" ) else: self.server.add_insecure_port(bind_addr) self.logger.debug(f'start server bound to {bind_addr}') await self.server.start()
async def async_setup(self): """ The async method to setup. Create the gRPC server and expose the port for communication. """ if not self.args.proxy and os.name != 'nt': os.unsetenv('http_proxy') os.unsetenv('https_proxy') self.server = grpc.aio.server(options=[ ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1), ]) self._set_topology_graph() self._set_connection_pool() self.streamer = RequestStreamer( args=self.args, request_handler=handle_request( graph=self._topology_graph, connection_pool=self._connection_pool), result_handler=handle_result, ) self.streamer.Call = self.streamer.stream jina_pb2_grpc.add_JinaRPCServicer_to_server(self.streamer, self.server) jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server( self, self.server) bind_addr = f'{__default_host__}:{self.args.port_expose}' self.server.add_insecure_port(bind_addr) self.logger.debug(f' Start server bound to {bind_addr}') await self.server.start()
def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', metrics_registry: Optional['CollectorRegistry'] = None, ): """ Get the app from FastAPI as the REST interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ with ImportExtensions(required=True): from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from starlette.requests import Request from jina.serve.runtimes.gateway.http.models import ( JinaEndpointRequestModel, JinaRequestModel, JinaResponseModel, JinaStatusModel, ) docs_url = '/docs' app = FastAPI( title=args.title or 'My Jina Service', description=args.description or 'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` ' 'to customize this text.', version=__version__, docs_url=docs_url if args.default_swagger_ui else None, ) if args.cors: app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) logger.warning( 'CORS is enabled. This service is now accessible from any website!' ) from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer request_handler = RequestHandler(metrics_registry, args.name) streamer = RequestStreamer( args=args, request_handler=request_handler.handle_request( graph=topology_graph, connection_pool=connection_pool), result_handler=request_handler.handle_result(), ) streamer.Call = streamer.stream @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() openapi_tags = [] if not args.no_debug_endpoints: openapi_tags.append({ 'name': 'Debug', 'description': 'Debugging interface. In production, you should hide them by setting ' '`--no-debug-endpoints` in `Flow`/`Gateway`.', }) from jina.serve.runtimes.gateway.http.models import JinaHealthModel @app.get( path='/', summary='Get the health of Jina service', response_model=JinaHealthModel, ) async def _health(): """ Get the health of this Jina service. .. # noqa: DAR201 """ return {} @app.get( path='/status', summary='Get the status of Jina service', response_model=JinaStatusModel, tags=['Debug'], ) async def _status(): """ Get the status of this Jina service. This is equivalent to running `jina -vf` from command line. .. # noqa: DAR201 """ _info = get_full_version() return { 'jina': _info[0], 'envs': _info[1], 'used_memory': used_memory_readable(), } @app.post( path='/post', summary='Post a data request to some endpoint', response_model=JinaResponseModel, tags=['Debug'] # do not add response_model here, this debug endpoint should not restricts the response model ) async def post(body: JinaEndpointRequestModel): """ Post a data request to some endpoint. This is equivalent to the following: from jina import Flow f = Flow().add(...) with f: f.post(endpoint, ...) .. # noqa: DAR201 .. # noqa: DAR101 """ # The above comment is written in Markdown for better rendering in FastAPI from jina.enums import DataInputType bd = body.dict() # type: Dict req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] result = await _get_singleton_result( request_generator(**req_generator_input)) return result def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs): """Exposing an executor endpoint to http endpoint :param exec_endpoint: the executor endpoint :param http_path: the http endpoint :param kwargs: kwargs accepted by FastAPI """ # set some default kwargs for richer semantics # group flow exposed endpoints into `customized` group kwargs['tags'] = kwargs.get('tags', ['Customized']) kwargs['response_model'] = kwargs.get( 'response_model', JinaResponseModel, # use standard response model by default ) kwargs['methods'] = kwargs.get('methods', ['POST']) @app.api_route(path=http_path or exec_endpoint, name=http_path or exec_endpoint, **kwargs) async def foo(body: JinaRequestModel): from jina.enums import DataInputType bd = body.dict() if body else {'data': None} bd['exec_endpoint'] = exec_endpoint req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] result = await _get_singleton_result( request_generator(**req_generator_input)) return result if not args.no_crud_endpoints: openapi_tags.append({ 'name': 'CRUD', 'description': 'CRUD interface. If your service does not implement those interfaces, you can should ' 'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.', }) crud = { '/index': { 'methods': ['POST'] }, '/search': { 'methods': ['POST'] }, '/delete': { 'methods': ['DELETE'] }, '/update': { 'methods': ['PUT'] }, } for k, v in crud.items(): v['tags'] = ['CRUD'] v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.' expose_executor_endpoint(exec_endpoint=k, **v) if openapi_tags: app.openapi_tags = openapi_tags if args.expose_endpoints: endpoints = json.loads(args.expose_endpoints) # type: Dict[str, Dict] for k, v in endpoints.items(): expose_executor_endpoint(exec_endpoint=k, **v) if not args.default_swagger_ui: async def _render_custom_swagger_html(req: Request) -> HTMLResponse: import urllib.request swagger_url = 'https://api.jina.ai/swagger' req = urllib.request.Request(swagger_url, headers={'User-Agent': 'Mozilla/5.0'}) with urllib.request.urlopen(req) as f: return HTMLResponse(f.read().decode()) app.add_route(docs_url, _render_custom_swagger_html, include_in_schema=False) if args.expose_graphql_endpoint: with ImportExtensions(required=True): from dataclasses import asdict import strawberry from docarray import DocumentArray from docarray.document.strawberry_type import ( JSONScalar, StrawberryDocument, StrawberryDocumentInput, ) from strawberry.fastapi import GraphQLRouter async def get_docs_from_endpoint(data, target_executor, parameters, exec_endpoint): req_generator_input = { 'data': [asdict(d) for d in data], 'target_executor': target_executor, 'parameters': parameters, 'exec_endpoint': exec_endpoint, 'data_type': DataInputType.DICT, } if (req_generator_input['data'] is not None and 'docs' in req_generator_input['data']): req_generator_input['data'] = req_generator_input['data'][ 'docs'] response = await _get_singleton_result( request_generator(**req_generator_input)) return DocumentArray.from_dict( response['data']).to_strawberry_type() @strawberry.type class Mutation: @strawberry.mutation async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) @strawberry.type class Query: @strawberry.field async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) schema = strawberry.Schema(query=Query, mutation=Mutation) app.include_router(GraphQLRouter(schema), prefix='/graphql') async def _get_singleton_result(request_iterator) -> Dict: """ Streams results from AsyncPrefetchCall as a dict :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ async for k in streamer.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict return app
def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', metrics_registry: Optional['CollectorRegistry'] = None, ): """ Get the app from FastAPI as the REST interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ with ImportExtensions(required=True): from fastapi import FastAPI, Response, status from fastapi.middleware.cors import CORSMiddleware from jina.serve.runtimes.gateway.http.models import ( JinaEndpointRequestModel, JinaRequestModel, JinaResponseModel, ) app = FastAPI( title=args.title or 'My Jina Service', description=args.description or 'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` ' 'to customize the title and description.', version=__version__, ) if args.cors: app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) logger.warning( 'CORS is enabled. This service is accessible from any website!') from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer request_handler = RequestHandler(metrics_registry, args.name) streamer = RequestStreamer( args=args, request_handler=request_handler.handle_request( graph=topology_graph, connection_pool=connection_pool), result_handler=request_handler.handle_result(), ) streamer.Call = streamer.stream @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() openapi_tags = [] if not args.no_debug_endpoints: openapi_tags.append({ 'name': 'Debug', 'description': 'Debugging interface. In production, you should hide them by setting ' '`--no-debug-endpoints` in `Flow`/`Gateway`.', }) from jina.serve.runtimes.gateway.http.models import JinaHealthModel @app.get( path='/', summary='Get the health of Jina Gateway service', response_model=JinaHealthModel, ) async def _gateway_health(): """ Get the health of this Gateway service. .. # noqa: DAR201 """ return {} from docarray import DocumentArray from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ from jina.serve.runtimes.gateway.http.models import ( PROTO_TO_PYDANTIC_MODELS, JinaInfoModel, ) from jina.types.request.status import StatusMessage @app.get( path='/dry_run', summary= 'Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to ' 'validate connectivity', response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto, ) async def _flow_health(): """ Get the health of the complete Flow service. .. # noqa: DAR201 """ da = DocumentArray() try: _ = await _get_singleton_result( request_generator( exec_endpoint=__dry_run_endpoint__, data=da, data_type=DataInputType.DOCUMENT, )) status_message = StatusMessage() status_message.set_code(jina_pb2.StatusProto.SUCCESS) return status_message.to_dict() except Exception as ex: status_message = StatusMessage() status_message.set_exception(ex) return status_message.to_dict(use_integers_for_enums=True) @app.get( path='/status', summary='Get the status of Jina service', response_model=JinaInfoModel, tags=['Debug'], ) async def _status(): """ Get the status of this Jina service. This is equivalent to running `jina -vf` from command line. .. # noqa: DAR201 """ version, env_info = get_full_version() for k, v in version.items(): version[k] = str(v) for k, v in env_info.items(): env_info[k] = str(v) return {'jina': version, 'envs': env_info} @app.post( path='/post', summary='Post a data request to some endpoint', response_model=JinaResponseModel, tags=['Debug'] # do not add response_model here, this debug endpoint should not restricts the response model ) async def post( body: JinaEndpointRequestModel, response: Response ): # 'response' is a FastAPI response, not a Jina response """ Post a data request to some endpoint. This is equivalent to the following: from jina import Flow f = Flow().add(...) with f: f.post(endpoint, ...) .. # noqa: DAR201 .. # noqa: DAR101 """ # The above comment is written in Markdown for better rendering in FastAPI from jina.enums import DataInputType bd = body.dict() # type: Dict req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] try: result = await _get_singleton_result( request_generator(**req_generator_input)) except InternalNetworkError as err: import grpc if err.code() == grpc.StatusCode.UNAVAILABLE: response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE elif err.code() == grpc.StatusCode.DEADLINE_EXCEEDED: response.status_code = status.HTTP_504_GATEWAY_TIMEOUT else: response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR result = bd # send back the request result['header'] = _generate_exception_header( err) # attach exception details to response header logger.error( f'Error while getting responses from deployments: {err.details()}' ) return result def _generate_exception_header(error: InternalNetworkError): import traceback from jina.proto.serializer import DataRequest exception_dict = { 'name': str(error.__class__), 'stacks': [ str(x) for x in traceback.extract_tb(error.og_exception.__traceback__) ], 'executor': '', } status_dict = { 'code': DataRequest().status.ERROR, 'description': error.details() if error.details() else '', 'exception': exception_dict, } header_dict = {'request_id': error.request_id, 'status': status_dict} return header_dict def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs): """Exposing an executor endpoint to http endpoint :param exec_endpoint: the executor endpoint :param http_path: the http endpoint :param kwargs: kwargs accepted by FastAPI """ # set some default kwargs for richer semantics # group flow exposed endpoints into `customized` group kwargs['tags'] = kwargs.get('tags', ['Customized']) kwargs['response_model'] = kwargs.get( 'response_model', JinaResponseModel, # use standard response model by default ) kwargs['methods'] = kwargs.get('methods', ['POST']) @app.api_route(path=http_path or exec_endpoint, name=http_path or exec_endpoint, **kwargs) async def foo(body: JinaRequestModel): from jina.enums import DataInputType bd = body.dict() if body else {'data': None} bd['exec_endpoint'] = exec_endpoint req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] result = await _get_singleton_result( request_generator(**req_generator_input)) return result if not args.no_crud_endpoints: openapi_tags.append({ 'name': 'CRUD', 'description': 'CRUD interface. If your service does not implement those interfaces, you can should ' 'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.', }) crud = { '/index': { 'methods': ['POST'] }, '/search': { 'methods': ['POST'] }, '/delete': { 'methods': ['DELETE'] }, '/update': { 'methods': ['PUT'] }, } for k, v in crud.items(): v['tags'] = ['CRUD'] v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.' expose_executor_endpoint(exec_endpoint=k, **v) if openapi_tags: app.openapi_tags = openapi_tags if args.expose_endpoints: endpoints = json.loads(args.expose_endpoints) # type: Dict[str, Dict] for k, v in endpoints.items(): expose_executor_endpoint(exec_endpoint=k, **v) if args.expose_graphql_endpoint: with ImportExtensions(required=True): from dataclasses import asdict import strawberry from docarray import DocumentArray from docarray.document.strawberry_type import ( JSONScalar, StrawberryDocument, StrawberryDocumentInput, ) from strawberry.fastapi import GraphQLRouter async def get_docs_from_endpoint(data, target_executor, parameters, exec_endpoint): req_generator_input = { 'data': [asdict(d) for d in data], 'target_executor': target_executor, 'parameters': parameters, 'exec_endpoint': exec_endpoint, 'data_type': DataInputType.DICT, } if (req_generator_input['data'] is not None and 'docs' in req_generator_input['data']): req_generator_input['data'] = req_generator_input['data'][ 'docs'] try: response = await _get_singleton_result( request_generator(**req_generator_input)) except InternalNetworkError as err: logger.error( f'Error while getting responses from deployments: {err.details()}' ) raise err # will be handled by Strawberry return DocumentArray.from_dict( response['data']).to_strawberry_type() @strawberry.type class Mutation: @strawberry.mutation async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) @strawberry.type class Query: @strawberry.field async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) schema = strawberry.Schema(query=Query, mutation=Mutation) app.include_router(GraphQLRouter(schema), prefix='/graphql') async def _get_singleton_result(request_iterator) -> Dict: """ Streams results from AsyncPrefetchCall as a dict :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ async for k in streamer.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict return app
async def test_request_streamer(prefetch, num_requests, async_iterator): requests_handled = [] results_handled = [] def request_handler_fn(request): requests_handled.append(request) async def task(): await asyncio.sleep(0.5) docs = request.docs docs[0].tags['request_handled'] = True request.data.docs = docs return request future = asyncio.ensure_future(task()) return future def result_handle_fn(result): results_handled.append(result) assert isinstance(result, DataRequest) docs = result.docs docs[0].tags['result_handled'] = True result.data.docs = docs return result def end_of_iter_fn(): # with a sync generator, iteration assert len(requests_handled) == num_requests assert len(results_handled) < num_requests def _yield_data_request(): req = DataRequest() req.header.request_id = random_identity() da = DocumentArray() da.append(Document()) req.data.docs = da return req def _get_sync_requests_iterator(num_requests): for i in range(num_requests): yield _yield_data_request() async def _get_async_requests_iterator(num_requests): for i in range(num_requests): yield _yield_data_request() await asyncio.sleep(0.1) args = Namespace() args.prefetch = prefetch streamer = RequestStreamer( args=args, request_handler=request_handler_fn, result_handler=result_handle_fn, end_of_iter_handler=end_of_iter_fn, ) it = (_get_async_requests_iterator(num_requests) if async_iterator else _get_sync_requests_iterator(num_requests)) response = streamer.stream(it) num_responses = 0 async for r in response: num_responses += 1 assert r.docs[0].tags['request_handled'] assert r.docs[0].tags['result_handled'] assert num_responses == num_requests
class GRPCGatewayRuntime(GatewayRuntime): """Gateway Runtime for gRPC.""" def __init__( self, args: argparse.Namespace, **kwargs, ): """Initialize the runtime :param args: args from CLI :param kwargs: keyword args """ self._health_servicer = health.HealthServicer( experimental_non_blocking=True) super().__init__(args, **kwargs) async def async_setup(self): """ The async method to setup. Create the gRPC server and expose the port for communication. """ if not self.args.proxy and os.name != 'nt': os.unsetenv('http_proxy') os.unsetenv('https_proxy') if not (is_port_free(__default_host__, self.args.port)): raise PortAlreadyUsed(f'port:{self.args.port}') self.server = grpc.aio.server(options=[ ('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1), ]) self._set_topology_graph() self._set_connection_pool() await self._async_setup_server() async def _async_setup_server(self): request_handler = RequestHandler(self.metrics_registry, self.name) self.streamer = RequestStreamer( args=self.args, request_handler=request_handler.handle_request( graph=self._topology_graph, connection_pool=self._connection_pool), result_handler=request_handler.handle_result(), ) self.streamer.Call = self.streamer.stream jina_pb2_grpc.add_JinaRPCServicer_to_server(self.streamer, self.server) jina_pb2_grpc.add_JinaGatewayDryRunRPCServicer_to_server( self, self.server) jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self.server) service_names = ( jina_pb2.DESCRIPTOR.services_by_name['JinaRPC'].full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC']. full_name, jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name, reflection.SERVICE_NAME, ) # Mark all services as healthy. health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer, self.server) for service in service_names: self._health_servicer.set(service, health_pb2.HealthCheckResponse.SERVING) reflection.enable_server_reflection(service_names, self.server) bind_addr = f'{__default_host__}:{self.args.port}' if self.args.ssl_keyfile and self.args.ssl_certfile: with open(self.args.ssl_keyfile, 'rb') as f: private_key = f.read() with open(self.args.ssl_certfile, 'rb') as f: certificate_chain = f.read() server_credentials = grpc.ssl_server_credentials((( private_key, certificate_chain, ), )) self.server.add_secure_port(bind_addr, server_credentials) elif ( self.args.ssl_keyfile != self.args.ssl_certfile ): # if we have only ssl_keyfile and not ssl_certfile or vice versa raise ValueError( f"you can't pass a ssl_keyfile without a ssl_certfile and vice versa" ) else: self.server.add_insecure_port(bind_addr) self.logger.debug(f'start server bound to {bind_addr}') await self.server.start() async def async_teardown(self): """Close the connection pool""" # usually async_cancel should already have been called, but then its a noop # if the runtime is stopped without a sigterm (e.g. as a context manager, this can happen) self._health_servicer.enter_graceful_shutdown() await self.async_cancel() await self._connection_pool.close() async def async_cancel(self): """The async method to stop server.""" await self.server.stop(0) async def async_run_forever(self): """The async running of server.""" self._connection_pool.start() await self.server.wait_for_termination() async def dry_run(self, empty, context) -> jina_pb2.StatusProto: """ Process the the call requested by having a dry run call to every Executor in the graph :param empty: The service expects an empty protobuf message :param context: grpc context :returns: the response request """ from docarray import DocumentArray from jina.clients.request import request_generator from jina.enums import DataInputType from jina.serve.executors import __dry_run_endpoint__ da = DocumentArray() try: req_iterator = request_generator( exec_endpoint=__dry_run_endpoint__, data=da, data_type=DataInputType.DOCUMENT, ) async for _ in self.streamer.stream(request_iterator=req_iterator): pass status_message = StatusMessage() status_message.set_code(jina_pb2.StatusProto.SUCCESS) return status_message.proto except Exception as ex: status_message = StatusMessage() status_message.set_exception(ex) return status_message.proto async def _status(self, empty, context) -> jina_pb2.JinaInfoProto: """ Process the the call requested and return the JinaInfo of the Runtime :param empty: The service expects an empty protobuf message :param context: grpc context :returns: the response request """ infoProto = jina_pb2.JinaInfoProto() version, env_info = get_full_version() for k, v in version.items(): infoProto.jina[k] = str(v) for k, v in env_info.items(): infoProto.envs[k] = str(v) return infoProto
async def _get_results( self, inputs: 'InputType', on_done: 'CallbackFnType', on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, **kwargs, ): """ :param inputs: the callable :param on_done: the callback for on_done :param on_error: the callback for on_error :param on_always: the callback for on_always :param kwargs: kwargs for _get_task_name and _get_requests :yields: generator over results """ with ImportExtensions(required=True): import aiohttp self.inputs = inputs request_iterator = self._get_requests(**kwargs) async with AsyncExitStack() as stack: try: cm1 = ProgressBar( total_length=self._inputs_length, disable=not (self.show_progress) ) p_bar = stack.enter_context(cm1) proto = 'wss' if self.args.tls else 'ws' url = f'{proto}://{self.args.host}:{self.args.port}/' iolet = await stack.enter_async_context( WebsocketClientlet(url=url, logger=self.logger) ) request_buffer: Dict[str, asyncio.Future] = dict() def _result_handler(result): return result async def _receive(): def _response_handler(response): if response.header.request_id in request_buffer: future = request_buffer.pop(response.header.request_id) future.set_result(response) else: self.logger.warning( f'discarding unexpected response with request id {response.header.request_id}' ) """Await messages from WebsocketGateway and process them in the request buffer""" try: async for response in iolet.recv_message(): _response_handler(response) finally: if request_buffer: self.logger.warning( f'{self.__class__.__name__} closed, cancelling all outstanding requests' ) for future in request_buffer.values(): future.cancel() request_buffer.clear() def _handle_end_of_iter(): """Send End of iteration signal to the Gateway""" asyncio.create_task(iolet.send_eoi()) def _request_handler(request: 'Request') -> 'asyncio.Future': """ For each request in the iterator, we send the `Message` using `iolet.send_message()`. For websocket requests from client, for each request in the iterator, we send the request in `bytes` using using `iolet.send_message()`. Then add {<request-id>: <an-empty-future>} to the request buffer. This empty future is used to track the `result` of this request during `receive`. :param request: current request in the iterator :return: asyncio Future for sending message """ future = get_or_reuse_loop().create_future() request_buffer[request.header.request_id] = future asyncio.create_task(iolet.send_message(request)) return future streamer = RequestStreamer( args=self.args, request_handler=_request_handler, result_handler=_result_handler, end_of_iter_handler=_handle_end_of_iter, ) receive_task = get_or_reuse_loop().create_task(_receive()) if receive_task.done(): raise RuntimeError( 'receive task not running, can not send messages' ) async for response in streamer.stream(request_iterator): callback_exec( response=response, on_error=on_error, on_done=on_done, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) if self.show_progress: p_bar.update() yield response except aiohttp.ClientError as e: self.logger.error( f'Error while streaming response from websocket server {e!r}' ) if on_error or on_always: if on_error: callback_exec_on_error(on_error, e, self.logger) if on_always: callback_exec( response=None, on_error=None, on_done=None, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) else: raise e
async def _get_results( self, inputs: 'InputType', on_done: 'CallbackFnType', on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, **kwargs, ): """ :param inputs: the callable :param on_done: the callback for on_done :param on_error: the callback for on_error :param on_always: the callback for on_always :param kwargs: kwargs for _get_task_name and _get_requests :yields: generator over results """ with ImportExtensions(required=True): import aiohttp self.inputs = inputs request_iterator = self._get_requests(**kwargs) async with AsyncExitStack() as stack: try: cm1 = (ProgressBar(total_length=self._inputs_length) if self.show_progress else nullcontext()) p_bar = stack.enter_context(cm1) proto = 'https' if self.args.https else 'http' url = f'{proto}://{self.args.host}:{self.args.port}/post' iolet = await stack.enter_async_context( HTTPClientlet(url=url, logger=self.logger)) def _request_handler(request: 'Request') -> 'asyncio.Future': """ For HTTP Client, for each request in the iterator, we `send_message` using http POST request and add it to the list of tasks which is awaited and yielded. :param request: current request in the iterator :return: asyncio Task for sending message """ return asyncio.ensure_future( iolet.send_message(request=request)) def _result_handler(result): return result streamer = RequestStreamer( self.args, request_handler=_request_handler, result_handler=_result_handler, ) async for response in streamer.stream(request_iterator): r_status = response.status r_str = await response.json() if r_status == 404: raise BadClient(f'no such endpoint {url}') elif r_status < 200 or r_status > 300: raise ValueError(r_str) da = None if 'data' in r_str and r_str['data'] is not None: from docarray import DocumentArray da = DocumentArray.from_dict(r_str['data']) del r_str['data'] resp = DataRequest(r_str) if da is not None: resp.data.docs = da callback_exec( response=resp, on_error=on_error, on_done=on_done, on_always=on_always, continue_on_error=self.continue_on_error, logger=self.logger, ) if self.show_progress: p_bar.update() yield resp except aiohttp.ClientError as e: self.logger.error( f'Error while fetching response from HTTP server {e!r}')
def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', ): """ Get the app from FastAPI as the Websocket interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :return: fastapi app """ with ImportExtensions(required=True): from fastapi import FastAPI, WebSocket, WebSocketDisconnect class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] async def connect(self, websocket: WebSocket): await websocket.accept() logger.debug( f'client {websocket.client.host}:{websocket.client.port} connected' ) self.active_connections.append(websocket) def disconnect(self, websocket: WebSocket): self.active_connections.remove(websocket) manager = ConnectionManager() app = FastAPI() from jina.serve.stream import RequestStreamer from jina.serve.runtimes.gateway.request_handling import ( handle_request, handle_result, ) streamer = RequestStreamer( args=args, request_handler=handle_request(graph=topology_graph, connection_pool=connection_pool), result_handler=handle_result, ) streamer.Call = streamer.stream @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() @app.websocket('/') async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) async def req_iter(): async for request_bytes in websocket.iter_bytes(): if request_bytes == bytes(True): break yield DataRequest(request_bytes) try: async for msg in streamer.stream(request_iterator=req_iter()): await websocket.send_bytes(bytes(msg)) except WebSocketDisconnect: logger.debug('Client successfully disconnected from server') manager.disconnect(websocket) return app
def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', metrics_registry: Optional['CollectorRegistry'] = None, ): """ Get the app from FastAPI as the Websocket interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ from jina.serve.runtimes.gateway.http.models import JinaEndpointRequestModel with ImportExtensions(required=True): from fastapi import FastAPI, WebSocket, WebSocketDisconnect class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] self.protocol_dict: Dict[str, WebsocketSubProtocols] = {} def get_client(self, websocket: WebSocket) -> str: return f'{websocket.client.host}:{websocket.client.port}' def get_subprotocol(self, headers: Dict): try: if 'sec-websocket-protocol' in headers: subprotocol = WebsocketSubProtocols( headers['sec-websocket-protocol']) elif b'sec-websocket-protocol' in headers: subprotocol = WebsocketSubProtocols( headers[b'sec-websocket-protocol'].decode()) else: subprotocol = WebsocketSubProtocols.JSON logger.debug( f'no protocol headers passed. Choosing default subprotocol {WebsocketSubProtocols.JSON}' ) except Exception as e: logger.debug( f'got an exception while setting user\'s subprotocol, defaulting to JSON {e}' ) subprotocol = WebsocketSubProtocols.JSON return subprotocol async def connect(self, websocket: WebSocket): await websocket.accept() subprotocol = self.get_subprotocol(dict( websocket.scope['headers'])) logger.info( f'client {websocket.client.host}:{websocket.client.port} connected ' f'with subprotocol {subprotocol}') self.active_connections.append(websocket) self.protocol_dict[self.get_client(websocket)] = subprotocol def disconnect(self, websocket: WebSocket): self.protocol_dict.pop(self.get_client(websocket)) self.active_connections.remove(websocket) async def receive(self, websocket: WebSocket) -> Any: subprotocol = self.protocol_dict[self.get_client(websocket)] if subprotocol == WebsocketSubProtocols.JSON: return await websocket.receive_json(mode='text') elif subprotocol == WebsocketSubProtocols.BYTES: return await websocket.receive_bytes() async def iter(self, websocket: WebSocket) -> AsyncIterator[Any]: try: while True: yield await self.receive(websocket) except WebSocketDisconnect: pass async def send(self, websocket: WebSocket, data: DataRequest) -> None: subprotocol = self.protocol_dict[self.get_client(websocket)] if subprotocol == WebsocketSubProtocols.JSON: return await websocket.send_json(data.to_dict(), mode='text') elif subprotocol == WebsocketSubProtocols.BYTES: return await websocket.send_bytes(data.to_bytes()) manager = ConnectionManager() app = FastAPI() from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer request_handler = RequestHandler(metrics_registry, args.name) streamer = RequestStreamer( args=args, request_handler=request_handler.handle_request( graph=topology_graph, connection_pool=connection_pool), result_handler=request_handler.handle_result(), ) streamer.Call = streamer.stream @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() @app.websocket('/') async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) async def req_iter(): async for request in manager.iter(websocket): if isinstance(request, dict): if request == {}: break else: # NOTE: Helps in converting camelCase to snake_case req_generator_input = JinaEndpointRequestModel( **request).dict() req_generator_input['data_type'] = DataInputType.DICT if request['data'] is not None and 'docs' in request[ 'data']: req_generator_input['data'] = req_generator_input[ 'data']['docs'] # you can't do `yield from` inside an async function for data_request in request_generator( **req_generator_input): yield data_request elif isinstance(request, bytes): if request == bytes(True): break else: yield DataRequest(request) try: async for msg in streamer.stream(request_iterator=req_iter()): await manager.send(websocket, msg) except WebSocketDisconnect: logger.info('Client successfully disconnected from server') manager.disconnect(websocket) return app
def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', metrics_registry: Optional['CollectorRegistry'] = None, ): """ Get the app from FastAPI as the Websocket interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ from jina.serve.runtimes.gateway.http.models import JinaEndpointRequestModel with ImportExtensions(required=True): from fastapi import FastAPI, Response, WebSocket, WebSocketDisconnect, status class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] self.protocol_dict: Dict[str, WebsocketSubProtocols] = {} def get_client(self, websocket: WebSocket) -> str: return f'{websocket.client.host}:{websocket.client.port}' def get_subprotocol(self, headers: Dict): try: if 'sec-websocket-protocol' in headers: subprotocol = WebsocketSubProtocols( headers['sec-websocket-protocol'] ) elif b'sec-websocket-protocol' in headers: subprotocol = WebsocketSubProtocols( headers[b'sec-websocket-protocol'].decode() ) else: subprotocol = WebsocketSubProtocols.JSON logger.debug( f'no protocol headers passed. Choosing default subprotocol {WebsocketSubProtocols.JSON}' ) except Exception as e: logger.debug( f'got an exception while setting user\'s subprotocol, defaulting to JSON {e}' ) subprotocol = WebsocketSubProtocols.JSON return subprotocol async def connect(self, websocket: WebSocket): await websocket.accept() subprotocol = self.get_subprotocol(dict(websocket.scope['headers'])) logger.info( f'client {websocket.client.host}:{websocket.client.port} connected ' f'with subprotocol {subprotocol}' ) self.active_connections.append(websocket) self.protocol_dict[self.get_client(websocket)] = subprotocol def disconnect(self, websocket: WebSocket): self.protocol_dict.pop(self.get_client(websocket)) self.active_connections.remove(websocket) async def receive(self, websocket: WebSocket) -> Any: subprotocol = self.protocol_dict[self.get_client(websocket)] if subprotocol == WebsocketSubProtocols.JSON: return await websocket.receive_json(mode='text') elif subprotocol == WebsocketSubProtocols.BYTES: return await websocket.receive_bytes() async def iter(self, websocket: WebSocket) -> AsyncIterator[Any]: try: while True: yield await self.receive(websocket) except WebSocketDisconnect: pass async def send( self, websocket: WebSocket, data: Union[DataRequest, StatusMessage] ) -> None: subprotocol = self.protocol_dict[self.get_client(websocket)] if subprotocol == WebsocketSubProtocols.JSON: return await websocket.send_json(data.to_dict(), mode='text') elif subprotocol == WebsocketSubProtocols.BYTES: return await websocket.send_bytes(data.to_bytes()) manager = ConnectionManager() app = FastAPI() from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer request_handler = RequestHandler(metrics_registry, args.name) streamer = RequestStreamer( args=args, request_handler=request_handler.handle_request( graph=topology_graph, connection_pool=connection_pool ), result_handler=request_handler.handle_result(), ) streamer.Call = streamer.stream @app.get( path='/', summary='Get the health of Jina service', ) async def _health(): """ Get the health of this Jina service. .. # noqa: DAR201 """ return {} @app.get( path='/status', summary='Get the status of Jina service', ) async def _status(): """ Get the status of this Jina service. This is equivalent to running `jina -vf` from command line. .. # noqa: DAR201 """ version, env_info = get_full_version() for k, v in version.items(): version[k] = str(v) for k, v in env_info.items(): env_info[k] = str(v) return {'jina': version, 'envs': env_info} @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() @app.websocket('/') async def websocket_endpoint( websocket: WebSocket, response: Response ): # 'response' is a FastAPI response, not a Jina response await manager.connect(websocket) async def req_iter(): async for request in manager.iter(websocket): if isinstance(request, dict): if request == {}: break else: # NOTE: Helps in converting camelCase to snake_case req_generator_input = JinaEndpointRequestModel(**request).dict() req_generator_input['data_type'] = DataInputType.DICT if request['data'] is not None and 'docs' in request['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs' ] # you can't do `yield from` inside an async function for data_request in request_generator(**req_generator_input): yield data_request elif isinstance(request, bytes): if request == bytes(True): break else: yield DataRequest(request) try: async for msg in streamer.stream(request_iterator=req_iter()): await manager.send(websocket, msg) except InternalNetworkError as err: import grpc manager.disconnect(websocket) fallback_msg = ( f'Connection to deployment at {err.dest_addr} timed out. You can adjust `timeout_send` attribute.' if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED else f'Network error while connecting to deployment at {err.dest_addr}. It may be down.' ) msg = ( err.details() if _fits_ws_close_msg( err.details() ) # some messages are too long for ws closing message else fallback_msg ) await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=msg) except WebSocketDisconnect: logger.info('Client successfully disconnected from server') manager.disconnect(websocket) async def _get_singleton_result(request_iterator) -> Dict: """ Streams results from AsyncPrefetchCall as a dict :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ async for k in streamer.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict from docarray import DocumentArray from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ from jina.serve.runtimes.gateway.http.models import PROTO_TO_PYDANTIC_MODELS @app.get( path='/dry_run', summary='Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to ' 'validate connectivity', response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto, ) async def _dry_run_http(): """ Get the health of the complete Flow service. .. # noqa: DAR201 """ da = DocumentArray() try: _ = await _get_singleton_result( request_generator( exec_endpoint=__dry_run_endpoint__, data=da, data_type=DataInputType.DOCUMENT, ) ) status_message = StatusMessage() status_message.set_code(jina_pb2.StatusProto.SUCCESS) return status_message.to_dict() except Exception as ex: status_message = StatusMessage() status_message.set_exception(ex) return status_message.to_dict(use_integers_for_enums=True) @app.websocket('/dry_run') async def websocket_endpoint( websocket: WebSocket, response: Response ): # 'response' is a FastAPI response, not a Jina response from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ await manager.connect(websocket) da = DocumentArray() try: async for _ in streamer.stream( request_iterator=request_generator( exec_endpoint=__dry_run_endpoint__, data=da, data_type=DataInputType.DOCUMENT, ) ): pass status_message = StatusMessage() status_message.set_code(jina_pb2.StatusProto.SUCCESS) await manager.send(websocket, status_message) except InternalNetworkError as err: manager.disconnect(websocket) msg = ( err.details() if _fits_ws_close_msg(err.details()) # some messages are too long else f'Network error while connecting to deployment at {err.dest_addr}. It may be down.' ) await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=msg) except WebSocketDisconnect: logger.info('Client successfully disconnected from server') manager.disconnect(websocket) except Exception as ex: manager.disconnect(websocket) status_message = StatusMessage() status_message.set_exception(ex) await manager.send(websocket, status_message) return app