def get_model(module, serve_model=False): import_name = prepare_import(module) try: module = importlib.import_module(import_name) except Exception as ex: logger.error(traceback.format_exc()) return None serve_models = [ v for v in module.__dict__.values() if isinstance(v, ServeModel) ] if len(serve_models) > 0 and serve_model: serve_model = serve_models[0] return serve_model apps = [v for v in module.__dict__.values() if isinstance(v, MLServer)] if len(apps) > 0: return apps[0] if len(serve_models) > 0: return serve_models[0] # Could not find model logger.debug("Could not find ServeModel") serve_models = [ v for v in module.__dict__.values() if not isinstance(v, type) ] if len(serve_models) > 0 and serve_model: serve_model = ServeModel(serve_models[-1]) return serve_model logger.error( "Could not find any instance to serve. So please check again the mlconfig.yaml or server file!" ) return None
def get_model(module, serve_model=False, queue=None, trace=False, name=None): import_name = prepare_import(module) module = importlib.import_module(import_name) serve_models = [ v for v in module.__dict__.values() if isinstance(v, ServeModel) ] if len(serve_models) > 0 and serve_model: serve_model = serve_models[0] if queue == 'rabbit': logger.debug('load rabbit {0}'.format(serve_model)) from mlchain.queue.rabbit_queue import RabbitQueue if not isinstance(serve_model, RabbitQueue): serve_model = RabbitQueue(serve_model, module_name=name, trace=trace) serve_model.run(threading=True) elif queue == 'redis': logger.debug('load redis {0}'.format(serve_model)) from mlchain.queue.redis_queue import RedisQueue if not isinstance(serve_model, RedisQueue): serve_model = RedisQueue(serve_model, module_name=name, trace=trace) serve_model.run(threading=True) return serve_model apps = [v for v in module.__dict__.values() if isinstance(v, MLServer)] if len(apps) > 0: return apps[0] if len(serve_models) > 0: return serve_models[0] # Could not find model logger.debug("Could not find ServeModel") serve_models = [ v for v in module.__dict__.values() if not isinstance(v, type) ] if len(serve_models) > 0 and serve_model: serve_model = ServeModel(serve_models[-1]) if queue == 'rabbit': logger.debug('Load rabbit {0}'.format(serve_model)) from mlchain.queue.rabbit_queue import RabbitQueue if not isinstance(serve_model, RabbitQueue): serve_model = RabbitQueue(serve_model, module_name=name, trace=trace) serve_model.run(threading=True) elif queue == 'redis': logger.debug('load redis {0}'.format(serve_model)) from mlchain.queue.redis_queue import RedisQueue if not isinstance(serve_model, RedisQueue): serve_model = RedisQueue(serve_model, module_name=name, trace=trace) serve_model.run(threading=True) return serve_model logger.error("Could not find any instance to serve") return None
def __call__(self, function_name, **kws): with push_scope() as scope: transaction_name = "{0} || {1}".format( mlconfig.MLCHAIN_SERVER_NAME, function_name) scope.transaction = transaction_name with start_transaction(op="task", name=transaction_name): uid = self.init_context() request_context = {'api_version': self.server.version} try: headers, form, files, data = self.parse_data() mlchain_context['REQUESTS_HEADERS'] = headers mlchain_context['REQUESTS_FORM'] = form mlchain_context['REQUESTS_FILES'] = files mlchain_context['REQUESTS_DATA'] = data except Exception as ex: request_context['time_process'] = 0 output = self.normalize_output(self.base_format, function_name, {}, None, ex, request_context) return self.make_response(output) formatter = self.get_format(headers, form, files, data) start_time = time.time() try: if self.authentication is not None: self.authentication.check(headers) args, kwargs = formatter.parse_request( function_name, headers, form, files, data, request_context) func = self.server.model.get_function(function_name) kwargs = self.server.get_kwargs(func, *args, **kwargs) kwargs = self.server._normalize_kwargs_to_valid_format( kwargs, func) uid = self.init_context_with_headers(headers, uid) scope.set_tag("transaction_id", uid) logger.debug("Mlchain transaction id: {0}".format(uid)) output = self.server.model.call_function( function_name, uid, **kwargs) exception = None except MlChainError as ex: exception = ex output = None except Exception as ex: exception = ex output = None time_process = time.time() - start_time request_context['time_process'] = time_process output = self.normalize_output(formatter, function_name, headers, output, exception, request_context) return self.make_response(output)
def run_command(entry_file, host, port, bind, wrapper, server, workers, config, name, mode, api_format, ngrok, kws): kws = list(kws) if isinstance(entry_file, str) and not os.path.exists(entry_file): kws = [f'--entry_file={entry_file}'] + kws entry_file = None from mlchain import config as mlconfig default_config = False if config is None: default_config = True config = 'mlconfig.yaml' config_path = copy.deepcopy(config) if os.path.isfile(config_path) and os.path.exists(config_path): config = mlconfig.load_file(config_path) if config is None: raise SystemExit( "Config file {0} are not supported".format(config_path)) else: if not default_config: raise SystemExit("Can't find config file {0}".format(config_path)) else: raise SystemExit( "Can't find mlchain config file. Please double check your current working directory. Or use `mlchain init` to initialize a new ones here." ) if 'mode' in config and 'env' in config['mode']: if mode in config['mode']['env']: config['mode']['default'] = mode elif mode is not None: available_mode = list(config['mode']['env'].keys()) available_mode = [ each for each in available_mode if each != 'default' ] raise SystemExit( f"No {mode} mode are available. Found these mode in config file: {available_mode}" ) mlconfig.load_config(config) for kw in kws: if kw.startswith('--'): tokens = kw[2:].split('=', 1) if len(tokens) == 2: key, value = tokens mlconfig.mlconfig.update({key: value}) else: raise AssertionError("Unexpected param {0}".format(kw)) else: raise AssertionError("Unexpected param {0}".format(kw)) model_id = mlconfig.get_value(None, config, 'model_id', None) entry_file = mlconfig.get_value(entry_file, config, 'entry_file', 'server.py') if entry_file.strip() == '': raise SystemExit(f"Entry file cannot be empty") if not os.path.exists(entry_file): raise SystemExit( f"Entry file {entry_file} not found in current working directory.") host = mlconfig.get_value(host, config, 'host', 'localhost') port = mlconfig.get_value(port, config, 'port', 5000) server = mlconfig.get_value(server, config, 'server', 'flask') if len(bind) == 0: bind = None bind = mlconfig.get_value(bind, config, 'bind', []) wrapper = mlconfig.get_value(wrapper, config, 'wrapper', None) if wrapper == 'gunicorn' and os.name == 'nt': logger.warning( 'Gunicorn warper are not supported on Windows. Switching to None instead.' ) wrapper = None workers = None if 'gunicorn' in config: workers = mlconfig.get_value(workers, config['gunicorn'], 'workers', None) if workers is None and 'hypercorn' in config.keys(): workers = mlconfig.get_value(workers, config['hypercorn'], 'workers', None) workers = int(workers) if workers is not None else 1 name = mlconfig.get_value(name, config, 'name', None) cors = mlconfig.get_value(None, config, 'cors', False) static_folder = mlconfig.get_value(None, config, 'static_folder', None) static_url_path = mlconfig.get_value(None, config, 'static_url_path', None) template_folder = mlconfig.get_value(None, config, 'template_folder', None) version = mlconfig.get_value(None, config, 'version', '0.0') version = str(version) api_format = mlconfig.get_value(api_format, config, 'api_format', None) api_keys = os.getenv('API_KEYS', None) if api_keys is not None: api_keys = api_keys.split(';') api_keys = mlconfig.get_value(api_keys, config, 'api_keys', None) if api_keys is None: authentication = None else: authentication = Authentication(api_keys) import logging logging.root = logging.getLogger(name) logger.debug( dict(entry_file=entry_file, host=host, port=port, bind=bind, wrapper=wrapper, server=server, workers=workers, name=name, mode=mode, api_format=api_format, kws=kws)) bind = list(bind) if ngrok: from pyngrok import ngrok as pyngrok endpoint = pyngrok.connect(port=port) logger.info("Ngrok url: {0}".format(endpoint)) os.environ['NGROK_URL'] = endpoint if server == 'grpc': from mlchain.server.grpc_server import GrpcServer app = get_model(entry_file, serve_model=True) if app is None: raise Exception( "Can not init model class from {0}. Please check mlconfig.yaml or {0} or mlchain run -m {{mode}}!" .format(entry_file)) app = GrpcServer(app, name=name) app.run(host, port) elif wrapper == 'gunicorn': from gunicorn.app.base import BaseApplication gpus = select_gpu() class GunicornWrapper(BaseApplication): def __init__(self, server_, **kwargs): assert server_.lower() in ['quart', 'flask'] self.server = server_.lower() self.options = kwargs self.autofrontend = False super(GunicornWrapper, self).__init__() def load_config(self): config = { key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None } for key, value in config.items(): self.cfg.set(key.lower(), value) from mlchain.base.gunicorn_config import post_worker_init self.cfg.set("post_worker_init", post_worker_init) def load(self): original_cuda_variable = os.environ.get('CUDA_VISIBLE_DEVICES') if original_cuda_variable is None: os.environ['CUDA_VISIBLE_DEVICES'] = str(next(gpus)) else: logger.info( f"Skipping automatic GPU selection for gunicorn worker since CUDA_VISIBLE_DEVICES environment variable is already set to {original_cuda_variable}" ) serve_model = get_model(entry_file, serve_model=True) if serve_model is None: raise Exception( f"Can not init model class from {entry_file}. Please check mlconfig.yaml or {entry_file} or mlchain run -m {{mode}}!" ) if isinstance(serve_model, ServeModel): if (not self.autofrontend) and model_id is not None: from mlchain.server.autofrontend import register_autofrontend register_autofrontend(model_id, serve_model=serve_model, version=version, endpoint=os.getenv('NGROK_URL')) self.autofrontend = True if self.server == 'flask': from mlchain.server.flask_server import FlaskServer app = FlaskServer(serve_model, name=name, api_format=api_format, version=version, authentication=authentication, static_url_path=static_url_path, static_folder=static_folder, template_folder=template_folder) app.register_swagger() if cors: from flask_cors import CORS CORS(app.app) return app.app if self.server == 'quart': from mlchain.server.quart_server import QuartServer app = QuartServer(serve_model, name=name, api_format=api_format, version=version, authentication=authentication, static_url_path=static_url_path, static_folder=static_folder, template_folder=template_folder) app.register_swagger() if cors: from quart_cors import cors as CORS CORS(app.app) return app.app return None if host is not None and port is not None: bind.append('{0}:{1}'.format(host, port)) bind = list(set(bind)) gunicorn_config = config.get('gunicorn', {}) gunicorn_env = ['worker_class', 'threads', 'workers'] if workers is not None: gunicorn_config['workers'] = workers for k in gunicorn_env: if get_env(k) in os.environ: gunicorn_config[k] = os.environ[get_env(k)] if server == 'flask' and 'worker_class' in gunicorn_config: if 'uvicorn' in gunicorn_config['worker_class']: logger.warning( "Can't use flask with uvicorn. change to gthread") gunicorn_config['worker_class'] = 'gthread' GunicornWrapper(server, bind=bind, **gunicorn_config).run() elif wrapper == 'hypercorn' and server == 'quart': from mlchain.server.quart_server import QuartServer app = get_model(entry_file, serve_model=True) if app is None: raise Exception( "Can not init model class from {0}. Please check mlconfig.yaml or {0} or mlchain run -m {{mode}}!" .format(entry_file)) app = QuartServer(app, name=name, version=version, api_format=api_format, authentication=authentication, static_url_path=static_url_path, static_folder=static_folder, template_folder=template_folder) app.run(host, port, bind=bind, cors=cors, gunicorn=False, hypercorn=True, **config.get('hypercorn', {}), model_id=model_id) app = get_model(entry_file) if app is None: raise Exception( "Can not init model class from {0}. Please check mlconfig.yaml or {0} or mlchain run -m {{mode}}!" .format(entry_file)) if isinstance(app, MLServer): if app.__class__.__name__ == 'FlaskServer': app.run(host, port, cors=cors, gunicorn=False) elif app.__class__.__name__ == 'QuartServer': app.run(host, port, cors=cors, gunicorn=False, hypercorn=False) elif app.__class__.__name__ == 'GrpcServer': app.run(host, port) elif isinstance(app, ServeModel): if server not in ['quart', 'grpc']: server = 'flask' if server == 'flask': from mlchain.server.flask_server import FlaskServer app = FlaskServer(app, name=name, api_format=api_format, version=version, authentication=authentication, static_url_path=static_url_path, static_folder=static_folder, template_folder=template_folder) app.run(host, port, cors=cors, gunicorn=False, model_id=model_id, threads=workers > 1) elif server == 'quart': from mlchain.server.quart_server import QuartServer app = QuartServer(app, name=name, api_format=api_format, version=version, authentication=authentication, static_url_path=static_url_path, static_folder=static_folder, template_folder=template_folder) app.run(host, port, cors=cors, gunicorn=False, hypercorn=False, model_id=model_id, workers=workers) elif server == 'grpc': from mlchain.server.grpc_server import GrpcServer app = GrpcServer(app, name=name) app.run(host, port)
def run_command(entry_file, host, port, bind, wrapper, server, queue, workers, trace, log, config, name, mode, kws): from mlchain import config as mlconfig default_config = False if config is None: default_config = True config = 'mlconfig.yaml' if os.path.isfile(config): if config.endswith('.json'): config = mlconfig.load_json(config) elif config.endswith('.yaml') or config.endswith('.yml'): config = mlconfig.load_yaml(config) else: raise AssertionError("Not support file config {0}".format(config)) else: if not default_config: raise FileNotFoundError("Not found file {0}".format(config)) config = {} if 'mode' in config and 'env' in config['mode']: if mode in config['mode']['env']: config['mode']['default'] = mode mlconfig.load_config(config) entry_file = mlconfig.get_value(entry_file, config, 'entry_file', 'server.py') host = mlconfig.get_value(host, config, 'host', 'localhost') port = mlconfig.get_value(port, config, 'port', 5000) server = mlconfig.get_value(server, config, 'server', 'flask') if len(bind) == 0: bind = None bind = mlconfig.get_value(bind, config, 'bind', []) wrapper = mlconfig.get_value(wrapper, config, 'wrapper', None) queue = mlconfig.get_value(queue, config, 'queue', None) trace = mlconfig.get_value(trace, config, 'trace', False) log = mlconfig.get_value(log, config, 'log', False) workers = mlconfig.get_value(workers, config, 'workers', 1) name = mlconfig.get_value(name, config, 'name', None) monitor_sampling_rate = mlconfig.get_value(None, config, 'monitor_sampling_rate', 1) cors = mlconfig.get_value(False, config, 'cors', False) logger.debug( dict(entry_file=entry_file, host=host, port=port, bind=bind, wrapper=wrapper, server=server, queue=queue, workers=workers, log=log, name=name, trace=trace, kws=kws)) bind = list(bind) if server == 'grpc': from mlchain.rpc.server.grpc_server import GrpcServer app = get_model(entry_file, serve_model=True, queue=queue, trace=trace) app = GrpcServer(app, name=name, trace=trace, log=log, monitor_sampling_rate=monitor_sampling_rate) app.run(host, port) elif wrapper == 'gunicorn': from gunicorn.app.base import BaseApplication class GunicornWrapper(BaseApplication): def __init__(self, server_, **kwargs): assert server_.lower() in ['quart', 'flask'] self.server = server_.lower() self.options = kwargs super(GunicornWrapper, self).__init__() def load_config(self): config = { key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None } for key, value in config.items(): self.cfg.set(key.lower(), value) def load(self): app = get_model(entry_file, serve_model=True, queue=queue, trace=trace, name=name) if isinstance(app, ServeModel): if self.server == 'flask': from mlchain.rpc.server.flask_server import FlaskServer app = FlaskServer( app, name=name, log=log, trace=trace, monitor_sampling_rate=monitor_sampling_rate) app.register_swagger(host, port) if cors: from flask_cors import CORS CORS(app.app) return app.app elif self.server == 'quart': from mlchain.rpc.server.quart_server import QuartServer app = QuartServer( app, name=name, log=log, trace=trace, monitor_sampling_rate=monitor_sampling_rate) if cors: from quart_cors import cors as CORS CORS(app.app) return app.app return None if host is not None and port is not None: bind.append('{0}:{1}'.format(host, port)) bind = list(set(bind)) gunicorn_config = config.get('gunicorn', {}) gunicorn_env = ['worker_class', 'threads', 'workers'] gunicorn_config['workers'] = workers def get_env(_k): return 'GUNICORN_' + _k.upper() for k in gunicorn_env: if get_env(k) in os.environ: gunicorn_config[k] = os.environ[get_env(k)] GunicornWrapper(server, bind=bind, **gunicorn_config).run() elif wrapper == 'hypercorn' and server == 'quart': from mlchain.rpc.server.quart_server import QuartServer app = get_model(entry_file, serve_model=True, queue=queue, trace=trace, name=name) app = QuartServer(app, name=name, log=log, trace=trace, monitor_sampling_rate=monitor_sampling_rate) app.run(host, port, bind=bind, cors=cors, gunicorn=False, hypercorn=True, **config.get('hypercorn', {})) app = get_model(entry_file, queue=queue, trace=trace, serve_model=True, name=name) if isinstance(app, MLServer): if app.__class__.__name__ == 'FlaskServer': app.run(host, port, cors=cors, gunicorn=False) elif app.__class__.__name__ == 'QuartServer': app.run(host, port, cors=cors, gunicorn=False, hypercorn=False) elif app.__class__.__name__ == 'GrpcServer': app.run(host, port) elif isinstance(app, ServeModel): if server not in ['quart', 'grpc']: server = 'flask' if server == 'flask': from mlchain.rpc.server.flask_server import FlaskServer app = FlaskServer(app, name=name, trace=trace, log=log, monitor_sampling_rate=monitor_sampling_rate) app.run(host, port, cors=cors, gunicorn=False) elif server == 'quart': from mlchain.rpc.server.quart_server import QuartServer app = QuartServer(app, name=name, trace=trace, log=log, monitor_sampling_rate=monitor_sampling_rate) app.run(host, port, cors=cors, gunicorn=False, hypercorn=False) elif server == 'grpc': from mlchain.rpc.server.grpc_server import GrpcServer app = GrpcServer(app, name=name, trace=trace, log=log, monitor_sampling_rate=monitor_sampling_rate) app.run(host, port)