예제 #1
0
파일: __init__.py 프로젝트: jina-ai/jina
    async def _async_setup_server(self):

        request_handler = RequestHandler(self.metrics_registry, self.name)

        self.streamer = RequestStreamer(
            args=self.args,
            request_handler=request_handler.handle_request(
                graph=self._topology_graph,
                connection_pool=self._connection_pool),
            result_handler=request_handler.handle_result(),
        )

        self.streamer.Call = self.streamer.stream

        jina_pb2_grpc.add_JinaRPCServicer_to_server(self.streamer, self.server)
        jina_pb2_grpc.add_JinaGatewayDryRunRPCServicer_to_server(
            self, self.server)
        jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self.server)

        service_names = (
            jina_pb2.DESCRIPTOR.services_by_name['JinaRPC'].full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC'].
            full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name,
            reflection.SERVICE_NAME,
        )
        # Mark all services as healthy.
        health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer,
                                                     self.server)

        for service in service_names:
            self._health_servicer.set(service,
                                      health_pb2.HealthCheckResponse.SERVING)
        reflection.enable_server_reflection(service_names, self.server)

        bind_addr = f'{__default_host__}:{self.args.port}'

        if self.args.ssl_keyfile and self.args.ssl_certfile:
            with open(self.args.ssl_keyfile, 'rb') as f:
                private_key = f.read()
            with open(self.args.ssl_certfile, 'rb') as f:
                certificate_chain = f.read()

            server_credentials = grpc.ssl_server_credentials(((
                private_key,
                certificate_chain,
            ), ))
            self.server.add_secure_port(bind_addr, server_credentials)
        elif (
                self.args.ssl_keyfile != self.args.ssl_certfile
        ):  # if we have only ssl_keyfile and not ssl_certfile or vice versa
            raise ValueError(
                f"you can't pass a ssl_keyfile without a ssl_certfile and vice versa"
            )
        else:
            self.server.add_insecure_port(bind_addr)
        self.logger.debug(f'start server bound to {bind_addr}')
        await self.server.start()
예제 #2
0
파일: __init__.py 프로젝트: srbhr/jina
    async def async_setup(self):
        """
        The async method to setup.

        Create the gRPC server and expose the port for communication.
        """
        if not self.args.proxy and os.name != 'nt':
            os.unsetenv('http_proxy')
            os.unsetenv('https_proxy')

        self.server = grpc.aio.server(options=[
            ('grpc.max_send_message_length', -1),
            ('grpc.max_receive_message_length', -1),
        ])
        self._set_topology_graph()
        self._set_connection_pool()

        self.streamer = RequestStreamer(
            args=self.args,
            request_handler=handle_request(
                graph=self._topology_graph,
                connection_pool=self._connection_pool),
            result_handler=handle_result,
        )

        self.streamer.Call = self.streamer.stream

        jina_pb2_grpc.add_JinaRPCServicer_to_server(self.streamer, self.server)
        jina_pb2_grpc.add_JinaControlRequestRPCServicer_to_server(
            self, self.server)
        bind_addr = f'{__default_host__}:{self.args.port_expose}'
        self.server.add_insecure_port(bind_addr)
        self.logger.debug(f' Start server bound to {bind_addr}')
        await self.server.start()
예제 #3
0
def get_fastapi_app(
    args: 'argparse.Namespace',
    topology_graph: 'TopologyGraph',
    connection_pool: 'GrpcConnectionPool',
    logger: 'JinaLogger',
    metrics_registry: Optional['CollectorRegistry'] = None,
):
    """
    Get the app from FastAPI as the REST interface.

    :param args: passed arguments.
    :param topology_graph: topology graph that manages the logic of sending to the proper executors.
    :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them
    :param logger: Jina logger.
    :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler
    :return: fastapi app
    """
    with ImportExtensions(required=True):
        from fastapi import FastAPI
        from fastapi.middleware.cors import CORSMiddleware
        from fastapi.responses import HTMLResponse
        from starlette.requests import Request

        from jina.serve.runtimes.gateway.http.models import (
            JinaEndpointRequestModel,
            JinaRequestModel,
            JinaResponseModel,
            JinaStatusModel,
        )

    docs_url = '/docs'
    app = FastAPI(
        title=args.title or 'My Jina Service',
        description=args.description or
        'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` '
        'to customize this text.',
        version=__version__,
        docs_url=docs_url if args.default_swagger_ui else None,
    )

    if args.cors:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=['*'],
            allow_credentials=True,
            allow_methods=['*'],
            allow_headers=['*'],
        )
        logger.warning(
            'CORS is enabled. This service is now accessible from any website!'
        )

    from jina.serve.runtimes.gateway.request_handling import RequestHandler
    from jina.serve.stream import RequestStreamer

    request_handler = RequestHandler(metrics_registry, args.name)

    streamer = RequestStreamer(
        args=args,
        request_handler=request_handler.handle_request(
            graph=topology_graph, connection_pool=connection_pool),
        result_handler=request_handler.handle_result(),
    )
    streamer.Call = streamer.stream

    @app.on_event('shutdown')
    async def _shutdown():
        await connection_pool.close()

    openapi_tags = []
    if not args.no_debug_endpoints:
        openapi_tags.append({
            'name':
            'Debug',
            'description':
            'Debugging interface. In production, you should hide them by setting '
            '`--no-debug-endpoints` in `Flow`/`Gateway`.',
        })

        from jina.serve.runtimes.gateway.http.models import JinaHealthModel

        @app.get(
            path='/',
            summary='Get the health of Jina service',
            response_model=JinaHealthModel,
        )
        async def _health():
            """
            Get the health of this Jina service.
            .. # noqa: DAR201

            """
            return {}

        @app.get(
            path='/status',
            summary='Get the status of Jina service',
            response_model=JinaStatusModel,
            tags=['Debug'],
        )
        async def _status():
            """
            Get the status of this Jina service.

            This is equivalent to running `jina -vf` from command line.

            .. # noqa: DAR201
            """
            _info = get_full_version()
            return {
                'jina': _info[0],
                'envs': _info[1],
                'used_memory': used_memory_readable(),
            }

        @app.post(
            path='/post',
            summary='Post a data request to some endpoint',
            response_model=JinaResponseModel,
            tags=['Debug']
            # do not add response_model here, this debug endpoint should not restricts the response model
        )
        async def post(body: JinaEndpointRequestModel):
            """
            Post a data request to some endpoint.

            This is equivalent to the following:

                from jina import Flow

                f = Flow().add(...)

                with f:
                    f.post(endpoint, ...)

            .. # noqa: DAR201
            .. # noqa: DAR101
            """
            # The above comment is written in Markdown for better rendering in FastAPI
            from jina.enums import DataInputType

            bd = body.dict()  # type: Dict
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            result = await _get_singleton_result(
                request_generator(**req_generator_input))
            return result

    def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs):
        """Exposing an executor endpoint to http endpoint
        :param exec_endpoint: the executor endpoint
        :param http_path: the http endpoint
        :param kwargs: kwargs accepted by FastAPI
        """

        # set some default kwargs for richer semantics
        # group flow exposed endpoints into `customized` group
        kwargs['tags'] = kwargs.get('tags', ['Customized'])
        kwargs['response_model'] = kwargs.get(
            'response_model',
            JinaResponseModel,  # use standard response model by default
        )
        kwargs['methods'] = kwargs.get('methods', ['POST'])

        @app.api_route(path=http_path or exec_endpoint,
                       name=http_path or exec_endpoint,
                       **kwargs)
        async def foo(body: JinaRequestModel):
            from jina.enums import DataInputType

            bd = body.dict() if body else {'data': None}
            bd['exec_endpoint'] = exec_endpoint
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            result = await _get_singleton_result(
                request_generator(**req_generator_input))
            return result

    if not args.no_crud_endpoints:
        openapi_tags.append({
            'name':
            'CRUD',
            'description':
            'CRUD interface. If your service does not implement those interfaces, you can should '
            'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.',
        })
        crud = {
            '/index': {
                'methods': ['POST']
            },
            '/search': {
                'methods': ['POST']
            },
            '/delete': {
                'methods': ['DELETE']
            },
            '/update': {
                'methods': ['PUT']
            },
        }
        for k, v in crud.items():
            v['tags'] = ['CRUD']
            v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.'
            expose_executor_endpoint(exec_endpoint=k, **v)

    if openapi_tags:
        app.openapi_tags = openapi_tags

    if args.expose_endpoints:
        endpoints = json.loads(args.expose_endpoints)  # type: Dict[str, Dict]
        for k, v in endpoints.items():
            expose_executor_endpoint(exec_endpoint=k, **v)

    if not args.default_swagger_ui:

        async def _render_custom_swagger_html(req: Request) -> HTMLResponse:
            import urllib.request

            swagger_url = 'https://api.jina.ai/swagger'
            req = urllib.request.Request(swagger_url,
                                         headers={'User-Agent': 'Mozilla/5.0'})
            with urllib.request.urlopen(req) as f:
                return HTMLResponse(f.read().decode())

        app.add_route(docs_url,
                      _render_custom_swagger_html,
                      include_in_schema=False)

    if args.expose_graphql_endpoint:
        with ImportExtensions(required=True):
            from dataclasses import asdict

            import strawberry
            from docarray import DocumentArray
            from docarray.document.strawberry_type import (
                JSONScalar,
                StrawberryDocument,
                StrawberryDocumentInput,
            )
            from strawberry.fastapi import GraphQLRouter

            async def get_docs_from_endpoint(data, target_executor, parameters,
                                             exec_endpoint):
                req_generator_input = {
                    'data': [asdict(d) for d in data],
                    'target_executor': target_executor,
                    'parameters': parameters,
                    'exec_endpoint': exec_endpoint,
                    'data_type': DataInputType.DICT,
                }

                if (req_generator_input['data'] is not None
                        and 'docs' in req_generator_input['data']):
                    req_generator_input['data'] = req_generator_input['data'][
                        'docs']

                response = await _get_singleton_result(
                    request_generator(**req_generator_input))
                return DocumentArray.from_dict(
                    response['data']).to_strawberry_type()

            @strawberry.type
            class Mutation:
                @strawberry.mutation
                async def docs(
                    self,
                    data: Optional[List[StrawberryDocumentInput]] = None,
                    target_executor: Optional[str] = None,
                    parameters: Optional[JSONScalar] = None,
                    exec_endpoint: str = '/search',
                ) -> List[StrawberryDocument]:
                    return await get_docs_from_endpoint(
                        data, target_executor, parameters, exec_endpoint)

            @strawberry.type
            class Query:
                @strawberry.field
                async def docs(
                    self,
                    data: Optional[List[StrawberryDocumentInput]] = None,
                    target_executor: Optional[str] = None,
                    parameters: Optional[JSONScalar] = None,
                    exec_endpoint: str = '/search',
                ) -> List[StrawberryDocument]:
                    return await get_docs_from_endpoint(
                        data, target_executor, parameters, exec_endpoint)

            schema = strawberry.Schema(query=Query, mutation=Mutation)
            app.include_router(GraphQLRouter(schema), prefix='/graphql')

    async def _get_singleton_result(request_iterator) -> Dict:
        """
        Streams results from AsyncPrefetchCall as a dict

        :param request_iterator: request iterator, with length of 1
        :return: the first result from the request iterator
        """
        async for k in streamer.stream(request_iterator=request_iterator):
            request_dict = k.to_dict()
            return request_dict

    return app
예제 #4
0
파일: app.py 프로젝트: jina-ai/jina
def get_fastapi_app(
    args: 'argparse.Namespace',
    topology_graph: 'TopologyGraph',
    connection_pool: 'GrpcConnectionPool',
    logger: 'JinaLogger',
    metrics_registry: Optional['CollectorRegistry'] = None,
):
    """
    Get the app from FastAPI as the REST interface.

    :param args: passed arguments.
    :param topology_graph: topology graph that manages the logic of sending to the proper executors.
    :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them
    :param logger: Jina logger.
    :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler
    :return: fastapi app
    """
    with ImportExtensions(required=True):
        from fastapi import FastAPI, Response, status
        from fastapi.middleware.cors import CORSMiddleware

        from jina.serve.runtimes.gateway.http.models import (
            JinaEndpointRequestModel,
            JinaRequestModel,
            JinaResponseModel,
        )

    app = FastAPI(
        title=args.title or 'My Jina Service',
        description=args.description or
        'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` '
        'to customize the title and description.',
        version=__version__,
    )

    if args.cors:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=['*'],
            allow_credentials=True,
            allow_methods=['*'],
            allow_headers=['*'],
        )
        logger.warning(
            'CORS is enabled. This service is accessible from any website!')

    from jina.serve.runtimes.gateway.request_handling import RequestHandler
    from jina.serve.stream import RequestStreamer

    request_handler = RequestHandler(metrics_registry, args.name)

    streamer = RequestStreamer(
        args=args,
        request_handler=request_handler.handle_request(
            graph=topology_graph, connection_pool=connection_pool),
        result_handler=request_handler.handle_result(),
    )
    streamer.Call = streamer.stream

    @app.on_event('shutdown')
    async def _shutdown():
        await connection_pool.close()

    openapi_tags = []
    if not args.no_debug_endpoints:
        openapi_tags.append({
            'name':
            'Debug',
            'description':
            'Debugging interface. In production, you should hide them by setting '
            '`--no-debug-endpoints` in `Flow`/`Gateway`.',
        })

        from jina.serve.runtimes.gateway.http.models import JinaHealthModel

        @app.get(
            path='/',
            summary='Get the health of Jina Gateway service',
            response_model=JinaHealthModel,
        )
        async def _gateway_health():
            """
            Get the health of this Gateway service.
            .. # noqa: DAR201

            """
            return {}

        from docarray import DocumentArray

        from jina.proto import jina_pb2
        from jina.serve.executors import __dry_run_endpoint__
        from jina.serve.runtimes.gateway.http.models import (
            PROTO_TO_PYDANTIC_MODELS,
            JinaInfoModel,
        )
        from jina.types.request.status import StatusMessage

        @app.get(
            path='/dry_run',
            summary=
            'Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to '
            'validate connectivity',
            response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto,
        )
        async def _flow_health():
            """
            Get the health of the complete Flow service.
            .. # noqa: DAR201

            """

            da = DocumentArray()

            try:
                _ = await _get_singleton_result(
                    request_generator(
                        exec_endpoint=__dry_run_endpoint__,
                        data=da,
                        data_type=DataInputType.DOCUMENT,
                    ))
                status_message = StatusMessage()
                status_message.set_code(jina_pb2.StatusProto.SUCCESS)
                return status_message.to_dict()
            except Exception as ex:
                status_message = StatusMessage()
                status_message.set_exception(ex)
                return status_message.to_dict(use_integers_for_enums=True)

        @app.get(
            path='/status',
            summary='Get the status of Jina service',
            response_model=JinaInfoModel,
            tags=['Debug'],
        )
        async def _status():
            """
            Get the status of this Jina service.

            This is equivalent to running `jina -vf` from command line.

            .. # noqa: DAR201
            """
            version, env_info = get_full_version()
            for k, v in version.items():
                version[k] = str(v)
            for k, v in env_info.items():
                env_info[k] = str(v)
            return {'jina': version, 'envs': env_info}

        @app.post(
            path='/post',
            summary='Post a data request to some endpoint',
            response_model=JinaResponseModel,
            tags=['Debug']
            # do not add response_model here, this debug endpoint should not restricts the response model
        )
        async def post(
            body: JinaEndpointRequestModel, response: Response
        ):  # 'response' is a FastAPI response, not a Jina response
            """
            Post a data request to some endpoint.

            This is equivalent to the following:

                from jina import Flow

                f = Flow().add(...)

                with f:
                    f.post(endpoint, ...)

            .. # noqa: DAR201
            .. # noqa: DAR101
            """
            # The above comment is written in Markdown for better rendering in FastAPI
            from jina.enums import DataInputType

            bd = body.dict()  # type: Dict
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            try:
                result = await _get_singleton_result(
                    request_generator(**req_generator_input))
            except InternalNetworkError as err:
                import grpc

                if err.code() == grpc.StatusCode.UNAVAILABLE:
                    response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
                elif err.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
                    response.status_code = status.HTTP_504_GATEWAY_TIMEOUT
                else:
                    response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
                result = bd  # send back the request
                result['header'] = _generate_exception_header(
                    err)  # attach exception details to response header
                logger.error(
                    f'Error while getting responses from deployments: {err.details()}'
                )
            return result

    def _generate_exception_header(error: InternalNetworkError):
        import traceback

        from jina.proto.serializer import DataRequest

        exception_dict = {
            'name':
            str(error.__class__),
            'stacks': [
                str(x)
                for x in traceback.extract_tb(error.og_exception.__traceback__)
            ],
            'executor':
            '',
        }
        status_dict = {
            'code': DataRequest().status.ERROR,
            'description': error.details() if error.details() else '',
            'exception': exception_dict,
        }
        header_dict = {'request_id': error.request_id, 'status': status_dict}
        return header_dict

    def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs):
        """Exposing an executor endpoint to http endpoint
        :param exec_endpoint: the executor endpoint
        :param http_path: the http endpoint
        :param kwargs: kwargs accepted by FastAPI
        """

        # set some default kwargs for richer semantics
        # group flow exposed endpoints into `customized` group
        kwargs['tags'] = kwargs.get('tags', ['Customized'])
        kwargs['response_model'] = kwargs.get(
            'response_model',
            JinaResponseModel,  # use standard response model by default
        )
        kwargs['methods'] = kwargs.get('methods', ['POST'])

        @app.api_route(path=http_path or exec_endpoint,
                       name=http_path or exec_endpoint,
                       **kwargs)
        async def foo(body: JinaRequestModel):
            from jina.enums import DataInputType

            bd = body.dict() if body else {'data': None}
            bd['exec_endpoint'] = exec_endpoint
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            result = await _get_singleton_result(
                request_generator(**req_generator_input))
            return result

    if not args.no_crud_endpoints:
        openapi_tags.append({
            'name':
            'CRUD',
            'description':
            'CRUD interface. If your service does not implement those interfaces, you can should '
            'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.',
        })
        crud = {
            '/index': {
                'methods': ['POST']
            },
            '/search': {
                'methods': ['POST']
            },
            '/delete': {
                'methods': ['DELETE']
            },
            '/update': {
                'methods': ['PUT']
            },
        }
        for k, v in crud.items():
            v['tags'] = ['CRUD']
            v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.'
            expose_executor_endpoint(exec_endpoint=k, **v)

    if openapi_tags:
        app.openapi_tags = openapi_tags

    if args.expose_endpoints:
        endpoints = json.loads(args.expose_endpoints)  # type: Dict[str, Dict]
        for k, v in endpoints.items():
            expose_executor_endpoint(exec_endpoint=k, **v)

    if args.expose_graphql_endpoint:
        with ImportExtensions(required=True):
            from dataclasses import asdict

            import strawberry
            from docarray import DocumentArray
            from docarray.document.strawberry_type import (
                JSONScalar,
                StrawberryDocument,
                StrawberryDocumentInput,
            )
            from strawberry.fastapi import GraphQLRouter

            async def get_docs_from_endpoint(data, target_executor, parameters,
                                             exec_endpoint):
                req_generator_input = {
                    'data': [asdict(d) for d in data],
                    'target_executor': target_executor,
                    'parameters': parameters,
                    'exec_endpoint': exec_endpoint,
                    'data_type': DataInputType.DICT,
                }

                if (req_generator_input['data'] is not None
                        and 'docs' in req_generator_input['data']):
                    req_generator_input['data'] = req_generator_input['data'][
                        'docs']
                try:
                    response = await _get_singleton_result(
                        request_generator(**req_generator_input))
                except InternalNetworkError as err:
                    logger.error(
                        f'Error while getting responses from deployments: {err.details()}'
                    )
                    raise err  # will be handled by Strawberry
                return DocumentArray.from_dict(
                    response['data']).to_strawberry_type()

            @strawberry.type
            class Mutation:
                @strawberry.mutation
                async def docs(
                    self,
                    data: Optional[List[StrawberryDocumentInput]] = None,
                    target_executor: Optional[str] = None,
                    parameters: Optional[JSONScalar] = None,
                    exec_endpoint: str = '/search',
                ) -> List[StrawberryDocument]:
                    return await get_docs_from_endpoint(
                        data, target_executor, parameters, exec_endpoint)

            @strawberry.type
            class Query:
                @strawberry.field
                async def docs(
                    self,
                    data: Optional[List[StrawberryDocumentInput]] = None,
                    target_executor: Optional[str] = None,
                    parameters: Optional[JSONScalar] = None,
                    exec_endpoint: str = '/search',
                ) -> List[StrawberryDocument]:
                    return await get_docs_from_endpoint(
                        data, target_executor, parameters, exec_endpoint)

            schema = strawberry.Schema(query=Query, mutation=Mutation)
            app.include_router(GraphQLRouter(schema), prefix='/graphql')

    async def _get_singleton_result(request_iterator) -> Dict:
        """
        Streams results from AsyncPrefetchCall as a dict

        :param request_iterator: request iterator, with length of 1
        :return: the first result from the request iterator
        """
        async for k in streamer.stream(request_iterator=request_iterator):
            request_dict = k.to_dict()
            return request_dict

    return app
예제 #5
0
async def test_request_streamer(prefetch, num_requests, async_iterator):
    requests_handled = []
    results_handled = []

    def request_handler_fn(request):
        requests_handled.append(request)

        async def task():
            await asyncio.sleep(0.5)
            docs = request.docs
            docs[0].tags['request_handled'] = True
            request.data.docs = docs
            return request

        future = asyncio.ensure_future(task())
        return future

    def result_handle_fn(result):
        results_handled.append(result)
        assert isinstance(result, DataRequest)
        docs = result.docs
        docs[0].tags['result_handled'] = True
        result.data.docs = docs
        return result

    def end_of_iter_fn():
        # with a sync generator, iteration
        assert len(requests_handled) == num_requests
        assert len(results_handled) < num_requests

    def _yield_data_request():
        req = DataRequest()
        req.header.request_id = random_identity()
        da = DocumentArray()
        da.append(Document())
        req.data.docs = da
        return req

    def _get_sync_requests_iterator(num_requests):
        for i in range(num_requests):
            yield _yield_data_request()

    async def _get_async_requests_iterator(num_requests):
        for i in range(num_requests):
            yield _yield_data_request()
            await asyncio.sleep(0.1)

    args = Namespace()
    args.prefetch = prefetch
    streamer = RequestStreamer(
        args=args,
        request_handler=request_handler_fn,
        result_handler=result_handle_fn,
        end_of_iter_handler=end_of_iter_fn,
    )

    it = (_get_async_requests_iterator(num_requests)
          if async_iterator else _get_sync_requests_iterator(num_requests))
    response = streamer.stream(it)

    num_responses = 0

    async for r in response:
        num_responses += 1
        assert r.docs[0].tags['request_handled']
        assert r.docs[0].tags['result_handled']

    assert num_responses == num_requests
예제 #6
0
파일: __init__.py 프로젝트: jina-ai/jina
class GRPCGatewayRuntime(GatewayRuntime):
    """Gateway Runtime for gRPC."""
    def __init__(
        self,
        args: argparse.Namespace,
        **kwargs,
    ):
        """Initialize the runtime
        :param args: args from CLI
        :param kwargs: keyword args
        """
        self._health_servicer = health.HealthServicer(
            experimental_non_blocking=True)
        super().__init__(args, **kwargs)

    async def async_setup(self):
        """
        The async method to setup.

        Create the gRPC server and expose the port for communication.
        """
        if not self.args.proxy and os.name != 'nt':
            os.unsetenv('http_proxy')
            os.unsetenv('https_proxy')

        if not (is_port_free(__default_host__, self.args.port)):
            raise PortAlreadyUsed(f'port:{self.args.port}')

        self.server = grpc.aio.server(options=[
            ('grpc.max_send_message_length', -1),
            ('grpc.max_receive_message_length', -1),
        ])

        self._set_topology_graph()
        self._set_connection_pool()

        await self._async_setup_server()

    async def _async_setup_server(self):

        request_handler = RequestHandler(self.metrics_registry, self.name)

        self.streamer = RequestStreamer(
            args=self.args,
            request_handler=request_handler.handle_request(
                graph=self._topology_graph,
                connection_pool=self._connection_pool),
            result_handler=request_handler.handle_result(),
        )

        self.streamer.Call = self.streamer.stream

        jina_pb2_grpc.add_JinaRPCServicer_to_server(self.streamer, self.server)
        jina_pb2_grpc.add_JinaGatewayDryRunRPCServicer_to_server(
            self, self.server)
        jina_pb2_grpc.add_JinaInfoRPCServicer_to_server(self, self.server)

        service_names = (
            jina_pb2.DESCRIPTOR.services_by_name['JinaRPC'].full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC'].
            full_name,
            jina_pb2.DESCRIPTOR.services_by_name['JinaInfoRPC'].full_name,
            reflection.SERVICE_NAME,
        )
        # Mark all services as healthy.
        health_pb2_grpc.add_HealthServicer_to_server(self._health_servicer,
                                                     self.server)

        for service in service_names:
            self._health_servicer.set(service,
                                      health_pb2.HealthCheckResponse.SERVING)
        reflection.enable_server_reflection(service_names, self.server)

        bind_addr = f'{__default_host__}:{self.args.port}'

        if self.args.ssl_keyfile and self.args.ssl_certfile:
            with open(self.args.ssl_keyfile, 'rb') as f:
                private_key = f.read()
            with open(self.args.ssl_certfile, 'rb') as f:
                certificate_chain = f.read()

            server_credentials = grpc.ssl_server_credentials(((
                private_key,
                certificate_chain,
            ), ))
            self.server.add_secure_port(bind_addr, server_credentials)
        elif (
                self.args.ssl_keyfile != self.args.ssl_certfile
        ):  # if we have only ssl_keyfile and not ssl_certfile or vice versa
            raise ValueError(
                f"you can't pass a ssl_keyfile without a ssl_certfile and vice versa"
            )
        else:
            self.server.add_insecure_port(bind_addr)
        self.logger.debug(f'start server bound to {bind_addr}')
        await self.server.start()

    async def async_teardown(self):
        """Close the connection pool"""
        # usually async_cancel should already have been called, but then its a noop
        # if the runtime is stopped without a sigterm (e.g. as a context manager, this can happen)
        self._health_servicer.enter_graceful_shutdown()
        await self.async_cancel()
        await self._connection_pool.close()

    async def async_cancel(self):
        """The async method to stop server."""
        await self.server.stop(0)

    async def async_run_forever(self):
        """The async running of server."""
        self._connection_pool.start()
        await self.server.wait_for_termination()

    async def dry_run(self, empty, context) -> jina_pb2.StatusProto:
        """
        Process the the call requested by having a dry run call to every Executor in the graph

        :param empty: The service expects an empty protobuf message
        :param context: grpc context
        :returns: the response request
        """
        from docarray import DocumentArray
        from jina.clients.request import request_generator
        from jina.enums import DataInputType
        from jina.serve.executors import __dry_run_endpoint__

        da = DocumentArray()

        try:
            req_iterator = request_generator(
                exec_endpoint=__dry_run_endpoint__,
                data=da,
                data_type=DataInputType.DOCUMENT,
            )
            async for _ in self.streamer.stream(request_iterator=req_iterator):
                pass
            status_message = StatusMessage()
            status_message.set_code(jina_pb2.StatusProto.SUCCESS)
            return status_message.proto
        except Exception as ex:
            status_message = StatusMessage()
            status_message.set_exception(ex)
            return status_message.proto

    async def _status(self, empty, context) -> jina_pb2.JinaInfoProto:
        """
        Process the the call requested and return the JinaInfo of the Runtime

        :param empty: The service expects an empty protobuf message
        :param context: grpc context
        :returns: the response request
        """
        infoProto = jina_pb2.JinaInfoProto()
        version, env_info = get_full_version()
        for k, v in version.items():
            infoProto.jina[k] = str(v)
        for k, v in env_info.items():
            infoProto.envs[k] = str(v)
        return infoProto
예제 #7
0
파일: websocket.py 프로젝트: sthagen/jina
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        **kwargs,
    ):
        """
        :param inputs: the callable
        :param on_done: the callback for on_done
        :param on_error: the callback for on_error
        :param on_always: the callback for on_always
        :param kwargs: kwargs for _get_task_name and _get_requests
        :yields: generator over results
        """
        with ImportExtensions(required=True):
            import aiohttp

        self.inputs = inputs
        request_iterator = self._get_requests(**kwargs)

        async with AsyncExitStack() as stack:
            try:
                cm1 = ProgressBar(
                    total_length=self._inputs_length, disable=not (self.show_progress)
                )
                p_bar = stack.enter_context(cm1)

                proto = 'wss' if self.args.tls else 'ws'
                url = f'{proto}://{self.args.host}:{self.args.port}/'
                iolet = await stack.enter_async_context(
                    WebsocketClientlet(url=url, logger=self.logger)
                )

                request_buffer: Dict[str, asyncio.Future] = dict()

                def _result_handler(result):
                    return result

                async def _receive():
                    def _response_handler(response):
                        if response.header.request_id in request_buffer:
                            future = request_buffer.pop(response.header.request_id)
                            future.set_result(response)
                        else:
                            self.logger.warning(
                                f'discarding unexpected response with request id {response.header.request_id}'
                            )

                    """Await messages from WebsocketGateway and process them in the request buffer"""
                    try:
                        async for response in iolet.recv_message():
                            _response_handler(response)
                    finally:
                        if request_buffer:
                            self.logger.warning(
                                f'{self.__class__.__name__} closed, cancelling all outstanding requests'
                            )
                            for future in request_buffer.values():
                                future.cancel()
                            request_buffer.clear()

                def _handle_end_of_iter():
                    """Send End of iteration signal to the Gateway"""
                    asyncio.create_task(iolet.send_eoi())

                def _request_handler(request: 'Request') -> 'asyncio.Future':
                    """
                    For each request in the iterator, we send the `Message` using `iolet.send_message()`.
                    For websocket requests from client, for each request in the iterator, we send the request in `bytes`
                    using using `iolet.send_message()`.
                    Then add {<request-id>: <an-empty-future>} to the request buffer.
                    This empty future is used to track the `result` of this request during `receive`.
                    :param request: current request in the iterator
                    :return: asyncio Future for sending message
                    """
                    future = get_or_reuse_loop().create_future()
                    request_buffer[request.header.request_id] = future
                    asyncio.create_task(iolet.send_message(request))
                    return future

                streamer = RequestStreamer(
                    args=self.args,
                    request_handler=_request_handler,
                    result_handler=_result_handler,
                    end_of_iter_handler=_handle_end_of_iter,
                )

                receive_task = get_or_reuse_loop().create_task(_receive())

                if receive_task.done():
                    raise RuntimeError(
                        'receive task not running, can not send messages'
                    )
                async for response in streamer.stream(request_iterator):
                    callback_exec(
                        response=response,
                        on_error=on_error,
                        on_done=on_done,
                        on_always=on_always,
                        continue_on_error=self.continue_on_error,
                        logger=self.logger,
                    )
                    if self.show_progress:
                        p_bar.update()
                    yield response

            except aiohttp.ClientError as e:
                self.logger.error(
                    f'Error while streaming response from websocket server {e!r}'
                )

                if on_error or on_always:
                    if on_error:
                        callback_exec_on_error(on_error, e, self.logger)
                    if on_always:
                        callback_exec(
                            response=None,
                            on_error=None,
                            on_done=None,
                            on_always=on_always,
                            continue_on_error=self.continue_on_error,
                            logger=self.logger,
                        )
                else:
                    raise e
예제 #8
0
파일: http.py 프로젝트: srbhr/jina
    async def _get_results(
        self,
        inputs: 'InputType',
        on_done: 'CallbackFnType',
        on_error: Optional['CallbackFnType'] = None,
        on_always: Optional['CallbackFnType'] = None,
        **kwargs,
    ):
        """
        :param inputs: the callable
        :param on_done: the callback for on_done
        :param on_error: the callback for on_error
        :param on_always: the callback for on_always
        :param kwargs: kwargs for _get_task_name and _get_requests
        :yields: generator over results
        """
        with ImportExtensions(required=True):
            import aiohttp

        self.inputs = inputs
        request_iterator = self._get_requests(**kwargs)

        async with AsyncExitStack() as stack:
            try:
                cm1 = (ProgressBar(total_length=self._inputs_length)
                       if self.show_progress else nullcontext())
                p_bar = stack.enter_context(cm1)

                proto = 'https' if self.args.https else 'http'
                url = f'{proto}://{self.args.host}:{self.args.port}/post'
                iolet = await stack.enter_async_context(
                    HTTPClientlet(url=url, logger=self.logger))

                def _request_handler(request: 'Request') -> 'asyncio.Future':
                    """
                    For HTTP Client, for each request in the iterator, we `send_message` using
                    http POST request and add it to the list of tasks which is awaited and yielded.
                    :param request: current request in the iterator
                    :return: asyncio Task for sending message
                    """
                    return asyncio.ensure_future(
                        iolet.send_message(request=request))

                def _result_handler(result):
                    return result

                streamer = RequestStreamer(
                    self.args,
                    request_handler=_request_handler,
                    result_handler=_result_handler,
                )
                async for response in streamer.stream(request_iterator):
                    r_status = response.status

                    r_str = await response.json()
                    if r_status == 404:
                        raise BadClient(f'no such endpoint {url}')
                    elif r_status < 200 or r_status > 300:
                        raise ValueError(r_str)

                    da = None
                    if 'data' in r_str and r_str['data'] is not None:
                        from docarray import DocumentArray

                        da = DocumentArray.from_dict(r_str['data'])
                        del r_str['data']

                    resp = DataRequest(r_str)
                    if da is not None:
                        resp.data.docs = da

                    callback_exec(
                        response=resp,
                        on_error=on_error,
                        on_done=on_done,
                        on_always=on_always,
                        continue_on_error=self.continue_on_error,
                        logger=self.logger,
                    )
                    if self.show_progress:
                        p_bar.update()
                    yield resp
            except aiohttp.ClientError as e:
                self.logger.error(
                    f'Error while fetching response from HTTP server {e!r}')
예제 #9
0
파일: app.py 프로젝트: srbhr/jina
def get_fastapi_app(
    args: 'argparse.Namespace',
    topology_graph: 'TopologyGraph',
    connection_pool: 'GrpcConnectionPool',
    logger: 'JinaLogger',
):
    """
    Get the app from FastAPI as the Websocket interface.

    :param args: passed arguments.
    :param topology_graph: topology graph that manages the logic of sending to the proper executors.
    :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them
    :param logger: Jina logger.
    :return: fastapi app
    """

    with ImportExtensions(required=True):
        from fastapi import FastAPI, WebSocket, WebSocketDisconnect

    class ConnectionManager:
        def __init__(self):
            self.active_connections: List[WebSocket] = []

        async def connect(self, websocket: WebSocket):
            await websocket.accept()
            logger.debug(
                f'client {websocket.client.host}:{websocket.client.port} connected'
            )
            self.active_connections.append(websocket)

        def disconnect(self, websocket: WebSocket):
            self.active_connections.remove(websocket)

    manager = ConnectionManager()

    app = FastAPI()

    from jina.serve.stream import RequestStreamer
    from jina.serve.runtimes.gateway.request_handling import (
        handle_request,
        handle_result,
    )

    streamer = RequestStreamer(
        args=args,
        request_handler=handle_request(graph=topology_graph,
                                       connection_pool=connection_pool),
        result_handler=handle_result,
    )
    streamer.Call = streamer.stream

    @app.on_event('shutdown')
    async def _shutdown():
        await connection_pool.close()

    @app.websocket('/')
    async def websocket_endpoint(websocket: WebSocket):

        await manager.connect(websocket)

        async def req_iter():
            async for request_bytes in websocket.iter_bytes():
                if request_bytes == bytes(True):
                    break
                yield DataRequest(request_bytes)

        try:
            async for msg in streamer.stream(request_iterator=req_iter()):
                await websocket.send_bytes(bytes(msg))
        except WebSocketDisconnect:
            logger.debug('Client successfully disconnected from server')
            manager.disconnect(websocket)

    return app
예제 #10
0
def get_fastapi_app(
    args: 'argparse.Namespace',
    topology_graph: 'TopologyGraph',
    connection_pool: 'GrpcConnectionPool',
    logger: 'JinaLogger',
    metrics_registry: Optional['CollectorRegistry'] = None,
):
    """
    Get the app from FastAPI as the Websocket interface.

    :param args: passed arguments.
    :param topology_graph: topology graph that manages the logic of sending to the proper executors.
    :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them
    :param logger: Jina logger.
    :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler
    :return: fastapi app
    """

    from jina.serve.runtimes.gateway.http.models import JinaEndpointRequestModel

    with ImportExtensions(required=True):
        from fastapi import FastAPI, WebSocket, WebSocketDisconnect

    class ConnectionManager:
        def __init__(self):
            self.active_connections: List[WebSocket] = []
            self.protocol_dict: Dict[str, WebsocketSubProtocols] = {}

        def get_client(self, websocket: WebSocket) -> str:
            return f'{websocket.client.host}:{websocket.client.port}'

        def get_subprotocol(self, headers: Dict):
            try:
                if 'sec-websocket-protocol' in headers:
                    subprotocol = WebsocketSubProtocols(
                        headers['sec-websocket-protocol'])
                elif b'sec-websocket-protocol' in headers:
                    subprotocol = WebsocketSubProtocols(
                        headers[b'sec-websocket-protocol'].decode())
                else:
                    subprotocol = WebsocketSubProtocols.JSON
                    logger.debug(
                        f'no protocol headers passed. Choosing default subprotocol {WebsocketSubProtocols.JSON}'
                    )
            except Exception as e:
                logger.debug(
                    f'got an exception while setting user\'s subprotocol, defaulting to JSON {e}'
                )
                subprotocol = WebsocketSubProtocols.JSON
            return subprotocol

        async def connect(self, websocket: WebSocket):
            await websocket.accept()
            subprotocol = self.get_subprotocol(dict(
                websocket.scope['headers']))
            logger.info(
                f'client {websocket.client.host}:{websocket.client.port} connected '
                f'with subprotocol {subprotocol}')
            self.active_connections.append(websocket)
            self.protocol_dict[self.get_client(websocket)] = subprotocol

        def disconnect(self, websocket: WebSocket):
            self.protocol_dict.pop(self.get_client(websocket))
            self.active_connections.remove(websocket)

        async def receive(self, websocket: WebSocket) -> Any:
            subprotocol = self.protocol_dict[self.get_client(websocket)]
            if subprotocol == WebsocketSubProtocols.JSON:
                return await websocket.receive_json(mode='text')
            elif subprotocol == WebsocketSubProtocols.BYTES:
                return await websocket.receive_bytes()

        async def iter(self, websocket: WebSocket) -> AsyncIterator[Any]:
            try:
                while True:
                    yield await self.receive(websocket)
            except WebSocketDisconnect:
                pass

        async def send(self, websocket: WebSocket, data: DataRequest) -> None:
            subprotocol = self.protocol_dict[self.get_client(websocket)]
            if subprotocol == WebsocketSubProtocols.JSON:
                return await websocket.send_json(data.to_dict(), mode='text')
            elif subprotocol == WebsocketSubProtocols.BYTES:
                return await websocket.send_bytes(data.to_bytes())

    manager = ConnectionManager()

    app = FastAPI()

    from jina.serve.runtimes.gateway.request_handling import RequestHandler
    from jina.serve.stream import RequestStreamer

    request_handler = RequestHandler(metrics_registry, args.name)

    streamer = RequestStreamer(
        args=args,
        request_handler=request_handler.handle_request(
            graph=topology_graph, connection_pool=connection_pool),
        result_handler=request_handler.handle_result(),
    )

    streamer.Call = streamer.stream

    @app.on_event('shutdown')
    async def _shutdown():
        await connection_pool.close()

    @app.websocket('/')
    async def websocket_endpoint(websocket: WebSocket):
        await manager.connect(websocket)

        async def req_iter():
            async for request in manager.iter(websocket):
                if isinstance(request, dict):
                    if request == {}:
                        break
                    else:
                        # NOTE: Helps in converting camelCase to snake_case
                        req_generator_input = JinaEndpointRequestModel(
                            **request).dict()
                        req_generator_input['data_type'] = DataInputType.DICT
                        if request['data'] is not None and 'docs' in request[
                                'data']:
                            req_generator_input['data'] = req_generator_input[
                                'data']['docs']

                        # you can't do `yield from` inside an async function
                        for data_request in request_generator(
                                **req_generator_input):
                            yield data_request
                elif isinstance(request, bytes):
                    if request == bytes(True):
                        break
                    else:
                        yield DataRequest(request)

        try:
            async for msg in streamer.stream(request_iterator=req_iter()):
                await manager.send(websocket, msg)
        except WebSocketDisconnect:
            logger.info('Client successfully disconnected from server')
            manager.disconnect(websocket)

    return app
예제 #11
0
def get_fastapi_app(
    args: 'argparse.Namespace',
    topology_graph: 'TopologyGraph',
    connection_pool: 'GrpcConnectionPool',
    logger: 'JinaLogger',
    metrics_registry: Optional['CollectorRegistry'] = None,
):
    """
    Get the app from FastAPI as the Websocket interface.

    :param args: passed arguments.
    :param topology_graph: topology graph that manages the logic of sending to the proper executors.
    :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them
    :param logger: Jina logger.
    :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler
    :return: fastapi app
    """

    from jina.serve.runtimes.gateway.http.models import JinaEndpointRequestModel

    with ImportExtensions(required=True):
        from fastapi import FastAPI, Response, WebSocket, WebSocketDisconnect, status

    class ConnectionManager:
        def __init__(self):
            self.active_connections: List[WebSocket] = []
            self.protocol_dict: Dict[str, WebsocketSubProtocols] = {}

        def get_client(self, websocket: WebSocket) -> str:
            return f'{websocket.client.host}:{websocket.client.port}'

        def get_subprotocol(self, headers: Dict):
            try:
                if 'sec-websocket-protocol' in headers:
                    subprotocol = WebsocketSubProtocols(
                        headers['sec-websocket-protocol']
                    )
                elif b'sec-websocket-protocol' in headers:
                    subprotocol = WebsocketSubProtocols(
                        headers[b'sec-websocket-protocol'].decode()
                    )
                else:
                    subprotocol = WebsocketSubProtocols.JSON
                    logger.debug(
                        f'no protocol headers passed. Choosing default subprotocol {WebsocketSubProtocols.JSON}'
                    )
            except Exception as e:
                logger.debug(
                    f'got an exception while setting user\'s subprotocol, defaulting to JSON {e}'
                )
                subprotocol = WebsocketSubProtocols.JSON
            return subprotocol

        async def connect(self, websocket: WebSocket):
            await websocket.accept()
            subprotocol = self.get_subprotocol(dict(websocket.scope['headers']))
            logger.info(
                f'client {websocket.client.host}:{websocket.client.port} connected '
                f'with subprotocol {subprotocol}'
            )
            self.active_connections.append(websocket)
            self.protocol_dict[self.get_client(websocket)] = subprotocol

        def disconnect(self, websocket: WebSocket):
            self.protocol_dict.pop(self.get_client(websocket))
            self.active_connections.remove(websocket)

        async def receive(self, websocket: WebSocket) -> Any:
            subprotocol = self.protocol_dict[self.get_client(websocket)]
            if subprotocol == WebsocketSubProtocols.JSON:
                return await websocket.receive_json(mode='text')
            elif subprotocol == WebsocketSubProtocols.BYTES:
                return await websocket.receive_bytes()

        async def iter(self, websocket: WebSocket) -> AsyncIterator[Any]:
            try:
                while True:
                    yield await self.receive(websocket)
            except WebSocketDisconnect:
                pass

        async def send(
            self, websocket: WebSocket, data: Union[DataRequest, StatusMessage]
        ) -> None:
            subprotocol = self.protocol_dict[self.get_client(websocket)]
            if subprotocol == WebsocketSubProtocols.JSON:
                return await websocket.send_json(data.to_dict(), mode='text')
            elif subprotocol == WebsocketSubProtocols.BYTES:
                return await websocket.send_bytes(data.to_bytes())

    manager = ConnectionManager()

    app = FastAPI()

    from jina.serve.runtimes.gateway.request_handling import RequestHandler
    from jina.serve.stream import RequestStreamer

    request_handler = RequestHandler(metrics_registry, args.name)

    streamer = RequestStreamer(
        args=args,
        request_handler=request_handler.handle_request(
            graph=topology_graph, connection_pool=connection_pool
        ),
        result_handler=request_handler.handle_result(),
    )

    streamer.Call = streamer.stream

    @app.get(
        path='/',
        summary='Get the health of Jina service',
    )
    async def _health():
        """
        Get the health of this Jina service.
        .. # noqa: DAR201

        """
        return {}

    @app.get(
        path='/status',
        summary='Get the status of Jina service',
    )
    async def _status():
        """
        Get the status of this Jina service.

        This is equivalent to running `jina -vf` from command line.

        .. # noqa: DAR201
        """
        version, env_info = get_full_version()
        for k, v in version.items():
            version[k] = str(v)
        for k, v in env_info.items():
            env_info[k] = str(v)
        return {'jina': version, 'envs': env_info}

    @app.on_event('shutdown')
    async def _shutdown():
        await connection_pool.close()

    @app.websocket('/')
    async def websocket_endpoint(
        websocket: WebSocket, response: Response
    ):  # 'response' is a FastAPI response, not a Jina response
        await manager.connect(websocket)

        async def req_iter():
            async for request in manager.iter(websocket):
                if isinstance(request, dict):
                    if request == {}:
                        break
                    else:
                        # NOTE: Helps in converting camelCase to snake_case
                        req_generator_input = JinaEndpointRequestModel(**request).dict()
                        req_generator_input['data_type'] = DataInputType.DICT
                        if request['data'] is not None and 'docs' in request['data']:
                            req_generator_input['data'] = req_generator_input['data'][
                                'docs'
                            ]

                        # you can't do `yield from` inside an async function
                        for data_request in request_generator(**req_generator_input):
                            yield data_request
                elif isinstance(request, bytes):
                    if request == bytes(True):
                        break
                    else:
                        yield DataRequest(request)

        try:
            async for msg in streamer.stream(request_iterator=req_iter()):
                await manager.send(websocket, msg)
        except InternalNetworkError as err:
            import grpc

            manager.disconnect(websocket)
            fallback_msg = (
                f'Connection to deployment at {err.dest_addr} timed out. You can adjust `timeout_send` attribute.'
                if err.code() == grpc.StatusCode.DEADLINE_EXCEEDED
                else f'Network error while connecting to deployment at {err.dest_addr}. It may be down.'
            )
            msg = (
                err.details()
                if _fits_ws_close_msg(
                    err.details()
                )  # some messages are too long for ws closing message
                else fallback_msg
            )
            await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=msg)
        except WebSocketDisconnect:
            logger.info('Client successfully disconnected from server')
            manager.disconnect(websocket)

    async def _get_singleton_result(request_iterator) -> Dict:
        """
        Streams results from AsyncPrefetchCall as a dict

        :param request_iterator: request iterator, with length of 1
        :return: the first result from the request iterator
        """
        async for k in streamer.stream(request_iterator=request_iterator):
            request_dict = k.to_dict()
            return request_dict

    from docarray import DocumentArray

    from jina.proto import jina_pb2
    from jina.serve.executors import __dry_run_endpoint__
    from jina.serve.runtimes.gateway.http.models import PROTO_TO_PYDANTIC_MODELS

    @app.get(
        path='/dry_run',
        summary='Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to '
        'validate connectivity',
        response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto,
    )
    async def _dry_run_http():
        """
        Get the health of the complete Flow service.
        .. # noqa: DAR201

        """

        da = DocumentArray()

        try:
            _ = await _get_singleton_result(
                request_generator(
                    exec_endpoint=__dry_run_endpoint__,
                    data=da,
                    data_type=DataInputType.DOCUMENT,
                )
            )
            status_message = StatusMessage()
            status_message.set_code(jina_pb2.StatusProto.SUCCESS)
            return status_message.to_dict()
        except Exception as ex:
            status_message = StatusMessage()
            status_message.set_exception(ex)
            return status_message.to_dict(use_integers_for_enums=True)

    @app.websocket('/dry_run')
    async def websocket_endpoint(
        websocket: WebSocket, response: Response
    ):  # 'response' is a FastAPI response, not a Jina response
        from jina.proto import jina_pb2
        from jina.serve.executors import __dry_run_endpoint__

        await manager.connect(websocket)

        da = DocumentArray()
        try:
            async for _ in streamer.stream(
                request_iterator=request_generator(
                    exec_endpoint=__dry_run_endpoint__,
                    data=da,
                    data_type=DataInputType.DOCUMENT,
                )
            ):
                pass
            status_message = StatusMessage()
            status_message.set_code(jina_pb2.StatusProto.SUCCESS)
            await manager.send(websocket, status_message)
        except InternalNetworkError as err:
            manager.disconnect(websocket)
            msg = (
                err.details()
                if _fits_ws_close_msg(err.details())  # some messages are too long
                else f'Network error while connecting to deployment at {err.dest_addr}. It may be down.'
            )
            await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=msg)
        except WebSocketDisconnect:
            logger.info('Client successfully disconnected from server')
            manager.disconnect(websocket)
        except Exception as ex:
            manager.disconnect(websocket)
            status_message = StatusMessage()
            status_message.set_exception(ex)
            await manager.send(websocket, status_message)

    return app