def test_set_cookie_headers(): @strawberry.type class Query: @strawberry.field def abc(self, info: Info) -> str: assert info.context.get("response") is not None info.context["response"].set_cookie( key="strawberry", value="rocks", ) info.context["response"].set_cookie( key="FastAPI", value="rocks", ) return "abc" app = FastAPI() schema = strawberry.Schema(query=Query) graphql_app = GraphQLRouter(schema) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) response = test_client.post("/graphql", json={"query": "{ abc }"}) assert response.status_code == 200 assert response.json() == {"data": {"abc": "abc"}} assert ( response.headers["set-cookie"] == "strawberry=rocks; Path=/; SameSite=lax, FastAPI=rocks; Path=/; SameSite=lax" )
def test_missing_path_and_prefix(): @strawberry.type class Query: @strawberry.field def abc(self) -> str: return "abc" app = FastAPI() schema = strawberry.Schema(query=Query) graphql_app = GraphQLRouter(schema) with pytest.raises(Exception) as exc: app.include_router(graphql_app) assert "Prefix and path cannot be both empty" in str(exc)
def setup(uvicorn: UvicornServer) -> None: # Monkey patch, due to limitations in configurability. strawberry.fastapi.router.JSONResponse = ExplorerJSONResponse # type: ignore[attr-defined] schema = strawberry.Schema(query=Query) graphql_app = GraphQLRouter( schema, context_getter=GraphQLContext(uvicorn).create_request_context) uvicorn.app.include_router(graphql_app, prefix=route) if graphql.open_graphiql: uvicorn.prerun_tasks.append( # Browser.open() needs an unlocked scheduler, so we need to defer that call to a # callstack that is not executing a rule. lambda: browser.open(uvicorn.request_state, route))
def test_include_router_prefix(): @strawberry.type class Query: @strawberry.field def abc(self) -> str: return "abc" app = FastAPI() schema = strawberry.Schema(query=Query) graphql_app = GraphQLRouter(schema) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) response = test_client.post("/graphql", json={"query": "{ abc }"}) assert response.status_code == 200 assert response.json() == {"data": {"abc": "abc"}}
def test_without_context_getter(): @strawberry.type class Query: @strawberry.field def abc(self, info: Info) -> str: assert info.context.get("request") is not None assert info.context.get("strawberry") is None return "abc" app = FastAPI() schema = strawberry.Schema(query=Query) graphql_app = GraphQLRouter(schema) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) response = test_client.post("/graphql", json={"query": "{ abc }"}) assert response.status_code == 200 assert response.json() == {"data": {"abc": "abc"}}
def test_set_custom_http_response_status(): @strawberry.type class Query: @strawberry.field def abc(self, info: Info) -> str: assert info.context.get("response") is not None info.context["response"].status_code = status.HTTP_418_IM_A_TEAPOT return "abc" app = FastAPI() schema = strawberry.Schema(query=Query) graphql_app = GraphQLRouter(schema) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) response = test_client.post("/graphql", json={"query": "{ abc }"}) assert response.status_code == 418 assert response.json() == {"data": {"abc": "abc"}}
def test_set_response_headers(): @strawberry.type class Query: @strawberry.field def abc(self, info: Info) -> str: assert info.context.get("response") is not None info.context["response"].headers["X-Strawberry"] = "rocks" return "abc" app = FastAPI() schema = strawberry.Schema(query=Query) graphql_app = GraphQLRouter(schema) app.include_router(graphql_app, prefix="/graphql") test_client = TestClient(app) response = test_client.post("/graphql", json={"query": "{ abc }"}) assert response.status_code == 200 assert response.json() == {"data": {"abc": "abc"}} assert response.headers["x-strawberry"] == "rocks"
def enable_graphql_api(self, app: FastAPI) -> None: ''' Loads the GraphQL API in the FastApi app. ''' self.app = app app.add_middleware( CORSMiddleware, allow_origins=self.schema.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # podserver.dependencies.podrequest_auth.PodApiRequestAuth # uses the GRAPHQL_API_URL_PREFIX to evaluate incoming # requests path = GRAPHQL_API_URL_PREFIX + str(self.service_id) graphql_app = GraphQLRouter( self.schema.gql_schema, graphiql=config.debug ) app.include_router(graphql_app, prefix=path)
def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', metrics_registry: Optional['CollectorRegistry'] = None, ): """ Get the app from FastAPI as the REST interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ with ImportExtensions(required=True): from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from starlette.requests import Request from jina.serve.runtimes.gateway.http.models import ( JinaEndpointRequestModel, JinaRequestModel, JinaResponseModel, JinaStatusModel, ) docs_url = '/docs' app = FastAPI( title=args.title or 'My Jina Service', description=args.description or 'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` ' 'to customize this text.', version=__version__, docs_url=docs_url if args.default_swagger_ui else None, ) if args.cors: app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) logger.warning( 'CORS is enabled. This service is now accessible from any website!' ) from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer request_handler = RequestHandler(metrics_registry, args.name) streamer = RequestStreamer( args=args, request_handler=request_handler.handle_request( graph=topology_graph, connection_pool=connection_pool), result_handler=request_handler.handle_result(), ) streamer.Call = streamer.stream @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() openapi_tags = [] if not args.no_debug_endpoints: openapi_tags.append({ 'name': 'Debug', 'description': 'Debugging interface. In production, you should hide them by setting ' '`--no-debug-endpoints` in `Flow`/`Gateway`.', }) from jina.serve.runtimes.gateway.http.models import JinaHealthModel @app.get( path='/', summary='Get the health of Jina service', response_model=JinaHealthModel, ) async def _health(): """ Get the health of this Jina service. .. # noqa: DAR201 """ return {} @app.get( path='/status', summary='Get the status of Jina service', response_model=JinaStatusModel, tags=['Debug'], ) async def _status(): """ Get the status of this Jina service. This is equivalent to running `jina -vf` from command line. .. # noqa: DAR201 """ _info = get_full_version() return { 'jina': _info[0], 'envs': _info[1], 'used_memory': used_memory_readable(), } @app.post( path='/post', summary='Post a data request to some endpoint', response_model=JinaResponseModel, tags=['Debug'] # do not add response_model here, this debug endpoint should not restricts the response model ) async def post(body: JinaEndpointRequestModel): """ Post a data request to some endpoint. This is equivalent to the following: from jina import Flow f = Flow().add(...) with f: f.post(endpoint, ...) .. # noqa: DAR201 .. # noqa: DAR101 """ # The above comment is written in Markdown for better rendering in FastAPI from jina.enums import DataInputType bd = body.dict() # type: Dict req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] result = await _get_singleton_result( request_generator(**req_generator_input)) return result def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs): """Exposing an executor endpoint to http endpoint :param exec_endpoint: the executor endpoint :param http_path: the http endpoint :param kwargs: kwargs accepted by FastAPI """ # set some default kwargs for richer semantics # group flow exposed endpoints into `customized` group kwargs['tags'] = kwargs.get('tags', ['Customized']) kwargs['response_model'] = kwargs.get( 'response_model', JinaResponseModel, # use standard response model by default ) kwargs['methods'] = kwargs.get('methods', ['POST']) @app.api_route(path=http_path or exec_endpoint, name=http_path or exec_endpoint, **kwargs) async def foo(body: JinaRequestModel): from jina.enums import DataInputType bd = body.dict() if body else {'data': None} bd['exec_endpoint'] = exec_endpoint req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] result = await _get_singleton_result( request_generator(**req_generator_input)) return result if not args.no_crud_endpoints: openapi_tags.append({ 'name': 'CRUD', 'description': 'CRUD interface. If your service does not implement those interfaces, you can should ' 'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.', }) crud = { '/index': { 'methods': ['POST'] }, '/search': { 'methods': ['POST'] }, '/delete': { 'methods': ['DELETE'] }, '/update': { 'methods': ['PUT'] }, } for k, v in crud.items(): v['tags'] = ['CRUD'] v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.' expose_executor_endpoint(exec_endpoint=k, **v) if openapi_tags: app.openapi_tags = openapi_tags if args.expose_endpoints: endpoints = json.loads(args.expose_endpoints) # type: Dict[str, Dict] for k, v in endpoints.items(): expose_executor_endpoint(exec_endpoint=k, **v) if not args.default_swagger_ui: async def _render_custom_swagger_html(req: Request) -> HTMLResponse: import urllib.request swagger_url = 'https://api.jina.ai/swagger' req = urllib.request.Request(swagger_url, headers={'User-Agent': 'Mozilla/5.0'}) with urllib.request.urlopen(req) as f: return HTMLResponse(f.read().decode()) app.add_route(docs_url, _render_custom_swagger_html, include_in_schema=False) if args.expose_graphql_endpoint: with ImportExtensions(required=True): from dataclasses import asdict import strawberry from docarray import DocumentArray from docarray.document.strawberry_type import ( JSONScalar, StrawberryDocument, StrawberryDocumentInput, ) from strawberry.fastapi import GraphQLRouter async def get_docs_from_endpoint(data, target_executor, parameters, exec_endpoint): req_generator_input = { 'data': [asdict(d) for d in data], 'target_executor': target_executor, 'parameters': parameters, 'exec_endpoint': exec_endpoint, 'data_type': DataInputType.DICT, } if (req_generator_input['data'] is not None and 'docs' in req_generator_input['data']): req_generator_input['data'] = req_generator_input['data'][ 'docs'] response = await _get_singleton_result( request_generator(**req_generator_input)) return DocumentArray.from_dict( response['data']).to_strawberry_type() @strawberry.type class Mutation: @strawberry.mutation async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) @strawberry.type class Query: @strawberry.field async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) schema = strawberry.Schema(query=Query, mutation=Mutation) app.include_router(GraphQLRouter(schema), prefix='/graphql') async def _get_singleton_result(request_iterator) -> Dict: """ Streams results from AsyncPrefetchCall as a dict :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ async for k in streamer.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict return app
def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', metrics_registry: Optional['CollectorRegistry'] = None, ): """ Get the app from FastAPI as the REST interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ with ImportExtensions(required=True): from fastapi import FastAPI, Response, status from fastapi.middleware.cors import CORSMiddleware from jina.serve.runtimes.gateway.http.models import ( JinaEndpointRequestModel, JinaRequestModel, JinaResponseModel, ) app = FastAPI( title=args.title or 'My Jina Service', description=args.description or 'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` ' 'to customize the title and description.', version=__version__, ) if args.cors: app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) logger.warning( 'CORS is enabled. This service is accessible from any website!') from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer request_handler = RequestHandler(metrics_registry, args.name) streamer = RequestStreamer( args=args, request_handler=request_handler.handle_request( graph=topology_graph, connection_pool=connection_pool), result_handler=request_handler.handle_result(), ) streamer.Call = streamer.stream @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() openapi_tags = [] if not args.no_debug_endpoints: openapi_tags.append({ 'name': 'Debug', 'description': 'Debugging interface. In production, you should hide them by setting ' '`--no-debug-endpoints` in `Flow`/`Gateway`.', }) from jina.serve.runtimes.gateway.http.models import JinaHealthModel @app.get( path='/', summary='Get the health of Jina Gateway service', response_model=JinaHealthModel, ) async def _gateway_health(): """ Get the health of this Gateway service. .. # noqa: DAR201 """ return {} from docarray import DocumentArray from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ from jina.serve.runtimes.gateway.http.models import ( PROTO_TO_PYDANTIC_MODELS, JinaInfoModel, ) from jina.types.request.status import StatusMessage @app.get( path='/dry_run', summary= 'Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to ' 'validate connectivity', response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto, ) async def _flow_health(): """ Get the health of the complete Flow service. .. # noqa: DAR201 """ da = DocumentArray() try: _ = await _get_singleton_result( request_generator( exec_endpoint=__dry_run_endpoint__, data=da, data_type=DataInputType.DOCUMENT, )) status_message = StatusMessage() status_message.set_code(jina_pb2.StatusProto.SUCCESS) return status_message.to_dict() except Exception as ex: status_message = StatusMessage() status_message.set_exception(ex) return status_message.to_dict(use_integers_for_enums=True) @app.get( path='/status', summary='Get the status of Jina service', response_model=JinaInfoModel, tags=['Debug'], ) async def _status(): """ Get the status of this Jina service. This is equivalent to running `jina -vf` from command line. .. # noqa: DAR201 """ version, env_info = get_full_version() for k, v in version.items(): version[k] = str(v) for k, v in env_info.items(): env_info[k] = str(v) return {'jina': version, 'envs': env_info} @app.post( path='/post', summary='Post a data request to some endpoint', response_model=JinaResponseModel, tags=['Debug'] # do not add response_model here, this debug endpoint should not restricts the response model ) async def post( body: JinaEndpointRequestModel, response: Response ): # 'response' is a FastAPI response, not a Jina response """ Post a data request to some endpoint. This is equivalent to the following: from jina import Flow f = Flow().add(...) with f: f.post(endpoint, ...) .. # noqa: DAR201 .. # noqa: DAR101 """ # The above comment is written in Markdown for better rendering in FastAPI from jina.enums import DataInputType bd = body.dict() # type: Dict req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] try: result = await _get_singleton_result( request_generator(**req_generator_input)) except InternalNetworkError as err: import grpc if err.code() == grpc.StatusCode.UNAVAILABLE: response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE elif err.code() == grpc.StatusCode.DEADLINE_EXCEEDED: response.status_code = status.HTTP_504_GATEWAY_TIMEOUT else: response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR result = bd # send back the request result['header'] = _generate_exception_header( err) # attach exception details to response header logger.error( f'Error while getting responses from deployments: {err.details()}' ) return result def _generate_exception_header(error: InternalNetworkError): import traceback from jina.proto.serializer import DataRequest exception_dict = { 'name': str(error.__class__), 'stacks': [ str(x) for x in traceback.extract_tb(error.og_exception.__traceback__) ], 'executor': '', } status_dict = { 'code': DataRequest().status.ERROR, 'description': error.details() if error.details() else '', 'exception': exception_dict, } header_dict = {'request_id': error.request_id, 'status': status_dict} return header_dict def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs): """Exposing an executor endpoint to http endpoint :param exec_endpoint: the executor endpoint :param http_path: the http endpoint :param kwargs: kwargs accepted by FastAPI """ # set some default kwargs for richer semantics # group flow exposed endpoints into `customized` group kwargs['tags'] = kwargs.get('tags', ['Customized']) kwargs['response_model'] = kwargs.get( 'response_model', JinaResponseModel, # use standard response model by default ) kwargs['methods'] = kwargs.get('methods', ['POST']) @app.api_route(path=http_path or exec_endpoint, name=http_path or exec_endpoint, **kwargs) async def foo(body: JinaRequestModel): from jina.enums import DataInputType bd = body.dict() if body else {'data': None} bd['exec_endpoint'] = exec_endpoint req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] result = await _get_singleton_result( request_generator(**req_generator_input)) return result if not args.no_crud_endpoints: openapi_tags.append({ 'name': 'CRUD', 'description': 'CRUD interface. If your service does not implement those interfaces, you can should ' 'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.', }) crud = { '/index': { 'methods': ['POST'] }, '/search': { 'methods': ['POST'] }, '/delete': { 'methods': ['DELETE'] }, '/update': { 'methods': ['PUT'] }, } for k, v in crud.items(): v['tags'] = ['CRUD'] v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.' expose_executor_endpoint(exec_endpoint=k, **v) if openapi_tags: app.openapi_tags = openapi_tags if args.expose_endpoints: endpoints = json.loads(args.expose_endpoints) # type: Dict[str, Dict] for k, v in endpoints.items(): expose_executor_endpoint(exec_endpoint=k, **v) if args.expose_graphql_endpoint: with ImportExtensions(required=True): from dataclasses import asdict import strawberry from docarray import DocumentArray from docarray.document.strawberry_type import ( JSONScalar, StrawberryDocument, StrawberryDocumentInput, ) from strawberry.fastapi import GraphQLRouter async def get_docs_from_endpoint(data, target_executor, parameters, exec_endpoint): req_generator_input = { 'data': [asdict(d) for d in data], 'target_executor': target_executor, 'parameters': parameters, 'exec_endpoint': exec_endpoint, 'data_type': DataInputType.DICT, } if (req_generator_input['data'] is not None and 'docs' in req_generator_input['data']): req_generator_input['data'] = req_generator_input['data'][ 'docs'] try: response = await _get_singleton_result( request_generator(**req_generator_input)) except InternalNetworkError as err: logger.error( f'Error while getting responses from deployments: {err.details()}' ) raise err # will be handled by Strawberry return DocumentArray.from_dict( response['data']).to_strawberry_type() @strawberry.type class Mutation: @strawberry.mutation async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) @strawberry.type class Query: @strawberry.field async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) schema = strawberry.Schema(query=Query, mutation=Mutation) app.include_router(GraphQLRouter(schema), prefix='/graphql') async def _get_singleton_result(request_iterator) -> Dict: """ Streams results from AsyncPrefetchCall as a dict :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ async for k in streamer.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict return app
@strawberry.type class Pessoa: id: Optional[int] nome: str idade: int livros: List['Livro'] @strawberry.type class Livro: id: Optional[int] titulo: str pessoa: Pessoa @strawberry.type class Query: all_pessoa: List[Pessoa] = strawberry.field(resolver=get_pessoas) all_livro: List[Livro] = strawberry.field(resolver=get_livros) @strawberry.type class Mutation: create_pessoa: Pessoa = strawberry.field(resolver=create_pessoas) create_livro: Livro = strawberry.field(resolver=create_livros) schema = strawberry.Schema(query=Query, mutation=Mutation) graphql_app = GraphQLRouter(schema)