예제 #1
0
def test_no_suppress_other_exception():
    with pytest.raises(Exception):
        with ImportExtensions(required=False, logger=default_logger):
            raise Exception

    with pytest.raises(Exception):
        with ImportExtensions(required=True, logger=default_logger):
            raise Exception
예제 #2
0
    def mutate(
        self,
        mutation: str,
        variables: Optional[dict] = None,
        timeout: Optional[float] = None,
        headers: Optional[dict] = None,
    ):
        """Perform a GraphQL mutation

        :param mutation: the GraphQL mutation as a single string.
        :param variables: variables to be substituted in the mutation. Not needed if no variables are present in the mutation string.
        :param timeout: HTTP request timeout
        :param headers: HTTP headers
        :return: dict containing the optional keys ``data`` and ``errors``, for response data and errors.
        """
        with ImportExtensions(required=True):
            from sgqlc.endpoint.http import HTTPEndpoint as SgqlcHTTPEndpoint

            proto = 'https' if self.args.tls else 'http'
            graphql_url = f'{proto}://{self.args.host}:{self.args.port}/graphql'
            endpoint = SgqlcHTTPEndpoint(graphql_url)
            res = endpoint(mutation,
                           variables=variables,
                           timeout=timeout,
                           extra_headers=headers)
            if 'errors' in res and res['errors']:
                msg = 'GraphQL mutation returned the following errors: '
                for err in res['errors']:
                    msg += err['message'] + '. '
                raise ConnectionError(msg)
            return res
예제 #3
0
    def craft(self, buffer: bytes, uri: str, *args, **kwargs) -> Dict:
        """
        Read image file and craft it into image matrix.

        Read the image from the given file path that specified in `buffer` and save the `ndarray` of the image in
            the `blob` of the document.

        :param buffer: the image in raw bytes
        :param uri: the image file path

        """
        with ImportExtensions(
            required=True,
            verbose=True,
            pkg_name='Pillow',
            logger=self.logger,
            help_text='PIL is missing. Install it with `pip install Pillow`',
        ):
            from PIL import Image

            if buffer:
                raw_img = Image.open(io.BytesIO(buffer))
            elif uri:
                raw_img = Image.open(uri)
            else:
                raise ValueError('no value found in "buffer" and "uri"')
            raw_img = raw_img.convert('RGB')
            img = np.array(raw_img).astype('float32')
            if self.channel_axis != -1:
                img = np.moveaxis(img, -1, self.channel_axis)
            return dict(blob=img)
예제 #4
0
def upload_file(
    url: str,
    file_name: str,
    buffer_data: bytes,
    dict_data: Dict,
    headers: Dict,
    stream: bool = False,
    method: str = 'post',
):
    """Upload file to target url

    :param url: target url
    :param file_name: the file name
    :param buffer_data: the data to upload
    :param dict_data: the dict-style data to upload
    :param headers: the request header
    :param stream: receive stream response
    :param method: the request method
    :return: the response of request
    """
    with ImportExtensions(required=True):
        import requests

    dict_data.update({'file': (file_name, buffer_data)})

    (data, ctype) = requests.packages.urllib3.filepost.encode_multipart_formdata(
        dict_data
    )

    headers.update({'Content-Type': ctype})

    response = getattr(requests, method)(url, data=data, headers=headers, stream=stream)

    return response
예제 #5
0
파일: __init__.py 프로젝트: jina-ai/jina
    async def async_setup(self):
        """
        Start the DataRequestHandler and wait for the GRPC and Monitoring servers to start
        """
        if self.metrics_registry:
            with ImportExtensions(
                required=True,
                help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
            ):
                from prometheus_client import Summary

            self._summary_time = (
                Summary(
                    'receiving_request_seconds',
                    'Time spent processing request',
                    registry=self.metrics_registry,
                    namespace='jina',
                    labelnames=('runtime_name',),
                )
                .labels(self.args.name)
                .time()
            )
        else:
            self._summary_time = contextlib.nullcontext()

        await self._async_setup_grpc_server()
예제 #6
0
    def __init__(
        self,
        logger: Optional[JinaLogger] = None,
        compression: Optional[str] = None,
        metrics_registry: Optional['CollectorRegistry'] = None,
    ):
        self._logger = logger or JinaLogger(self.__class__.__name__)

        self.compression = (getattr(grpc.Compression, compression)
                            if compression else grpc.Compression.NoCompression)

        if metrics_registry:
            with ImportExtensions(
                    required=True,
                    help_text=
                    'You need to install the `prometheus_client` to use the montitoring functionality of jina',
            ):
                from prometheus_client import Summary

            self._summary_time = Summary(
                'sending_request_seconds',
                'Time spent between sending a request to the Pod and receiving the response',
                registry=metrics_registry,
                namespace='jina',
            ).time()
        else:
            self._summary_time = contextlib.nullcontext()
        self._connections = self._ConnectionPoolMap(self._logger,
                                                    self._summary_time)
        self._deployment_address_map = {}
예제 #7
0
    def _init_monitoring(self, metrics_registry: Optional['CollectorRegistry'] = None):

        if metrics_registry:

            with ImportExtensions(
                required=True,
                help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina',
            ):
                from prometheus_client import Counter, Summary

                self._counter = Counter(
                    'document_processed',
                    'Number of Documents that have been processed by the executor',
                    namespace='jina',
                    labelnames=('executor_endpoint', 'executor', 'runtime_name'),
                    registry=metrics_registry,
                )

                self._request_size_metrics = Summary(
                    'request_size_bytes',
                    'The request size in Bytes',
                    namespace='jina',
                    labelnames=('executor_endpoint', 'executor', 'runtime_name'),
                    registry=metrics_registry,
                )
        else:
            self._counter = None
            self._request_size_metrics = None
예제 #8
0
파일: decorators.py 프로젝트: jina-ai/jina
        def arg_wrapper(self, *args, **kwargs):
            if func.__name__ != '__init__':
                raise TypeError(
                    'this decorator should only be used on __init__ method of an executor'
                )

            if self.__class__ == cls:
                file_lock = nullcontext()
                with ImportExtensions(
                        required=False,
                        help_text=
                        f'FileLock is needed to guarantee non-concurrent initialization of replicas in the same '
                        f'machine.',
                ):
                    import filelock

                    locks_root = _get_locks_root()

                    lock_file = locks_root.joinpath(
                        f'{self.__class__.__name__}.lock')

                    file_lock = filelock.FileLock(lock_file, timeout=-1)

                with file_lock:
                    f = func(self, *args, **kwargs)
                return f
            else:
                return func(self, *args, **kwargs)
예제 #9
0
    def __init__(
        self,
        metrics_registry: Optional['CollectorRegistry'] = None,
        runtime_name: Optional[str] = None,
    ):
        self.request_init_time = {} if metrics_registry else None
        self._executor_endpoint_mapping = None

        if metrics_registry:
            with ImportExtensions(
                    required=True,
                    help_text=
                    'You need to install the `prometheus_client` to use the montitoring functionality of jina',
            ):
                from prometheus_client import Summary

            self._summary = Summary(
                'receiving_request_seconds',
                'Time spent processing request',
                registry=metrics_registry,
                namespace='jina',
                labelnames=('runtime_name', ),
            ).labels(runtime_name)

        else:
            self._summary = None
예제 #10
0
    async def async_setup(self):
        """
        The async method setup the runtime.

        Setup the uvicorn server.
        """
        with ImportExtensions(required=True):
            from uvicorn import Config, Server

        class UviServer(Server):
            """The uvicorn server."""
            async def setup(self, sockets=None):
                """
                Setup uvicorn server.

                :param sockets: sockets of server.
                """
                config = self.config
                if not config.loaded:
                    config.load()
                self.lifespan = config.lifespan_class(config)
                self.install_signal_handlers()
                await self.startup(sockets=sockets)
                if self.should_exit:
                    return

            async def serve(self, **kwargs):
                """
                Start the server.

                :param kwargs: keyword arguments
                """
                await self.main_loop()

        from jina.helper import extend_rest_interface

        uvicorn_kwargs = self.args.uvicorn_kwargs or {}
        for ssl_file in ['ssl_keyfile', 'ssl_certfile']:
            if getattr(self.args, ssl_file):
                if ssl_file not in uvicorn_kwargs.keys():
                    uvicorn_kwargs[ssl_file] = getattr(self.args, ssl_file)

        self._set_topology_graph()
        self._set_connection_pool()
        self._server = UviServer(config=Config(
            app=extend_rest_interface(
                get_fastapi_app(
                    self.args,
                    topology_graph=self._topology_graph,
                    connection_pool=self._connection_pool,
                    logger=self.logger,
                    metrics_registry=self.metrics_registry,
                )),
            host=__default_host__,
            port=self.args.port,
            ws_max_size=1024 * 1024 * 1024,
            log_level=os.getenv('JINA_LOG_LEVEL', 'error').lower(),
            **uvicorn_kwargs))
        await self._server.setup()
예제 #11
0
파일: app.py 프로젝트: luojiguicai/jina
def hello_world(args):
    """
    Execute the chatbot example.

    :param args: arguments passed from CLI
    """
    Path(args.workdir).mkdir(parents=True, exist_ok=True)

    with ImportExtensions(
            required=True,
            help_text=
            'this demo requires Pytorch and Transformers to be installed, '
            'if you haven\'t, please do `pip install jina[torch,transformers]`',
    ):
        import transformers, torch

        assert [torch,
                transformers]  #: prevent pycharm auto remove the above line

    targets = {
        'covid-csv': {
            'url': args.index_data_url,
            'filename': os.path.join(args.workdir, 'dataset.csv'),
        }
    }

    # download the data
    download_data(targets, args.download_proxy, task_name='download csv data')

    # now comes the real work
    # load index flow from a YAML file

    f = (Flow().add(uses=MyTransformer,
                    parallel=args.parallel).add(uses=MyIndexer,
                                                workspace=args.workdir))

    # index it!
    with f, open(targets['covid-csv']['filename']) as fp:
        f.index(DocumentArray.from_csv(fp, field_resolver={'question':
                                                           'text'}))

        # switch to REST gateway at runtime
        f.use_rest_gateway(args.port_expose)

        url_html_path = 'file://' + os.path.abspath(
            os.path.join(os.path.dirname(os.path.realpath(__file__)),
                         'static/index.html'))
        try:
            webbrowser.open(url_html_path, new=2)
        except:
            pass  # intentional pass, browser support isn't cross-platform
        finally:
            default_logger.success(
                f'You should see a demo page opened in your browser, '
                f'if not, you may open {url_html_path} manually')

        if not args.unblock_query_flow:
            f.block()
예제 #12
0
    async def start(self):
        """Create ClientSession and enter context

        :return: self
        """
        with ImportExtensions(required=True):
            import aiohttp

        self.session = aiohttp.ClientSession()
        await self.session.__aenter__()
        return self
예제 #13
0
def test_bad_import():
    from jina.logging import default_logger

    with pytest.raises(ModuleNotFoundError):
        with ImportExtensions(required=True, logger=default_logger):
            import abcdefg  # no install and unlist

    with pytest.raises(ModuleNotFoundError):
        with ImportExtensions(required=True, logger=default_logger):
            import ngt  # list but no install

    with ImportExtensions(required=False, logger=default_logger) as ie:
        import ngt

    assert ie._tags == ['ngt', 'index', 'py37']

    with ImportExtensions(required=False, logger=default_logger) as ie:
        import ngt.abc.edf

    assert ie._tags == ['ngt', 'index', 'py37']

    with ImportExtensions(required=False, logger=default_logger) as ie:
        from ngt.abc import edf

    assert ie._tags == ['ngt', 'index', 'py37']

    with ImportExtensions(required=False, logger=default_logger) as ie:
        import abcdefg

    assert not ie._tags
예제 #14
0
파일: debug.py 프로젝트: winstonww/jina
        def _load_image(blob: 'np.ndarray', channel_axis: int):
            with ImportExtensions(
                    required=True,
                    pkg_name='Pillow',
                    verbose=True,
                    logger=self.logger,
                    help_text=
                    'PIL is missing. Install it with `pip install Pillow`',
            ):
                from PIL import Image

                img = _move_channel_axis(blob, channel_axis)
                return Image.fromarray(img.astype('uint8'))
예제 #15
0
        def wrapper(*args, **kwargs):

            call_hash = f'{func.__name__}({", ".join(map(str, args))})'

            pickle_protocol = 4
            file_lock = nullcontext()
            with ImportExtensions(
                    required=False,
                    help_text=
                    f'FileLock is needed to guarantee non-concurrent access to the'
                    f'cache_file {cache_file}',
            ):
                import filelock

                file_lock = filelock.FileLock(f'{cache_file}.lock', timeout=-1)

            cache_db = None
            with file_lock:
                try:
                    cache_db = shelve.open(cache_file,
                                           protocol=pickle_protocol,
                                           writeback=True)
                except Exception:
                    if os.path.exists(cache_file):
                        # cache is in an unsupported format, reset the cache
                        os.remove(cache_file)
                        cache_db = shelve.open(cache_file,
                                               protocol=pickle_protocol,
                                               writeback=True)

            if cache_db is None:
                # if we failed to load cache, do not raise, it is only an optimization thing
                return func(*args, **kwargs), False
            else:
                with cache_db as dict_db:
                    try:
                        if call_hash in dict_db and not kwargs.get(
                                'force', False):
                            return dict_db[call_hash], True

                        result = func(*args, **kwargs)
                        dict_db[call_hash] = result
                    except urllib.error.URLError:
                        if call_hash in dict_db:
                            default_logger.warning(
                                message.format(func_name=func.__name__))
                            return dict_db[call_hash], True
                        else:
                            raise
                return result, False
예제 #16
0
def _load_image(blob: 'np.ndarray', channel_axis: int):
    """
    Load an image array and return a `PIL.Image` object.
    """
    with ImportExtensions(
        required=True,
        verbose=True,
        pkg_name='Pillow',
        help_text='PIL is missing. Install it with `pip install Pillow`',
    ):
        from PIL import Image

        img = _move_channel_axis(blob, channel_axis)
        return Image.fromarray(img.astype('uint8'))
예제 #17
0
def archive_package(package_folder: 'Path') -> 'io.BytesIO':
    """
    Archives the given folder in zip format and return a data stream.
    :param package_folder: the folder path of the package
    :return: the data stream of zip content
    """

    with ImportExtensions(required=True):
        import pathspec

    root_path = package_folder.resolve()

    gitignore = root_path / '.gitignore'
    if not gitignore.exists():
        gitignore = Path(__resources_path__) / 'Python.gitignore'

    with gitignore.open() as fp:
        ignore_lines = [
            line.strip() for line in fp
            if line.strip() and (not line.startswith('#'))
        ]
        ignore_lines += ['.git', '.jina']
        ignored_spec = pathspec.PathSpec.from_lines('gitwildmatch',
                                                    ignore_lines)

    zip_stream = io.BytesIO()
    try:
        zfile = zipfile.ZipFile(zip_stream,
                                'w',
                                compression=zipfile.ZIP_DEFLATED)
    except EnvironmentError as e:
        raise e

    def _zip(base_path, path, archive):

        for p in path.iterdir():
            rel_path = p.relative_to(base_path)
            if ignored_spec.match_file(str(rel_path)):
                continue
            if p.is_dir():
                _zip(base_path, p, archive)
            else:
                archive.write(p, rel_path)

    _zip(root_path, root_path, zfile)

    zfile.close()
    zip_stream.seek(0)

    return zip_stream
예제 #18
0
    def __init__(self, args: Optional[argparse.Namespace] = None, **kwargs):
        if args and isinstance(args, argparse.Namespace):
            self.args = args
        else:
            self.args = ArgNamespace.kwargs2namespace(kwargs, set_hub_parser())
        self.logger = JinaLogger(self.__class__.__name__, **vars(args))

        with ImportExtensions(required=True):
            import rich
            import cryptography
            import filelock

            assert rich  #: prevent pycharm auto remove the above line
            assert cryptography
            assert filelock
예제 #19
0
    def fetch_meta(
        name: str,
        tag: str,
        secret: Optional[str] = None,
        force: bool = False,
    ) -> HubExecutor:
        """Fetch the executor meta info from Jina Hub.
        :param name: the UUID/Name of the executor
        :param tag: the tag of the executor if available, otherwise, use `None` as the value
        :param secret: the access secret of the executor
        :param force: if set to True, access to fetch_meta will always pull latest Executor metas, otherwise, default
            to local cache
        :return: meta of executor

        .. note::
            The `name` and `tag` should be passed via ``args`` and `force` and `secret` as ``kwargs``, otherwise,
            cache does not work.
        """
        with ImportExtensions(required=True):
            import requests

        pull_url = get_hubble_url_v1() + f'/executors/{name}/?'
        path_params = {}
        if secret:
            path_params['secret'] = secret
        if tag:
            path_params['tag'] = tag
        if path_params:
            pull_url += urlencode(path_params)

        resp = requests.get(pull_url, headers=get_request_header())
        if resp.status_code != 200:
            if resp.text:
                raise Exception(resp.text)
            resp.raise_for_status()

        resp = resp.json()

        return HubExecutor(
            uuid=resp['id'],
            name=resp.get('name', None),
            sn=resp.get('sn', None),
            tag=tag or resp['tag'],
            visibility=resp['visibility'],
            image_name=resp['image'],
            archive_url=resp['package']['download'],
            md5sum=resp['package']['md5'],
        )
예제 #20
0
async def request_generator(
    exec_endpoint: str,
    data: 'GeneratorSourceType',
    request_size: int = 0,
    data_type: DataInputType = DataInputType.AUTO,
    target_executor: Optional[str] = None,
    parameters: Optional[Dict] = None,
    **kwargs,  # do not remove this, add on purpose to suppress unknown kwargs
) -> AsyncIterator['Request']:
    """An async :function:`request_generator`.

    :param exec_endpoint: the endpoint string, by convention starts with `/`
    :param data: the data to use in the request
    :param request_size: the number of Documents per request
    :param data_type: if ``data`` is an iterator over self-contained document, i.e. :class:`DocumentSourceType`;
            or an iterator over possible Document content (set to text, blob and buffer).
    :param parameters: the kwargs that will be sent to the executor
    :param target_executor: a regex string. Only matching Executors will process the request.
    :param kwargs: additional arguments
    :yield: request
    """

    _kwargs = dict(extra_kwargs=kwargs)

    try:
        if data is None:
            # this allows empty inputs, i.e. a data request with only parameters
            yield _new_data_request(endpoint=exec_endpoint,
                                    target=target_executor,
                                    parameters=parameters)
        else:
            with ImportExtensions(required=True):
                import aiostream

            async for batch in aiostream.stream.chunks(data, request_size):
                yield _new_data_request_from_batch(
                    _kwargs=kwargs,
                    batch=batch,
                    data_type=data_type,
                    endpoint=exec_endpoint,
                    target=target_executor,
                    parameters=parameters,
                )
    except Exception as ex:
        # must be handled here, as grpc channel wont handle Python exception
        default_logger.critical(f'inputs is not valid! {ex!r}', exc_info=True)
예제 #21
0
    def _load_docker_client(self):
        with ImportExtensions(required=True):
            import docker.errors
            from docker import APIClient

            from jina import __windows__

            try:
                self._client = docker.from_env()
                # low-level client
                self._raw_client = APIClient(
                    base_url=docker.constants.DEFAULT_NPIPE
                    if __windows__ else docker.constants.DEFAULT_UNIX_SOCKET)
            except docker.errors.DockerException:
                self.logger.critical(
                    f'Docker daemon seems not running. Please run Docker daemon and try again.'
                )
                exit(1)
예제 #22
0
파일: networking.py 프로젝트: sthagen/jina
    def __init__(
        self,
        logger: Optional[JinaLogger] = None,
        compression: str = 'NoCompression',
        metrics_registry: Optional['CollectorRegistry'] = None,
    ):
        self._logger = logger or JinaLogger(self.__class__.__name__)
        GRPC_COMPRESSION_MAP = {
            'NoCompression'.lower(): grpc.Compression.NoCompression,
            'Gzip'.lower(): grpc.Compression.Gzip,
            'Deflate'.lower(): grpc.Compression.Deflate,
        }
        if compression.lower() not in GRPC_COMPRESSION_MAP:
            import warnings

            warnings.warn(
                message=
                f'Your compression "{compression}" is not supported. Supported '
                f'algorithms are `Gzip`, `Deflate` and `NoCompression`. NoCompression will be used as '
                f'default')
        self.compression = GRPC_COMPRESSION_MAP.get(
            compression.lower(), grpc.Compression.NoCompression)

        if metrics_registry:
            with ImportExtensions(
                    required=True,
                    help_text=
                    'You need to install the `prometheus_client` to use the montitoring functionality of jina',
            ):
                from prometheus_client import Summary

            self._summary_time = Summary(
                'sending_request_seconds',
                'Time spent between sending a request to the Pod and receiving the response',
                registry=metrics_registry,
                namespace='jina',
            ).time()
        else:
            self._summary_time = contextlib.nullcontext()
        self._connections = self._ConnectionPoolMap(self._logger,
                                                    self._summary_time)
예제 #23
0
    def __init__(
        self,
        args: 'argparse.Namespace',
        cancel_event: Optional[
            Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event']
        ] = None,
        **kwargs,
    ):
        super().__init__(args, **kwargs)
        self._loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self._loop)
        self.is_cancel = cancel_event or asyncio.Event()
        if not __windows__:
            # TODO: windows event loops don't support signal handlers
            try:
                for signame in {'SIGINT', 'SIGTERM'}:
                    self._loop.add_signal_handler(
                        getattr(signal, signame),
                        lambda *args, **kwargs: self.is_cancel.set(),
                    )
            except (ValueError, RuntimeError) as exc:
                self.logger.warning(
                    f' The runtime {self.__class__.__name__} will not be able to handle termination signals. '
                    f' {repr(exc)}'
                )
        else:
            with ImportExtensions(
                required=True,
                logger=self.logger,
                help_text='''If you see a 'DLL load failed' error, please reinstall `pywin32`.
                If you're using conda, please use the command `conda install -c anaconda pywin32`''',
            ):
                import win32api

            win32api.SetConsoleCtrlHandler(
                lambda *args, **kwargs: self.is_cancel.set(), True
            )

        self._setup_monitoring()
        self._loop.run_until_complete(self.async_setup())
예제 #24
0
    async def _create_remote_pod(self):
        """Create Workspace, Pod on remote JinaD server"""
        with ImportExtensions(required=True):
            # rich & aiohttp are used in `AsyncJinaDClient`
            import rich
            import aiohttp
            from daemon.clients import AsyncJinaDClient

            assert rich
            assert aiohttp

        # NOTE: args.timeout_ready is always set to -1 for JinadRuntime so that wait_for_success doesn't fail in Pod,
        # so it can't be used for Client timeout.
        self.client = AsyncJinaDClient(host=self.args.host,
                                       port=self.args.port_jinad,
                                       logger=self._logger)

        if not await self.client.alive:
            raise DaemonConnectivityError

        # Create a remote workspace with upload_files
        workspace_id = await self.client.workspaces.create(
            paths=self.filepaths,
            id=self.args.workspace_id,
            complete=True,
        )
        if not workspace_id:
            self._logger.critical(f'remote workspace creation failed')
            raise DaemonWorkspaceCreationFailed

        payload = replace_enum_to_str(vars(self._mask_args()))
        # Create a remote Pod in the above workspace
        success, response = await self.client.pods.create(
            workspace_id=workspace_id, payload=payload, envs=self.envs)
        if not success:
            self._logger.critical(f'remote pod creation failed')
            raise DaemonPodCreationFailed(response)
        else:
            self.pod_id = response
예제 #25
0
파일: __init__.py 프로젝트: jina-ai/jina
    def _init_monitoring(self):
        if (hasattr(self.runtime_args, 'metrics_registry')
                and self.runtime_args.metrics_registry):
            with ImportExtensions(
                    required=True,
                    help_text=
                    'You need to install the `prometheus_client` to use the montitoring functionality of jina',
            ):
                from prometheus_client import Summary

            self._summary_method = Summary(
                'process_request_seconds',
                'Time spent when calling the executor request method',
                registry=self.runtime_args.metrics_registry,
                namespace='jina',
                labelnames=('executor', 'executor_endpoint', 'runtime_name'),
            )
            self._metrics_buffer = {
                'process_request_seconds': self._summary_method
            }

        else:
            self._summary_method = None
            self._metrics_buffer = None
예제 #26
0
def download_with_resume(
    url: str,
    target_dir: 'Path',
    filename: Optional[str] = None,
    md5sum: Optional[str] = None,
) -> 'Path':
    """
    Download file from url to target_dir, and check md5sum.
    Performs a HTTP(S) download that can be restarted if prematurely terminated.
    The HTTP server must support byte ranges.

    :param url: the URL to download
    :param target_dir: the target path for the file
    :param filename: the filename of the downloaded file
    :param md5sum: the MD5 checksum to match

    :return: the filepath of the downloaded file
    """
    with ImportExtensions(required=True):
        import requests

    def _download(url, target, resume_byte_pos: int = None):
        resume_header = ({
            'Range': f'bytes={resume_byte_pos}-'
        } if resume_byte_pos else None)

        try:
            r = requests.get(url, stream=True, headers=resume_header)
        except requests.exceptions.RequestException as e:
            raise e

        block_size = 1024
        mode = 'ab' if resume_byte_pos else 'wb'

        with target.open(mode=mode) as f:
            for chunk in r.iter_content(32 * block_size):
                f.write(chunk)

    if filename is None:
        filename = url.split('/')[-1]
    filepath = target_dir / filename

    head_info = requests.head(url)
    file_size_online = int(head_info.headers.get('content-length', 0))

    _resume_byte_pos = None
    if filepath.exists():
        if md5sum and md5file(filepath) == md5sum:
            return filepath

        file_size_offline = filepath.stat().st_size
        if file_size_online > file_size_offline:
            _resume_byte_pos = file_size_offline

    _download(url, filepath, _resume_byte_pos)

    if md5sum and not md5file(filepath) == md5sum:
        raise RuntimeError(
            'MD5 checksum failed.'
            'Might happen when the network is unstable, please retry.'
            'If still not work, feel free to raise an issue.'
            'https://github.com/jina-ai/jina/issues/new')

    return filepath
예제 #27
0
def hello_world(args):
    """
    Execute the multimodal example.

    :param args: arguments passed from CLI
    """
    Path(args.workdir).mkdir(parents=True, exist_ok=True)

    with ImportExtensions(
            required=True,
            help_text=
            'this demo requires Pytorch and Transformers to be installed, '
            'if you haven\'t, please do `pip install jina[torch,transformers]`',
    ):
        import transformers, torch, torchvision

        assert [
            torch,
            transformers,
            torchvision,
        ]  #: prevent pycharm auto remove the above line

    # args.workdir = '0bae16ce-5bb2-43be-bcd4-6f1969e8068f'
    targets = {
        'people-img': {
            'url': args.index_data_url,
            'filename': os.path.join(args.workdir, 'dataset.zip'),
        }
    }

    # download the data
    if not os.path.exists(targets['people-img']['filename']):
        download_data(targets,
                      args.download_proxy,
                      task_name='download zip data')

    with zipfile.ZipFile(targets['people-img']['filename'], 'r') as fp:
        fp.extractall(args.workdir)

    # this envs are referred in index and query flow YAMLs
    os.environ['HW_WORKDIR'] = args.workdir
    os.environ['PY_MODULE'] = os.path.abspath(
        os.path.join(cur_dir, 'my_executors.py'))
    # now comes the real work
    # load index flow from a YAML file

    # index it!
    f = Flow.load_config('flow-index.yml')

    with f, open(f'{args.workdir}/people-img/meta.csv', newline='') as fp:
        f.index(inputs=DocumentArray.from_csv(fp),
                request_size=10,
                show_progress=True)

    # search it!
    f = Flow.load_config('flow-search.yml')
    # switch to HTTP gateway
    f.protocol = 'http'
    f.port_expose = args.port_expose

    url_html_path = 'file://' + os.path.abspath(
        os.path.join(cur_dir, 'static/index.html'))
    with f:
        try:
            webbrowser.open(url_html_path, new=2)
        except:
            pass  # intentional pass, browser support isn't cross-platform
        finally:
            default_logger.info(
                f'You should see a demo page opened in your browser, '
                f'if not, you may open {url_html_path} manually')
        if not args.unblock_query_flow:
            f.block()
예제 #28
0
파일: app.py 프로젝트: jina-ai/jina
def get_fastapi_app(
    args: 'argparse.Namespace',
    topology_graph: 'TopologyGraph',
    connection_pool: 'GrpcConnectionPool',
    logger: 'JinaLogger',
    metrics_registry: Optional['CollectorRegistry'] = None,
):
    """
    Get the app from FastAPI as the REST interface.

    :param args: passed arguments.
    :param topology_graph: topology graph that manages the logic of sending to the proper executors.
    :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them
    :param logger: Jina logger.
    :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler
    :return: fastapi app
    """
    with ImportExtensions(required=True):
        from fastapi import FastAPI, Response, status
        from fastapi.middleware.cors import CORSMiddleware

        from jina.serve.runtimes.gateway.http.models import (
            JinaEndpointRequestModel,
            JinaRequestModel,
            JinaResponseModel,
        )

    app = FastAPI(
        title=args.title or 'My Jina Service',
        description=args.description or
        'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` '
        'to customize the title and description.',
        version=__version__,
    )

    if args.cors:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=['*'],
            allow_credentials=True,
            allow_methods=['*'],
            allow_headers=['*'],
        )
        logger.warning(
            'CORS is enabled. This service is accessible from any website!')

    from jina.serve.runtimes.gateway.request_handling import RequestHandler
    from jina.serve.stream import RequestStreamer

    request_handler = RequestHandler(metrics_registry, args.name)

    streamer = RequestStreamer(
        args=args,
        request_handler=request_handler.handle_request(
            graph=topology_graph, connection_pool=connection_pool),
        result_handler=request_handler.handle_result(),
    )
    streamer.Call = streamer.stream

    @app.on_event('shutdown')
    async def _shutdown():
        await connection_pool.close()

    openapi_tags = []
    if not args.no_debug_endpoints:
        openapi_tags.append({
            'name':
            'Debug',
            'description':
            'Debugging interface. In production, you should hide them by setting '
            '`--no-debug-endpoints` in `Flow`/`Gateway`.',
        })

        from jina.serve.runtimes.gateway.http.models import JinaHealthModel

        @app.get(
            path='/',
            summary='Get the health of Jina Gateway service',
            response_model=JinaHealthModel,
        )
        async def _gateway_health():
            """
            Get the health of this Gateway service.
            .. # noqa: DAR201

            """
            return {}

        from docarray import DocumentArray

        from jina.proto import jina_pb2
        from jina.serve.executors import __dry_run_endpoint__
        from jina.serve.runtimes.gateway.http.models import (
            PROTO_TO_PYDANTIC_MODELS,
            JinaInfoModel,
        )
        from jina.types.request.status import StatusMessage

        @app.get(
            path='/dry_run',
            summary=
            'Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to '
            'validate connectivity',
            response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto,
        )
        async def _flow_health():
            """
            Get the health of the complete Flow service.
            .. # noqa: DAR201

            """

            da = DocumentArray()

            try:
                _ = await _get_singleton_result(
                    request_generator(
                        exec_endpoint=__dry_run_endpoint__,
                        data=da,
                        data_type=DataInputType.DOCUMENT,
                    ))
                status_message = StatusMessage()
                status_message.set_code(jina_pb2.StatusProto.SUCCESS)
                return status_message.to_dict()
            except Exception as ex:
                status_message = StatusMessage()
                status_message.set_exception(ex)
                return status_message.to_dict(use_integers_for_enums=True)

        @app.get(
            path='/status',
            summary='Get the status of Jina service',
            response_model=JinaInfoModel,
            tags=['Debug'],
        )
        async def _status():
            """
            Get the status of this Jina service.

            This is equivalent to running `jina -vf` from command line.

            .. # noqa: DAR201
            """
            version, env_info = get_full_version()
            for k, v in version.items():
                version[k] = str(v)
            for k, v in env_info.items():
                env_info[k] = str(v)
            return {'jina': version, 'envs': env_info}

        @app.post(
            path='/post',
            summary='Post a data request to some endpoint',
            response_model=JinaResponseModel,
            tags=['Debug']
            # do not add response_model here, this debug endpoint should not restricts the response model
        )
        async def post(
            body: JinaEndpointRequestModel, response: Response
        ):  # 'response' is a FastAPI response, not a Jina response
            """
            Post a data request to some endpoint.

            This is equivalent to the following:

                from jina import Flow

                f = Flow().add(...)

                with f:
                    f.post(endpoint, ...)

            .. # noqa: DAR201
            .. # noqa: DAR101
            """
            # The above comment is written in Markdown for better rendering in FastAPI
            from jina.enums import DataInputType

            bd = body.dict()  # type: Dict
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            try:
                result = await _get_singleton_result(
                    request_generator(**req_generator_input))
            except InternalNetworkError as err:
                import grpc

                if err.code() == grpc.StatusCode.UNAVAILABLE:
                    response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE
                elif err.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
                    response.status_code = status.HTTP_504_GATEWAY_TIMEOUT
                else:
                    response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
                result = bd  # send back the request
                result['header'] = _generate_exception_header(
                    err)  # attach exception details to response header
                logger.error(
                    f'Error while getting responses from deployments: {err.details()}'
                )
            return result

    def _generate_exception_header(error: InternalNetworkError):
        import traceback

        from jina.proto.serializer import DataRequest

        exception_dict = {
            'name':
            str(error.__class__),
            'stacks': [
                str(x)
                for x in traceback.extract_tb(error.og_exception.__traceback__)
            ],
            'executor':
            '',
        }
        status_dict = {
            'code': DataRequest().status.ERROR,
            'description': error.details() if error.details() else '',
            'exception': exception_dict,
        }
        header_dict = {'request_id': error.request_id, 'status': status_dict}
        return header_dict

    def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs):
        """Exposing an executor endpoint to http endpoint
        :param exec_endpoint: the executor endpoint
        :param http_path: the http endpoint
        :param kwargs: kwargs accepted by FastAPI
        """

        # set some default kwargs for richer semantics
        # group flow exposed endpoints into `customized` group
        kwargs['tags'] = kwargs.get('tags', ['Customized'])
        kwargs['response_model'] = kwargs.get(
            'response_model',
            JinaResponseModel,  # use standard response model by default
        )
        kwargs['methods'] = kwargs.get('methods', ['POST'])

        @app.api_route(path=http_path or exec_endpoint,
                       name=http_path or exec_endpoint,
                       **kwargs)
        async def foo(body: JinaRequestModel):
            from jina.enums import DataInputType

            bd = body.dict() if body else {'data': None}
            bd['exec_endpoint'] = exec_endpoint
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            result = await _get_singleton_result(
                request_generator(**req_generator_input))
            return result

    if not args.no_crud_endpoints:
        openapi_tags.append({
            'name':
            'CRUD',
            'description':
            'CRUD interface. If your service does not implement those interfaces, you can should '
            'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.',
        })
        crud = {
            '/index': {
                'methods': ['POST']
            },
            '/search': {
                'methods': ['POST']
            },
            '/delete': {
                'methods': ['DELETE']
            },
            '/update': {
                'methods': ['PUT']
            },
        }
        for k, v in crud.items():
            v['tags'] = ['CRUD']
            v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.'
            expose_executor_endpoint(exec_endpoint=k, **v)

    if openapi_tags:
        app.openapi_tags = openapi_tags

    if args.expose_endpoints:
        endpoints = json.loads(args.expose_endpoints)  # type: Dict[str, Dict]
        for k, v in endpoints.items():
            expose_executor_endpoint(exec_endpoint=k, **v)

    if args.expose_graphql_endpoint:
        with ImportExtensions(required=True):
            from dataclasses import asdict

            import strawberry
            from docarray import DocumentArray
            from docarray.document.strawberry_type import (
                JSONScalar,
                StrawberryDocument,
                StrawberryDocumentInput,
            )
            from strawberry.fastapi import GraphQLRouter

            async def get_docs_from_endpoint(data, target_executor, parameters,
                                             exec_endpoint):
                req_generator_input = {
                    'data': [asdict(d) for d in data],
                    'target_executor': target_executor,
                    'parameters': parameters,
                    'exec_endpoint': exec_endpoint,
                    'data_type': DataInputType.DICT,
                }

                if (req_generator_input['data'] is not None
                        and 'docs' in req_generator_input['data']):
                    req_generator_input['data'] = req_generator_input['data'][
                        'docs']
                try:
                    response = await _get_singleton_result(
                        request_generator(**req_generator_input))
                except InternalNetworkError as err:
                    logger.error(
                        f'Error while getting responses from deployments: {err.details()}'
                    )
                    raise err  # will be handled by Strawberry
                return DocumentArray.from_dict(
                    response['data']).to_strawberry_type()

            @strawberry.type
            class Mutation:
                @strawberry.mutation
                async def docs(
                    self,
                    data: Optional[List[StrawberryDocumentInput]] = None,
                    target_executor: Optional[str] = None,
                    parameters: Optional[JSONScalar] = None,
                    exec_endpoint: str = '/search',
                ) -> List[StrawberryDocument]:
                    return await get_docs_from_endpoint(
                        data, target_executor, parameters, exec_endpoint)

            @strawberry.type
            class Query:
                @strawberry.field
                async def docs(
                    self,
                    data: Optional[List[StrawberryDocumentInput]] = None,
                    target_executor: Optional[str] = None,
                    parameters: Optional[JSONScalar] = None,
                    exec_endpoint: str = '/search',
                ) -> List[StrawberryDocument]:
                    return await get_docs_from_endpoint(
                        data, target_executor, parameters, exec_endpoint)

            schema = strawberry.Schema(query=Query, mutation=Mutation)
            app.include_router(GraphQLRouter(schema), prefix='/graphql')

    async def _get_singleton_result(request_iterator) -> Dict:
        """
        Streams results from AsyncPrefetchCall as a dict

        :param request_iterator: request iterator, with length of 1
        :return: the first result from the request iterator
        """
        async for k in streamer.stream(request_iterator=request_iterator):
            request_dict = k.to_dict()
            return request_dict

    return app
예제 #29
0
def run(
    args: 'argparse.Namespace',
    name: str,
    container_name: str,
    net_mode: Optional[str],
    runtime_ctrl_address: str,
    envs: Dict,
    is_started: Union['multiprocessing.Event', 'threading.Event'],
    is_shutdown: Union['multiprocessing.Event', 'threading.Event'],
    is_ready: Union['multiprocessing.Event', 'threading.Event'],
):
    """Method to be run in a process that stream logs from a Container

    This method is the target for the Pod's `thread` or `process`

    .. note::
        :meth:`run` is running in subprocess/thread, the exception can not be propagated to the main process.
        Hence, please do not raise any exception here.

    .. note::
        Please note that env variables are process-specific. Subprocess inherits envs from
        the main process. But Subprocess's envs do NOT affect the main process. It does NOT
        mess up user local system envs.

    :param args: namespace args from the Pod
    :param name: name of the Pod to have proper logging
    :param container_name: name to set the Container to
    :param net_mode: The network mode where to run the container
    :param runtime_ctrl_address: The control address of the runtime in the container
    :param envs: Dictionary of environment variables to be set in the docker image
    :param is_started: concurrency event to communicate runtime is properly started. Used for better logging
    :param is_shutdown: concurrency event to communicate runtime is terminated
    :param is_ready: concurrency event to communicate runtime is ready to receive messages
    """
    import docker

    log_kwargs = copy.deepcopy(vars(args))
    log_kwargs['log_config'] = 'docker'
    logger = JinaLogger(name, **log_kwargs)

    cancel = threading.Event()
    fail_to_start = threading.Event()

    if not __windows__:
        try:
            for signame in {signal.SIGINT, signal.SIGTERM}:
                signal.signal(signame, lambda *args, **kwargs: cancel.set())
        except (ValueError, RuntimeError) as exc:
            logger.warning(
                f' The process starting the container for {name} will not be able to handle termination signals. '
                f' {repr(exc)}')
    else:
        with ImportExtensions(
                required=True,
                logger=logger,
                help_text=
                '''If you see a 'DLL load failed' error, please reinstall `pywin32`.
                If you're using conda, please use the command `conda install -c anaconda pywin32`''',
        ):
            import win32api

        win32api.SetConsoleCtrlHandler(lambda *args, **kwargs: cancel.set(),
                                       True)

    client = docker.from_env()

    try:
        container = _docker_run(
            client=client,
            args=args,
            container_name=container_name,
            envs=envs,
            net_mode=net_mode,
            logger=logger,
        )
        client.close()

        def _is_ready():
            return AsyncNewLoopRuntime.is_ready(runtime_ctrl_address)

        def _is_container_alive(container) -> bool:
            import docker.errors

            try:
                container.reload()
            except docker.errors.NotFound:
                return False
            return True

        async def _check_readiness(container):
            while (_is_container_alive(container) and not _is_ready()
                   and not cancel.is_set()):
                await asyncio.sleep(0.1)
            if _is_container_alive(container):
                is_started.set()
                is_ready.set()
            else:
                fail_to_start.set()

        async def _stream_starting_logs(container):
            for line in container.logs(stream=True):
                if (not is_started.is_set() and not fail_to_start.is_set()
                        and not cancel.is_set()):
                    await asyncio.sleep(0.01)
                msg = line.decode().rstrip()  # type: str
                logger.debug(re.sub(r'\u001b\[.*?[@-~]', '', msg))

        async def _run_async(container):
            await asyncio.gather(*[
                _check_readiness(container),
                _stream_starting_logs(container)
            ])

        asyncio.run(_run_async(container))
    finally:
        client.close()
        if not is_started.is_set():
            logger.error(
                f' Process terminated, the container fails to start, check the arguments or entrypoint'
            )
        is_shutdown.set()
        logger.debug(f'process terminated')
예제 #30
0
def get_fastapi_app(
    args: 'argparse.Namespace',
    topology_graph: 'TopologyGraph',
    connection_pool: 'GrpcConnectionPool',
    logger: 'JinaLogger',
    metrics_registry: Optional['CollectorRegistry'] = None,
):
    """
    Get the app from FastAPI as the REST interface.

    :param args: passed arguments.
    :param topology_graph: topology graph that manages the logic of sending to the proper executors.
    :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them
    :param logger: Jina logger.
    :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler
    :return: fastapi app
    """
    with ImportExtensions(required=True):
        from fastapi import FastAPI
        from fastapi.middleware.cors import CORSMiddleware
        from fastapi.responses import HTMLResponse
        from starlette.requests import Request

        from jina.serve.runtimes.gateway.http.models import (
            JinaEndpointRequestModel,
            JinaRequestModel,
            JinaResponseModel,
            JinaStatusModel,
        )

    docs_url = '/docs'
    app = FastAPI(
        title=args.title or 'My Jina Service',
        description=args.description or
        'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` '
        'to customize this text.',
        version=__version__,
        docs_url=docs_url if args.default_swagger_ui else None,
    )

    if args.cors:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=['*'],
            allow_credentials=True,
            allow_methods=['*'],
            allow_headers=['*'],
        )
        logger.warning(
            'CORS is enabled. This service is now accessible from any website!'
        )

    from jina.serve.runtimes.gateway.request_handling import RequestHandler
    from jina.serve.stream import RequestStreamer

    request_handler = RequestHandler(metrics_registry, args.name)

    streamer = RequestStreamer(
        args=args,
        request_handler=request_handler.handle_request(
            graph=topology_graph, connection_pool=connection_pool),
        result_handler=request_handler.handle_result(),
    )
    streamer.Call = streamer.stream

    @app.on_event('shutdown')
    async def _shutdown():
        await connection_pool.close()

    openapi_tags = []
    if not args.no_debug_endpoints:
        openapi_tags.append({
            'name':
            'Debug',
            'description':
            'Debugging interface. In production, you should hide them by setting '
            '`--no-debug-endpoints` in `Flow`/`Gateway`.',
        })

        from jina.serve.runtimes.gateway.http.models import JinaHealthModel

        @app.get(
            path='/',
            summary='Get the health of Jina service',
            response_model=JinaHealthModel,
        )
        async def _health():
            """
            Get the health of this Jina service.
            .. # noqa: DAR201

            """
            return {}

        @app.get(
            path='/status',
            summary='Get the status of Jina service',
            response_model=JinaStatusModel,
            tags=['Debug'],
        )
        async def _status():
            """
            Get the status of this Jina service.

            This is equivalent to running `jina -vf` from command line.

            .. # noqa: DAR201
            """
            _info = get_full_version()
            return {
                'jina': _info[0],
                'envs': _info[1],
                'used_memory': used_memory_readable(),
            }

        @app.post(
            path='/post',
            summary='Post a data request to some endpoint',
            response_model=JinaResponseModel,
            tags=['Debug']
            # do not add response_model here, this debug endpoint should not restricts the response model
        )
        async def post(body: JinaEndpointRequestModel):
            """
            Post a data request to some endpoint.

            This is equivalent to the following:

                from jina import Flow

                f = Flow().add(...)

                with f:
                    f.post(endpoint, ...)

            .. # noqa: DAR201
            .. # noqa: DAR101
            """
            # The above comment is written in Markdown for better rendering in FastAPI
            from jina.enums import DataInputType

            bd = body.dict()  # type: Dict
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            result = await _get_singleton_result(
                request_generator(**req_generator_input))
            return result

    def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs):
        """Exposing an executor endpoint to http endpoint
        :param exec_endpoint: the executor endpoint
        :param http_path: the http endpoint
        :param kwargs: kwargs accepted by FastAPI
        """

        # set some default kwargs for richer semantics
        # group flow exposed endpoints into `customized` group
        kwargs['tags'] = kwargs.get('tags', ['Customized'])
        kwargs['response_model'] = kwargs.get(
            'response_model',
            JinaResponseModel,  # use standard response model by default
        )
        kwargs['methods'] = kwargs.get('methods', ['POST'])

        @app.api_route(path=http_path or exec_endpoint,
                       name=http_path or exec_endpoint,
                       **kwargs)
        async def foo(body: JinaRequestModel):
            from jina.enums import DataInputType

            bd = body.dict() if body else {'data': None}
            bd['exec_endpoint'] = exec_endpoint
            req_generator_input = bd
            req_generator_input['data_type'] = DataInputType.DICT
            if bd['data'] is not None and 'docs' in bd['data']:
                req_generator_input['data'] = req_generator_input['data'][
                    'docs']

            result = await _get_singleton_result(
                request_generator(**req_generator_input))
            return result

    if not args.no_crud_endpoints:
        openapi_tags.append({
            'name':
            'CRUD',
            'description':
            'CRUD interface. If your service does not implement those interfaces, you can should '
            'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.',
        })
        crud = {
            '/index': {
                'methods': ['POST']
            },
            '/search': {
                'methods': ['POST']
            },
            '/delete': {
                'methods': ['DELETE']
            },
            '/update': {
                'methods': ['PUT']
            },
        }
        for k, v in crud.items():
            v['tags'] = ['CRUD']
            v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.'
            expose_executor_endpoint(exec_endpoint=k, **v)

    if openapi_tags:
        app.openapi_tags = openapi_tags

    if args.expose_endpoints:
        endpoints = json.loads(args.expose_endpoints)  # type: Dict[str, Dict]
        for k, v in endpoints.items():
            expose_executor_endpoint(exec_endpoint=k, **v)

    if not args.default_swagger_ui:

        async def _render_custom_swagger_html(req: Request) -> HTMLResponse:
            import urllib.request

            swagger_url = 'https://api.jina.ai/swagger'
            req = urllib.request.Request(swagger_url,
                                         headers={'User-Agent': 'Mozilla/5.0'})
            with urllib.request.urlopen(req) as f:
                return HTMLResponse(f.read().decode())

        app.add_route(docs_url,
                      _render_custom_swagger_html,
                      include_in_schema=False)

    if args.expose_graphql_endpoint:
        with ImportExtensions(required=True):
            from dataclasses import asdict

            import strawberry
            from docarray import DocumentArray
            from docarray.document.strawberry_type import (
                JSONScalar,
                StrawberryDocument,
                StrawberryDocumentInput,
            )
            from strawberry.fastapi import GraphQLRouter

            async def get_docs_from_endpoint(data, target_executor, parameters,
                                             exec_endpoint):
                req_generator_input = {
                    'data': [asdict(d) for d in data],
                    'target_executor': target_executor,
                    'parameters': parameters,
                    'exec_endpoint': exec_endpoint,
                    'data_type': DataInputType.DICT,
                }

                if (req_generator_input['data'] is not None
                        and 'docs' in req_generator_input['data']):
                    req_generator_input['data'] = req_generator_input['data'][
                        'docs']

                response = await _get_singleton_result(
                    request_generator(**req_generator_input))
                return DocumentArray.from_dict(
                    response['data']).to_strawberry_type()

            @strawberry.type
            class Mutation:
                @strawberry.mutation
                async def docs(
                    self,
                    data: Optional[List[StrawberryDocumentInput]] = None,
                    target_executor: Optional[str] = None,
                    parameters: Optional[JSONScalar] = None,
                    exec_endpoint: str = '/search',
                ) -> List[StrawberryDocument]:
                    return await get_docs_from_endpoint(
                        data, target_executor, parameters, exec_endpoint)

            @strawberry.type
            class Query:
                @strawberry.field
                async def docs(
                    self,
                    data: Optional[List[StrawberryDocumentInput]] = None,
                    target_executor: Optional[str] = None,
                    parameters: Optional[JSONScalar] = None,
                    exec_endpoint: str = '/search',
                ) -> List[StrawberryDocument]:
                    return await get_docs_from_endpoint(
                        data, target_executor, parameters, exec_endpoint)

            schema = strawberry.Schema(query=Query, mutation=Mutation)
            app.include_router(GraphQLRouter(schema), prefix='/graphql')

    async def _get_singleton_result(request_iterator) -> Dict:
        """
        Streams results from AsyncPrefetchCall as a dict

        :param request_iterator: request iterator, with length of 1
        :return: the first result from the request iterator
        """
        async for k in streamer.stream(request_iterator=request_iterator):
            request_dict = k.to_dict()
            return request_dict

    return app