def test_no_suppress_other_exception(): with pytest.raises(Exception): with ImportExtensions(required=False, logger=default_logger): raise Exception with pytest.raises(Exception): with ImportExtensions(required=True, logger=default_logger): raise Exception
def mutate( self, mutation: str, variables: Optional[dict] = None, timeout: Optional[float] = None, headers: Optional[dict] = None, ): """Perform a GraphQL mutation :param mutation: the GraphQL mutation as a single string. :param variables: variables to be substituted in the mutation. Not needed if no variables are present in the mutation string. :param timeout: HTTP request timeout :param headers: HTTP headers :return: dict containing the optional keys ``data`` and ``errors``, for response data and errors. """ with ImportExtensions(required=True): from sgqlc.endpoint.http import HTTPEndpoint as SgqlcHTTPEndpoint proto = 'https' if self.args.tls else 'http' graphql_url = f'{proto}://{self.args.host}:{self.args.port}/graphql' endpoint = SgqlcHTTPEndpoint(graphql_url) res = endpoint(mutation, variables=variables, timeout=timeout, extra_headers=headers) if 'errors' in res and res['errors']: msg = 'GraphQL mutation returned the following errors: ' for err in res['errors']: msg += err['message'] + '. ' raise ConnectionError(msg) return res
def craft(self, buffer: bytes, uri: str, *args, **kwargs) -> Dict: """ Read image file and craft it into image matrix. Read the image from the given file path that specified in `buffer` and save the `ndarray` of the image in the `blob` of the document. :param buffer: the image in raw bytes :param uri: the image file path """ with ImportExtensions( required=True, verbose=True, pkg_name='Pillow', logger=self.logger, help_text='PIL is missing. Install it with `pip install Pillow`', ): from PIL import Image if buffer: raw_img = Image.open(io.BytesIO(buffer)) elif uri: raw_img = Image.open(uri) else: raise ValueError('no value found in "buffer" and "uri"') raw_img = raw_img.convert('RGB') img = np.array(raw_img).astype('float32') if self.channel_axis != -1: img = np.moveaxis(img, -1, self.channel_axis) return dict(blob=img)
def upload_file( url: str, file_name: str, buffer_data: bytes, dict_data: Dict, headers: Dict, stream: bool = False, method: str = 'post', ): """Upload file to target url :param url: target url :param file_name: the file name :param buffer_data: the data to upload :param dict_data: the dict-style data to upload :param headers: the request header :param stream: receive stream response :param method: the request method :return: the response of request """ with ImportExtensions(required=True): import requests dict_data.update({'file': (file_name, buffer_data)}) (data, ctype) = requests.packages.urllib3.filepost.encode_multipart_formdata( dict_data ) headers.update({'Content-Type': ctype}) response = getattr(requests, method)(url, data=data, headers=headers, stream=stream) return response
async def async_setup(self): """ Start the DataRequestHandler and wait for the GRPC and Monitoring servers to start """ if self.metrics_registry: with ImportExtensions( required=True, help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Summary self._summary_time = ( Summary( 'receiving_request_seconds', 'Time spent processing request', registry=self.metrics_registry, namespace='jina', labelnames=('runtime_name',), ) .labels(self.args.name) .time() ) else: self._summary_time = contextlib.nullcontext() await self._async_setup_grpc_server()
def __init__( self, logger: Optional[JinaLogger] = None, compression: Optional[str] = None, metrics_registry: Optional['CollectorRegistry'] = None, ): self._logger = logger or JinaLogger(self.__class__.__name__) self.compression = (getattr(grpc.Compression, compression) if compression else grpc.Compression.NoCompression) if metrics_registry: with ImportExtensions( required=True, help_text= 'You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Summary self._summary_time = Summary( 'sending_request_seconds', 'Time spent between sending a request to the Pod and receiving the response', registry=metrics_registry, namespace='jina', ).time() else: self._summary_time = contextlib.nullcontext() self._connections = self._ConnectionPoolMap(self._logger, self._summary_time) self._deployment_address_map = {}
def _init_monitoring(self, metrics_registry: Optional['CollectorRegistry'] = None): if metrics_registry: with ImportExtensions( required=True, help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Counter, Summary self._counter = Counter( 'document_processed', 'Number of Documents that have been processed by the executor', namespace='jina', labelnames=('executor_endpoint', 'executor', 'runtime_name'), registry=metrics_registry, ) self._request_size_metrics = Summary( 'request_size_bytes', 'The request size in Bytes', namespace='jina', labelnames=('executor_endpoint', 'executor', 'runtime_name'), registry=metrics_registry, ) else: self._counter = None self._request_size_metrics = None
def arg_wrapper(self, *args, **kwargs): if func.__name__ != '__init__': raise TypeError( 'this decorator should only be used on __init__ method of an executor' ) if self.__class__ == cls: file_lock = nullcontext() with ImportExtensions( required=False, help_text= f'FileLock is needed to guarantee non-concurrent initialization of replicas in the same ' f'machine.', ): import filelock locks_root = _get_locks_root() lock_file = locks_root.joinpath( f'{self.__class__.__name__}.lock') file_lock = filelock.FileLock(lock_file, timeout=-1) with file_lock: f = func(self, *args, **kwargs) return f else: return func(self, *args, **kwargs)
def __init__( self, metrics_registry: Optional['CollectorRegistry'] = None, runtime_name: Optional[str] = None, ): self.request_init_time = {} if metrics_registry else None self._executor_endpoint_mapping = None if metrics_registry: with ImportExtensions( required=True, help_text= 'You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Summary self._summary = Summary( 'receiving_request_seconds', 'Time spent processing request', registry=metrics_registry, namespace='jina', labelnames=('runtime_name', ), ).labels(runtime_name) else: self._summary = None
async def async_setup(self): """ The async method setup the runtime. Setup the uvicorn server. """ with ImportExtensions(required=True): from uvicorn import Config, Server class UviServer(Server): """The uvicorn server.""" async def setup(self, sockets=None): """ Setup uvicorn server. :param sockets: sockets of server. """ config = self.config if not config.loaded: config.load() self.lifespan = config.lifespan_class(config) self.install_signal_handlers() await self.startup(sockets=sockets) if self.should_exit: return async def serve(self, **kwargs): """ Start the server. :param kwargs: keyword arguments """ await self.main_loop() from jina.helper import extend_rest_interface uvicorn_kwargs = self.args.uvicorn_kwargs or {} for ssl_file in ['ssl_keyfile', 'ssl_certfile']: if getattr(self.args, ssl_file): if ssl_file not in uvicorn_kwargs.keys(): uvicorn_kwargs[ssl_file] = getattr(self.args, ssl_file) self._set_topology_graph() self._set_connection_pool() self._server = UviServer(config=Config( app=extend_rest_interface( get_fastapi_app( self.args, topology_graph=self._topology_graph, connection_pool=self._connection_pool, logger=self.logger, metrics_registry=self.metrics_registry, )), host=__default_host__, port=self.args.port, ws_max_size=1024 * 1024 * 1024, log_level=os.getenv('JINA_LOG_LEVEL', 'error').lower(), **uvicorn_kwargs)) await self._server.setup()
def hello_world(args): """ Execute the chatbot example. :param args: arguments passed from CLI """ Path(args.workdir).mkdir(parents=True, exist_ok=True) with ImportExtensions( required=True, help_text= 'this demo requires Pytorch and Transformers to be installed, ' 'if you haven\'t, please do `pip install jina[torch,transformers]`', ): import transformers, torch assert [torch, transformers] #: prevent pycharm auto remove the above line targets = { 'covid-csv': { 'url': args.index_data_url, 'filename': os.path.join(args.workdir, 'dataset.csv'), } } # download the data download_data(targets, args.download_proxy, task_name='download csv data') # now comes the real work # load index flow from a YAML file f = (Flow().add(uses=MyTransformer, parallel=args.parallel).add(uses=MyIndexer, workspace=args.workdir)) # index it! with f, open(targets['covid-csv']['filename']) as fp: f.index(DocumentArray.from_csv(fp, field_resolver={'question': 'text'})) # switch to REST gateway at runtime f.use_rest_gateway(args.port_expose) url_html_path = 'file://' + os.path.abspath( os.path.join(os.path.dirname(os.path.realpath(__file__)), 'static/index.html')) try: webbrowser.open(url_html_path, new=2) except: pass # intentional pass, browser support isn't cross-platform finally: default_logger.success( f'You should see a demo page opened in your browser, ' f'if not, you may open {url_html_path} manually') if not args.unblock_query_flow: f.block()
async def start(self): """Create ClientSession and enter context :return: self """ with ImportExtensions(required=True): import aiohttp self.session = aiohttp.ClientSession() await self.session.__aenter__() return self
def test_bad_import(): from jina.logging import default_logger with pytest.raises(ModuleNotFoundError): with ImportExtensions(required=True, logger=default_logger): import abcdefg # no install and unlist with pytest.raises(ModuleNotFoundError): with ImportExtensions(required=True, logger=default_logger): import ngt # list but no install with ImportExtensions(required=False, logger=default_logger) as ie: import ngt assert ie._tags == ['ngt', 'index', 'py37'] with ImportExtensions(required=False, logger=default_logger) as ie: import ngt.abc.edf assert ie._tags == ['ngt', 'index', 'py37'] with ImportExtensions(required=False, logger=default_logger) as ie: from ngt.abc import edf assert ie._tags == ['ngt', 'index', 'py37'] with ImportExtensions(required=False, logger=default_logger) as ie: import abcdefg assert not ie._tags
def _load_image(blob: 'np.ndarray', channel_axis: int): with ImportExtensions( required=True, pkg_name='Pillow', verbose=True, logger=self.logger, help_text= 'PIL is missing. Install it with `pip install Pillow`', ): from PIL import Image img = _move_channel_axis(blob, channel_axis) return Image.fromarray(img.astype('uint8'))
def wrapper(*args, **kwargs): call_hash = f'{func.__name__}({", ".join(map(str, args))})' pickle_protocol = 4 file_lock = nullcontext() with ImportExtensions( required=False, help_text= f'FileLock is needed to guarantee non-concurrent access to the' f'cache_file {cache_file}', ): import filelock file_lock = filelock.FileLock(f'{cache_file}.lock', timeout=-1) cache_db = None with file_lock: try: cache_db = shelve.open(cache_file, protocol=pickle_protocol, writeback=True) except Exception: if os.path.exists(cache_file): # cache is in an unsupported format, reset the cache os.remove(cache_file) cache_db = shelve.open(cache_file, protocol=pickle_protocol, writeback=True) if cache_db is None: # if we failed to load cache, do not raise, it is only an optimization thing return func(*args, **kwargs), False else: with cache_db as dict_db: try: if call_hash in dict_db and not kwargs.get( 'force', False): return dict_db[call_hash], True result = func(*args, **kwargs) dict_db[call_hash] = result except urllib.error.URLError: if call_hash in dict_db: default_logger.warning( message.format(func_name=func.__name__)) return dict_db[call_hash], True else: raise return result, False
def _load_image(blob: 'np.ndarray', channel_axis: int): """ Load an image array and return a `PIL.Image` object. """ with ImportExtensions( required=True, verbose=True, pkg_name='Pillow', help_text='PIL is missing. Install it with `pip install Pillow`', ): from PIL import Image img = _move_channel_axis(blob, channel_axis) return Image.fromarray(img.astype('uint8'))
def archive_package(package_folder: 'Path') -> 'io.BytesIO': """ Archives the given folder in zip format and return a data stream. :param package_folder: the folder path of the package :return: the data stream of zip content """ with ImportExtensions(required=True): import pathspec root_path = package_folder.resolve() gitignore = root_path / '.gitignore' if not gitignore.exists(): gitignore = Path(__resources_path__) / 'Python.gitignore' with gitignore.open() as fp: ignore_lines = [ line.strip() for line in fp if line.strip() and (not line.startswith('#')) ] ignore_lines += ['.git', '.jina'] ignored_spec = pathspec.PathSpec.from_lines('gitwildmatch', ignore_lines) zip_stream = io.BytesIO() try: zfile = zipfile.ZipFile(zip_stream, 'w', compression=zipfile.ZIP_DEFLATED) except EnvironmentError as e: raise e def _zip(base_path, path, archive): for p in path.iterdir(): rel_path = p.relative_to(base_path) if ignored_spec.match_file(str(rel_path)): continue if p.is_dir(): _zip(base_path, p, archive) else: archive.write(p, rel_path) _zip(root_path, root_path, zfile) zfile.close() zip_stream.seek(0) return zip_stream
def __init__(self, args: Optional[argparse.Namespace] = None, **kwargs): if args and isinstance(args, argparse.Namespace): self.args = args else: self.args = ArgNamespace.kwargs2namespace(kwargs, set_hub_parser()) self.logger = JinaLogger(self.__class__.__name__, **vars(args)) with ImportExtensions(required=True): import rich import cryptography import filelock assert rich #: prevent pycharm auto remove the above line assert cryptography assert filelock
def fetch_meta( name: str, tag: str, secret: Optional[str] = None, force: bool = False, ) -> HubExecutor: """Fetch the executor meta info from Jina Hub. :param name: the UUID/Name of the executor :param tag: the tag of the executor if available, otherwise, use `None` as the value :param secret: the access secret of the executor :param force: if set to True, access to fetch_meta will always pull latest Executor metas, otherwise, default to local cache :return: meta of executor .. note:: The `name` and `tag` should be passed via ``args`` and `force` and `secret` as ``kwargs``, otherwise, cache does not work. """ with ImportExtensions(required=True): import requests pull_url = get_hubble_url_v1() + f'/executors/{name}/?' path_params = {} if secret: path_params['secret'] = secret if tag: path_params['tag'] = tag if path_params: pull_url += urlencode(path_params) resp = requests.get(pull_url, headers=get_request_header()) if resp.status_code != 200: if resp.text: raise Exception(resp.text) resp.raise_for_status() resp = resp.json() return HubExecutor( uuid=resp['id'], name=resp.get('name', None), sn=resp.get('sn', None), tag=tag or resp['tag'], visibility=resp['visibility'], image_name=resp['image'], archive_url=resp['package']['download'], md5sum=resp['package']['md5'], )
async def request_generator( exec_endpoint: str, data: 'GeneratorSourceType', request_size: int = 0, data_type: DataInputType = DataInputType.AUTO, target_executor: Optional[str] = None, parameters: Optional[Dict] = None, **kwargs, # do not remove this, add on purpose to suppress unknown kwargs ) -> AsyncIterator['Request']: """An async :function:`request_generator`. :param exec_endpoint: the endpoint string, by convention starts with `/` :param data: the data to use in the request :param request_size: the number of Documents per request :param data_type: if ``data`` is an iterator over self-contained document, i.e. :class:`DocumentSourceType`; or an iterator over possible Document content (set to text, blob and buffer). :param parameters: the kwargs that will be sent to the executor :param target_executor: a regex string. Only matching Executors will process the request. :param kwargs: additional arguments :yield: request """ _kwargs = dict(extra_kwargs=kwargs) try: if data is None: # this allows empty inputs, i.e. a data request with only parameters yield _new_data_request(endpoint=exec_endpoint, target=target_executor, parameters=parameters) else: with ImportExtensions(required=True): import aiostream async for batch in aiostream.stream.chunks(data, request_size): yield _new_data_request_from_batch( _kwargs=kwargs, batch=batch, data_type=data_type, endpoint=exec_endpoint, target=target_executor, parameters=parameters, ) except Exception as ex: # must be handled here, as grpc channel wont handle Python exception default_logger.critical(f'inputs is not valid! {ex!r}', exc_info=True)
def _load_docker_client(self): with ImportExtensions(required=True): import docker.errors from docker import APIClient from jina import __windows__ try: self._client = docker.from_env() # low-level client self._raw_client = APIClient( base_url=docker.constants.DEFAULT_NPIPE if __windows__ else docker.constants.DEFAULT_UNIX_SOCKET) except docker.errors.DockerException: self.logger.critical( f'Docker daemon seems not running. Please run Docker daemon and try again.' ) exit(1)
def __init__( self, logger: Optional[JinaLogger] = None, compression: str = 'NoCompression', metrics_registry: Optional['CollectorRegistry'] = None, ): self._logger = logger or JinaLogger(self.__class__.__name__) GRPC_COMPRESSION_MAP = { 'NoCompression'.lower(): grpc.Compression.NoCompression, 'Gzip'.lower(): grpc.Compression.Gzip, 'Deflate'.lower(): grpc.Compression.Deflate, } if compression.lower() not in GRPC_COMPRESSION_MAP: import warnings warnings.warn( message= f'Your compression "{compression}" is not supported. Supported ' f'algorithms are `Gzip`, `Deflate` and `NoCompression`. NoCompression will be used as ' f'default') self.compression = GRPC_COMPRESSION_MAP.get( compression.lower(), grpc.Compression.NoCompression) if metrics_registry: with ImportExtensions( required=True, help_text= 'You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Summary self._summary_time = Summary( 'sending_request_seconds', 'Time spent between sending a request to the Pod and receiving the response', registry=metrics_registry, namespace='jina', ).time() else: self._summary_time = contextlib.nullcontext() self._connections = self._ConnectionPoolMap(self._logger, self._summary_time)
def __init__( self, args: 'argparse.Namespace', cancel_event: Optional[ Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event'] ] = None, **kwargs, ): super().__init__(args, **kwargs) self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) self.is_cancel = cancel_event or asyncio.Event() if not __windows__: # TODO: windows event loops don't support signal handlers try: for signame in {'SIGINT', 'SIGTERM'}: self._loop.add_signal_handler( getattr(signal, signame), lambda *args, **kwargs: self.is_cancel.set(), ) except (ValueError, RuntimeError) as exc: self.logger.warning( f' The runtime {self.__class__.__name__} will not be able to handle termination signals. ' f' {repr(exc)}' ) else: with ImportExtensions( required=True, logger=self.logger, help_text='''If you see a 'DLL load failed' error, please reinstall `pywin32`. If you're using conda, please use the command `conda install -c anaconda pywin32`''', ): import win32api win32api.SetConsoleCtrlHandler( lambda *args, **kwargs: self.is_cancel.set(), True ) self._setup_monitoring() self._loop.run_until_complete(self.async_setup())
async def _create_remote_pod(self): """Create Workspace, Pod on remote JinaD server""" with ImportExtensions(required=True): # rich & aiohttp are used in `AsyncJinaDClient` import rich import aiohttp from daemon.clients import AsyncJinaDClient assert rich assert aiohttp # NOTE: args.timeout_ready is always set to -1 for JinadRuntime so that wait_for_success doesn't fail in Pod, # so it can't be used for Client timeout. self.client = AsyncJinaDClient(host=self.args.host, port=self.args.port_jinad, logger=self._logger) if not await self.client.alive: raise DaemonConnectivityError # Create a remote workspace with upload_files workspace_id = await self.client.workspaces.create( paths=self.filepaths, id=self.args.workspace_id, complete=True, ) if not workspace_id: self._logger.critical(f'remote workspace creation failed') raise DaemonWorkspaceCreationFailed payload = replace_enum_to_str(vars(self._mask_args())) # Create a remote Pod in the above workspace success, response = await self.client.pods.create( workspace_id=workspace_id, payload=payload, envs=self.envs) if not success: self._logger.critical(f'remote pod creation failed') raise DaemonPodCreationFailed(response) else: self.pod_id = response
def _init_monitoring(self): if (hasattr(self.runtime_args, 'metrics_registry') and self.runtime_args.metrics_registry): with ImportExtensions( required=True, help_text= 'You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Summary self._summary_method = Summary( 'process_request_seconds', 'Time spent when calling the executor request method', registry=self.runtime_args.metrics_registry, namespace='jina', labelnames=('executor', 'executor_endpoint', 'runtime_name'), ) self._metrics_buffer = { 'process_request_seconds': self._summary_method } else: self._summary_method = None self._metrics_buffer = None
def download_with_resume( url: str, target_dir: 'Path', filename: Optional[str] = None, md5sum: Optional[str] = None, ) -> 'Path': """ Download file from url to target_dir, and check md5sum. Performs a HTTP(S) download that can be restarted if prematurely terminated. The HTTP server must support byte ranges. :param url: the URL to download :param target_dir: the target path for the file :param filename: the filename of the downloaded file :param md5sum: the MD5 checksum to match :return: the filepath of the downloaded file """ with ImportExtensions(required=True): import requests def _download(url, target, resume_byte_pos: int = None): resume_header = ({ 'Range': f'bytes={resume_byte_pos}-' } if resume_byte_pos else None) try: r = requests.get(url, stream=True, headers=resume_header) except requests.exceptions.RequestException as e: raise e block_size = 1024 mode = 'ab' if resume_byte_pos else 'wb' with target.open(mode=mode) as f: for chunk in r.iter_content(32 * block_size): f.write(chunk) if filename is None: filename = url.split('/')[-1] filepath = target_dir / filename head_info = requests.head(url) file_size_online = int(head_info.headers.get('content-length', 0)) _resume_byte_pos = None if filepath.exists(): if md5sum and md5file(filepath) == md5sum: return filepath file_size_offline = filepath.stat().st_size if file_size_online > file_size_offline: _resume_byte_pos = file_size_offline _download(url, filepath, _resume_byte_pos) if md5sum and not md5file(filepath) == md5sum: raise RuntimeError( 'MD5 checksum failed.' 'Might happen when the network is unstable, please retry.' 'If still not work, feel free to raise an issue.' 'https://github.com/jina-ai/jina/issues/new') return filepath
def hello_world(args): """ Execute the multimodal example. :param args: arguments passed from CLI """ Path(args.workdir).mkdir(parents=True, exist_ok=True) with ImportExtensions( required=True, help_text= 'this demo requires Pytorch and Transformers to be installed, ' 'if you haven\'t, please do `pip install jina[torch,transformers]`', ): import transformers, torch, torchvision assert [ torch, transformers, torchvision, ] #: prevent pycharm auto remove the above line # args.workdir = '0bae16ce-5bb2-43be-bcd4-6f1969e8068f' targets = { 'people-img': { 'url': args.index_data_url, 'filename': os.path.join(args.workdir, 'dataset.zip'), } } # download the data if not os.path.exists(targets['people-img']['filename']): download_data(targets, args.download_proxy, task_name='download zip data') with zipfile.ZipFile(targets['people-img']['filename'], 'r') as fp: fp.extractall(args.workdir) # this envs are referred in index and query flow YAMLs os.environ['HW_WORKDIR'] = args.workdir os.environ['PY_MODULE'] = os.path.abspath( os.path.join(cur_dir, 'my_executors.py')) # now comes the real work # load index flow from a YAML file # index it! f = Flow.load_config('flow-index.yml') with f, open(f'{args.workdir}/people-img/meta.csv', newline='') as fp: f.index(inputs=DocumentArray.from_csv(fp), request_size=10, show_progress=True) # search it! f = Flow.load_config('flow-search.yml') # switch to HTTP gateway f.protocol = 'http' f.port_expose = args.port_expose url_html_path = 'file://' + os.path.abspath( os.path.join(cur_dir, 'static/index.html')) with f: try: webbrowser.open(url_html_path, new=2) except: pass # intentional pass, browser support isn't cross-platform finally: default_logger.info( f'You should see a demo page opened in your browser, ' f'if not, you may open {url_html_path} manually') if not args.unblock_query_flow: f.block()
def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', metrics_registry: Optional['CollectorRegistry'] = None, ): """ Get the app from FastAPI as the REST interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ with ImportExtensions(required=True): from fastapi import FastAPI, Response, status from fastapi.middleware.cors import CORSMiddleware from jina.serve.runtimes.gateway.http.models import ( JinaEndpointRequestModel, JinaRequestModel, JinaResponseModel, ) app = FastAPI( title=args.title or 'My Jina Service', description=args.description or 'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` ' 'to customize the title and description.', version=__version__, ) if args.cors: app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) logger.warning( 'CORS is enabled. This service is accessible from any website!') from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer request_handler = RequestHandler(metrics_registry, args.name) streamer = RequestStreamer( args=args, request_handler=request_handler.handle_request( graph=topology_graph, connection_pool=connection_pool), result_handler=request_handler.handle_result(), ) streamer.Call = streamer.stream @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() openapi_tags = [] if not args.no_debug_endpoints: openapi_tags.append({ 'name': 'Debug', 'description': 'Debugging interface. In production, you should hide them by setting ' '`--no-debug-endpoints` in `Flow`/`Gateway`.', }) from jina.serve.runtimes.gateway.http.models import JinaHealthModel @app.get( path='/', summary='Get the health of Jina Gateway service', response_model=JinaHealthModel, ) async def _gateway_health(): """ Get the health of this Gateway service. .. # noqa: DAR201 """ return {} from docarray import DocumentArray from jina.proto import jina_pb2 from jina.serve.executors import __dry_run_endpoint__ from jina.serve.runtimes.gateway.http.models import ( PROTO_TO_PYDANTIC_MODELS, JinaInfoModel, ) from jina.types.request.status import StatusMessage @app.get( path='/dry_run', summary= 'Get the readiness of Jina Flow service, sends an empty DocumentArray to the complete Flow to ' 'validate connectivity', response_model=PROTO_TO_PYDANTIC_MODELS.StatusProto, ) async def _flow_health(): """ Get the health of the complete Flow service. .. # noqa: DAR201 """ da = DocumentArray() try: _ = await _get_singleton_result( request_generator( exec_endpoint=__dry_run_endpoint__, data=da, data_type=DataInputType.DOCUMENT, )) status_message = StatusMessage() status_message.set_code(jina_pb2.StatusProto.SUCCESS) return status_message.to_dict() except Exception as ex: status_message = StatusMessage() status_message.set_exception(ex) return status_message.to_dict(use_integers_for_enums=True) @app.get( path='/status', summary='Get the status of Jina service', response_model=JinaInfoModel, tags=['Debug'], ) async def _status(): """ Get the status of this Jina service. This is equivalent to running `jina -vf` from command line. .. # noqa: DAR201 """ version, env_info = get_full_version() for k, v in version.items(): version[k] = str(v) for k, v in env_info.items(): env_info[k] = str(v) return {'jina': version, 'envs': env_info} @app.post( path='/post', summary='Post a data request to some endpoint', response_model=JinaResponseModel, tags=['Debug'] # do not add response_model here, this debug endpoint should not restricts the response model ) async def post( body: JinaEndpointRequestModel, response: Response ): # 'response' is a FastAPI response, not a Jina response """ Post a data request to some endpoint. This is equivalent to the following: from jina import Flow f = Flow().add(...) with f: f.post(endpoint, ...) .. # noqa: DAR201 .. # noqa: DAR101 """ # The above comment is written in Markdown for better rendering in FastAPI from jina.enums import DataInputType bd = body.dict() # type: Dict req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] try: result = await _get_singleton_result( request_generator(**req_generator_input)) except InternalNetworkError as err: import grpc if err.code() == grpc.StatusCode.UNAVAILABLE: response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE elif err.code() == grpc.StatusCode.DEADLINE_EXCEEDED: response.status_code = status.HTTP_504_GATEWAY_TIMEOUT else: response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR result = bd # send back the request result['header'] = _generate_exception_header( err) # attach exception details to response header logger.error( f'Error while getting responses from deployments: {err.details()}' ) return result def _generate_exception_header(error: InternalNetworkError): import traceback from jina.proto.serializer import DataRequest exception_dict = { 'name': str(error.__class__), 'stacks': [ str(x) for x in traceback.extract_tb(error.og_exception.__traceback__) ], 'executor': '', } status_dict = { 'code': DataRequest().status.ERROR, 'description': error.details() if error.details() else '', 'exception': exception_dict, } header_dict = {'request_id': error.request_id, 'status': status_dict} return header_dict def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs): """Exposing an executor endpoint to http endpoint :param exec_endpoint: the executor endpoint :param http_path: the http endpoint :param kwargs: kwargs accepted by FastAPI """ # set some default kwargs for richer semantics # group flow exposed endpoints into `customized` group kwargs['tags'] = kwargs.get('tags', ['Customized']) kwargs['response_model'] = kwargs.get( 'response_model', JinaResponseModel, # use standard response model by default ) kwargs['methods'] = kwargs.get('methods', ['POST']) @app.api_route(path=http_path or exec_endpoint, name=http_path or exec_endpoint, **kwargs) async def foo(body: JinaRequestModel): from jina.enums import DataInputType bd = body.dict() if body else {'data': None} bd['exec_endpoint'] = exec_endpoint req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] result = await _get_singleton_result( request_generator(**req_generator_input)) return result if not args.no_crud_endpoints: openapi_tags.append({ 'name': 'CRUD', 'description': 'CRUD interface. If your service does not implement those interfaces, you can should ' 'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.', }) crud = { '/index': { 'methods': ['POST'] }, '/search': { 'methods': ['POST'] }, '/delete': { 'methods': ['DELETE'] }, '/update': { 'methods': ['PUT'] }, } for k, v in crud.items(): v['tags'] = ['CRUD'] v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.' expose_executor_endpoint(exec_endpoint=k, **v) if openapi_tags: app.openapi_tags = openapi_tags if args.expose_endpoints: endpoints = json.loads(args.expose_endpoints) # type: Dict[str, Dict] for k, v in endpoints.items(): expose_executor_endpoint(exec_endpoint=k, **v) if args.expose_graphql_endpoint: with ImportExtensions(required=True): from dataclasses import asdict import strawberry from docarray import DocumentArray from docarray.document.strawberry_type import ( JSONScalar, StrawberryDocument, StrawberryDocumentInput, ) from strawberry.fastapi import GraphQLRouter async def get_docs_from_endpoint(data, target_executor, parameters, exec_endpoint): req_generator_input = { 'data': [asdict(d) for d in data], 'target_executor': target_executor, 'parameters': parameters, 'exec_endpoint': exec_endpoint, 'data_type': DataInputType.DICT, } if (req_generator_input['data'] is not None and 'docs' in req_generator_input['data']): req_generator_input['data'] = req_generator_input['data'][ 'docs'] try: response = await _get_singleton_result( request_generator(**req_generator_input)) except InternalNetworkError as err: logger.error( f'Error while getting responses from deployments: {err.details()}' ) raise err # will be handled by Strawberry return DocumentArray.from_dict( response['data']).to_strawberry_type() @strawberry.type class Mutation: @strawberry.mutation async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) @strawberry.type class Query: @strawberry.field async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) schema = strawberry.Schema(query=Query, mutation=Mutation) app.include_router(GraphQLRouter(schema), prefix='/graphql') async def _get_singleton_result(request_iterator) -> Dict: """ Streams results from AsyncPrefetchCall as a dict :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ async for k in streamer.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict return app
def run( args: 'argparse.Namespace', name: str, container_name: str, net_mode: Optional[str], runtime_ctrl_address: str, envs: Dict, is_started: Union['multiprocessing.Event', 'threading.Event'], is_shutdown: Union['multiprocessing.Event', 'threading.Event'], is_ready: Union['multiprocessing.Event', 'threading.Event'], ): """Method to be run in a process that stream logs from a Container This method is the target for the Pod's `thread` or `process` .. note:: :meth:`run` is running in subprocess/thread, the exception can not be propagated to the main process. Hence, please do not raise any exception here. .. note:: Please note that env variables are process-specific. Subprocess inherits envs from the main process. But Subprocess's envs do NOT affect the main process. It does NOT mess up user local system envs. :param args: namespace args from the Pod :param name: name of the Pod to have proper logging :param container_name: name to set the Container to :param net_mode: The network mode where to run the container :param runtime_ctrl_address: The control address of the runtime in the container :param envs: Dictionary of environment variables to be set in the docker image :param is_started: concurrency event to communicate runtime is properly started. Used for better logging :param is_shutdown: concurrency event to communicate runtime is terminated :param is_ready: concurrency event to communicate runtime is ready to receive messages """ import docker log_kwargs = copy.deepcopy(vars(args)) log_kwargs['log_config'] = 'docker' logger = JinaLogger(name, **log_kwargs) cancel = threading.Event() fail_to_start = threading.Event() if not __windows__: try: for signame in {signal.SIGINT, signal.SIGTERM}: signal.signal(signame, lambda *args, **kwargs: cancel.set()) except (ValueError, RuntimeError) as exc: logger.warning( f' The process starting the container for {name} will not be able to handle termination signals. ' f' {repr(exc)}') else: with ImportExtensions( required=True, logger=logger, help_text= '''If you see a 'DLL load failed' error, please reinstall `pywin32`. If you're using conda, please use the command `conda install -c anaconda pywin32`''', ): import win32api win32api.SetConsoleCtrlHandler(lambda *args, **kwargs: cancel.set(), True) client = docker.from_env() try: container = _docker_run( client=client, args=args, container_name=container_name, envs=envs, net_mode=net_mode, logger=logger, ) client.close() def _is_ready(): return AsyncNewLoopRuntime.is_ready(runtime_ctrl_address) def _is_container_alive(container) -> bool: import docker.errors try: container.reload() except docker.errors.NotFound: return False return True async def _check_readiness(container): while (_is_container_alive(container) and not _is_ready() and not cancel.is_set()): await asyncio.sleep(0.1) if _is_container_alive(container): is_started.set() is_ready.set() else: fail_to_start.set() async def _stream_starting_logs(container): for line in container.logs(stream=True): if (not is_started.is_set() and not fail_to_start.is_set() and not cancel.is_set()): await asyncio.sleep(0.01) msg = line.decode().rstrip() # type: str logger.debug(re.sub(r'\u001b\[.*?[@-~]', '', msg)) async def _run_async(container): await asyncio.gather(*[ _check_readiness(container), _stream_starting_logs(container) ]) asyncio.run(_run_async(container)) finally: client.close() if not is_started.is_set(): logger.error( f' Process terminated, the container fails to start, check the arguments or entrypoint' ) is_shutdown.set() logger.debug(f'process terminated')
def get_fastapi_app( args: 'argparse.Namespace', topology_graph: 'TopologyGraph', connection_pool: 'GrpcConnectionPool', logger: 'JinaLogger', metrics_registry: Optional['CollectorRegistry'] = None, ): """ Get the app from FastAPI as the REST interface. :param args: passed arguments. :param topology_graph: topology graph that manages the logic of sending to the proper executors. :param connection_pool: Connection Pool to handle multiple replicas and sending to different of them :param logger: Jina logger. :param metrics_registry: optional metrics registry for prometheus used if we need to expose metrics from the executor or from the data request handler :return: fastapi app """ with ImportExtensions(required=True): from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from starlette.requests import Request from jina.serve.runtimes.gateway.http.models import ( JinaEndpointRequestModel, JinaRequestModel, JinaResponseModel, JinaStatusModel, ) docs_url = '/docs' app = FastAPI( title=args.title or 'My Jina Service', description=args.description or 'This is my awesome service. You can set `title` and `description` in your `Flow` or `Gateway` ' 'to customize this text.', version=__version__, docs_url=docs_url if args.default_swagger_ui else None, ) if args.cors: app.add_middleware( CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'], ) logger.warning( 'CORS is enabled. This service is now accessible from any website!' ) from jina.serve.runtimes.gateway.request_handling import RequestHandler from jina.serve.stream import RequestStreamer request_handler = RequestHandler(metrics_registry, args.name) streamer = RequestStreamer( args=args, request_handler=request_handler.handle_request( graph=topology_graph, connection_pool=connection_pool), result_handler=request_handler.handle_result(), ) streamer.Call = streamer.stream @app.on_event('shutdown') async def _shutdown(): await connection_pool.close() openapi_tags = [] if not args.no_debug_endpoints: openapi_tags.append({ 'name': 'Debug', 'description': 'Debugging interface. In production, you should hide them by setting ' '`--no-debug-endpoints` in `Flow`/`Gateway`.', }) from jina.serve.runtimes.gateway.http.models import JinaHealthModel @app.get( path='/', summary='Get the health of Jina service', response_model=JinaHealthModel, ) async def _health(): """ Get the health of this Jina service. .. # noqa: DAR201 """ return {} @app.get( path='/status', summary='Get the status of Jina service', response_model=JinaStatusModel, tags=['Debug'], ) async def _status(): """ Get the status of this Jina service. This is equivalent to running `jina -vf` from command line. .. # noqa: DAR201 """ _info = get_full_version() return { 'jina': _info[0], 'envs': _info[1], 'used_memory': used_memory_readable(), } @app.post( path='/post', summary='Post a data request to some endpoint', response_model=JinaResponseModel, tags=['Debug'] # do not add response_model here, this debug endpoint should not restricts the response model ) async def post(body: JinaEndpointRequestModel): """ Post a data request to some endpoint. This is equivalent to the following: from jina import Flow f = Flow().add(...) with f: f.post(endpoint, ...) .. # noqa: DAR201 .. # noqa: DAR101 """ # The above comment is written in Markdown for better rendering in FastAPI from jina.enums import DataInputType bd = body.dict() # type: Dict req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] result = await _get_singleton_result( request_generator(**req_generator_input)) return result def expose_executor_endpoint(exec_endpoint, http_path=None, **kwargs): """Exposing an executor endpoint to http endpoint :param exec_endpoint: the executor endpoint :param http_path: the http endpoint :param kwargs: kwargs accepted by FastAPI """ # set some default kwargs for richer semantics # group flow exposed endpoints into `customized` group kwargs['tags'] = kwargs.get('tags', ['Customized']) kwargs['response_model'] = kwargs.get( 'response_model', JinaResponseModel, # use standard response model by default ) kwargs['methods'] = kwargs.get('methods', ['POST']) @app.api_route(path=http_path or exec_endpoint, name=http_path or exec_endpoint, **kwargs) async def foo(body: JinaRequestModel): from jina.enums import DataInputType bd = body.dict() if body else {'data': None} bd['exec_endpoint'] = exec_endpoint req_generator_input = bd req_generator_input['data_type'] = DataInputType.DICT if bd['data'] is not None and 'docs' in bd['data']: req_generator_input['data'] = req_generator_input['data'][ 'docs'] result = await _get_singleton_result( request_generator(**req_generator_input)) return result if not args.no_crud_endpoints: openapi_tags.append({ 'name': 'CRUD', 'description': 'CRUD interface. If your service does not implement those interfaces, you can should ' 'hide them by setting `--no-crud-endpoints` in `Flow`/`Gateway`.', }) crud = { '/index': { 'methods': ['POST'] }, '/search': { 'methods': ['POST'] }, '/delete': { 'methods': ['DELETE'] }, '/update': { 'methods': ['PUT'] }, } for k, v in crud.items(): v['tags'] = ['CRUD'] v['description'] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.' expose_executor_endpoint(exec_endpoint=k, **v) if openapi_tags: app.openapi_tags = openapi_tags if args.expose_endpoints: endpoints = json.loads(args.expose_endpoints) # type: Dict[str, Dict] for k, v in endpoints.items(): expose_executor_endpoint(exec_endpoint=k, **v) if not args.default_swagger_ui: async def _render_custom_swagger_html(req: Request) -> HTMLResponse: import urllib.request swagger_url = 'https://api.jina.ai/swagger' req = urllib.request.Request(swagger_url, headers={'User-Agent': 'Mozilla/5.0'}) with urllib.request.urlopen(req) as f: return HTMLResponse(f.read().decode()) app.add_route(docs_url, _render_custom_swagger_html, include_in_schema=False) if args.expose_graphql_endpoint: with ImportExtensions(required=True): from dataclasses import asdict import strawberry from docarray import DocumentArray from docarray.document.strawberry_type import ( JSONScalar, StrawberryDocument, StrawberryDocumentInput, ) from strawberry.fastapi import GraphQLRouter async def get_docs_from_endpoint(data, target_executor, parameters, exec_endpoint): req_generator_input = { 'data': [asdict(d) for d in data], 'target_executor': target_executor, 'parameters': parameters, 'exec_endpoint': exec_endpoint, 'data_type': DataInputType.DICT, } if (req_generator_input['data'] is not None and 'docs' in req_generator_input['data']): req_generator_input['data'] = req_generator_input['data'][ 'docs'] response = await _get_singleton_result( request_generator(**req_generator_input)) return DocumentArray.from_dict( response['data']).to_strawberry_type() @strawberry.type class Mutation: @strawberry.mutation async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) @strawberry.type class Query: @strawberry.field async def docs( self, data: Optional[List[StrawberryDocumentInput]] = None, target_executor: Optional[str] = None, parameters: Optional[JSONScalar] = None, exec_endpoint: str = '/search', ) -> List[StrawberryDocument]: return await get_docs_from_endpoint( data, target_executor, parameters, exec_endpoint) schema = strawberry.Schema(query=Query, mutation=Mutation) app.include_router(GraphQLRouter(schema), prefix='/graphql') async def _get_singleton_result(request_iterator) -> Dict: """ Streams results from AsyncPrefetchCall as a dict :param request_iterator: request iterator, with length of 1 :return: the first result from the request iterator """ async for k in streamer.stream(request_iterator=request_iterator): request_dict = k.to_dict() return request_dict return app