コード例 #1
0
ファイル: run.py プロジェクト: dd-dos/mlchain-python
def get_model(module, serve_model=False):
    import_name = prepare_import(module)

    try:
        module = importlib.import_module(import_name)
    except Exception as ex:
        logger.error(traceback.format_exc())
        return None

    serve_models = [
        v for v in module.__dict__.values() if isinstance(v, ServeModel)
    ]
    if len(serve_models) > 0 and serve_model:
        serve_model = serve_models[0]
        return serve_model
    apps = [v for v in module.__dict__.values() if isinstance(v, MLServer)]
    if len(apps) > 0:
        return apps[0]
    if len(serve_models) > 0:
        return serve_models[0]

    # Could not find model
    logger.debug("Could not find ServeModel")
    serve_models = [
        v for v in module.__dict__.values() if not isinstance(v, type)
    ]
    if len(serve_models) > 0 and serve_model:
        serve_model = ServeModel(serve_models[-1])
        return serve_model

    logger.error(
        "Could not find any instance to serve. So please check again the mlconfig.yaml or server file!"
    )
    return None
コード例 #2
0
def get_model(module, serve_model=False, queue=None, trace=False, name=None):
    import_name = prepare_import(module)

    module = importlib.import_module(import_name)
    serve_models = [
        v for v in module.__dict__.values() if isinstance(v, ServeModel)
    ]
    if len(serve_models) > 0 and serve_model:
        serve_model = serve_models[0]
        if queue == 'rabbit':
            logger.debug('load rabbit {0}'.format(serve_model))
            from mlchain.queue.rabbit_queue import RabbitQueue
            if not isinstance(serve_model, RabbitQueue):
                serve_model = RabbitQueue(serve_model,
                                          module_name=name,
                                          trace=trace)
            serve_model.run(threading=True)
        elif queue == 'redis':
            logger.debug('load redis {0}'.format(serve_model))
            from mlchain.queue.redis_queue import RedisQueue
            if not isinstance(serve_model, RedisQueue):
                serve_model = RedisQueue(serve_model,
                                         module_name=name,
                                         trace=trace)
            serve_model.run(threading=True)
        return serve_model
    apps = [v for v in module.__dict__.values() if isinstance(v, MLServer)]
    if len(apps) > 0:
        return apps[0]
    if len(serve_models) > 0:
        return serve_models[0]

    # Could not find model
    logger.debug("Could not find ServeModel")
    serve_models = [
        v for v in module.__dict__.values() if not isinstance(v, type)
    ]
    if len(serve_models) > 0 and serve_model:
        serve_model = ServeModel(serve_models[-1])
        if queue == 'rabbit':
            logger.debug('Load rabbit {0}'.format(serve_model))
            from mlchain.queue.rabbit_queue import RabbitQueue
            if not isinstance(serve_model, RabbitQueue):
                serve_model = RabbitQueue(serve_model,
                                          module_name=name,
                                          trace=trace)
            serve_model.run(threading=True)
        elif queue == 'redis':
            logger.debug('load redis {0}'.format(serve_model))
            from mlchain.queue.redis_queue import RedisQueue
            if not isinstance(serve_model, RedisQueue):
                serve_model = RedisQueue(serve_model,
                                         module_name=name,
                                         trace=trace)
            serve_model.run(threading=True)
        return serve_model

    logger.error("Could not find any instance to serve")
    return None
コード例 #3
0
ファイル: view.py プロジェクト: dd-dos/mlchain-python
    def __call__(self, function_name, **kws):
        with push_scope() as scope:
            transaction_name = "{0}  ||  {1}".format(
                mlconfig.MLCHAIN_SERVER_NAME, function_name)
            scope.transaction = transaction_name

            with start_transaction(op="task", name=transaction_name):
                uid = self.init_context()

                request_context = {'api_version': self.server.version}
                try:
                    headers, form, files, data = self.parse_data()
                    mlchain_context['REQUESTS_HEADERS'] = headers
                    mlchain_context['REQUESTS_FORM'] = form
                    mlchain_context['REQUESTS_FILES'] = files
                    mlchain_context['REQUESTS_DATA'] = data

                except Exception as ex:
                    request_context['time_process'] = 0
                    output = self.normalize_output(self.base_format,
                                                   function_name, {}, None, ex,
                                                   request_context)
                    return self.make_response(output)

                formatter = self.get_format(headers, form, files, data)
                start_time = time.time()
                try:
                    if self.authentication is not None:
                        self.authentication.check(headers)
                    args, kwargs = formatter.parse_request(
                        function_name, headers, form, files, data,
                        request_context)
                    func = self.server.model.get_function(function_name)
                    kwargs = self.server.get_kwargs(func, *args, **kwargs)
                    kwargs = self.server._normalize_kwargs_to_valid_format(
                        kwargs, func)

                    uid = self.init_context_with_headers(headers, uid)
                    scope.set_tag("transaction_id", uid)
                    logger.debug("Mlchain transaction id: {0}".format(uid))

                    output = self.server.model.call_function(
                        function_name, uid, **kwargs)
                    exception = None
                except MlChainError as ex:
                    exception = ex
                    output = None
                except Exception as ex:
                    exception = ex
                    output = None

                time_process = time.time() - start_time
                request_context['time_process'] = time_process
                output = self.normalize_output(formatter, function_name,
                                               headers, output, exception,
                                               request_context)
                return self.make_response(output)
コード例 #4
0
ファイル: run.py プロジェクト: dd-dos/mlchain-python
def run_command(entry_file, host, port, bind, wrapper, server, workers, config,
                name, mode, api_format, ngrok, kws):
    kws = list(kws)
    if isinstance(entry_file, str) and not os.path.exists(entry_file):
        kws = [f'--entry_file={entry_file}'] + kws
        entry_file = None
    from mlchain import config as mlconfig
    default_config = False

    if config is None:
        default_config = True
        config = 'mlconfig.yaml'

    config_path = copy.deepcopy(config)
    if os.path.isfile(config_path) and os.path.exists(config_path):
        config = mlconfig.load_file(config_path)
        if config is None:
            raise SystemExit(
                "Config file {0} are not supported".format(config_path))
    else:
        if not default_config:
            raise SystemExit("Can't find config file {0}".format(config_path))
        else:
            raise SystemExit(
                "Can't find mlchain config file. Please double check your current working directory. Or use `mlchain init` to initialize a new ones here."
            )
    if 'mode' in config and 'env' in config['mode']:
        if mode in config['mode']['env']:
            config['mode']['default'] = mode
        elif mode is not None:
            available_mode = list(config['mode']['env'].keys())
            available_mode = [
                each for each in available_mode if each != 'default'
            ]
            raise SystemExit(
                f"No {mode} mode are available. Found these mode in config file: {available_mode}"
            )
    mlconfig.load_config(config)
    for kw in kws:
        if kw.startswith('--'):
            tokens = kw[2:].split('=', 1)
            if len(tokens) == 2:
                key, value = tokens
                mlconfig.mlconfig.update({key: value})
            else:
                raise AssertionError("Unexpected param {0}".format(kw))
        else:
            raise AssertionError("Unexpected param {0}".format(kw))
    model_id = mlconfig.get_value(None, config, 'model_id', None)
    entry_file = mlconfig.get_value(entry_file, config, 'entry_file',
                                    'server.py')
    if entry_file.strip() == '':
        raise SystemExit(f"Entry file cannot be empty")
    if not os.path.exists(entry_file):
        raise SystemExit(
            f"Entry file {entry_file} not found in current working directory.")
    host = mlconfig.get_value(host, config, 'host', 'localhost')
    port = mlconfig.get_value(port, config, 'port', 5000)
    server = mlconfig.get_value(server, config, 'server', 'flask')
    if len(bind) == 0:
        bind = None
    bind = mlconfig.get_value(bind, config, 'bind', [])
    wrapper = mlconfig.get_value(wrapper, config, 'wrapper', None)
    if wrapper == 'gunicorn' and os.name == 'nt':
        logger.warning(
            'Gunicorn warper are not supported on Windows. Switching to None instead.'
        )
        wrapper = None
    workers = None
    if 'gunicorn' in config:
        workers = mlconfig.get_value(workers, config['gunicorn'], 'workers',
                                     None)
    if workers is None and 'hypercorn' in config.keys():
        workers = mlconfig.get_value(workers, config['hypercorn'], 'workers',
                                     None)
    workers = int(workers) if workers is not None else 1
    name = mlconfig.get_value(name, config, 'name', None)
    cors = mlconfig.get_value(None, config, 'cors', False)

    static_folder = mlconfig.get_value(None, config, 'static_folder', None)
    static_url_path = mlconfig.get_value(None, config, 'static_url_path', None)
    template_folder = mlconfig.get_value(None, config, 'template_folder', None)

    version = mlconfig.get_value(None, config, 'version', '0.0')
    version = str(version)
    api_format = mlconfig.get_value(api_format, config, 'api_format', None)
    api_keys = os.getenv('API_KEYS', None)
    if api_keys is not None:
        api_keys = api_keys.split(';')
    api_keys = mlconfig.get_value(api_keys, config, 'api_keys', None)
    if api_keys is None:
        authentication = None
    else:
        authentication = Authentication(api_keys)
    import logging
    logging.root = logging.getLogger(name)
    logger.debug(
        dict(entry_file=entry_file,
             host=host,
             port=port,
             bind=bind,
             wrapper=wrapper,
             server=server,
             workers=workers,
             name=name,
             mode=mode,
             api_format=api_format,
             kws=kws))
    bind = list(bind)
    if ngrok:
        from pyngrok import ngrok as pyngrok
        endpoint = pyngrok.connect(port=port)
        logger.info("Ngrok url: {0}".format(endpoint))
        os.environ['NGROK_URL'] = endpoint
    if server == 'grpc':
        from mlchain.server.grpc_server import GrpcServer
        app = get_model(entry_file, serve_model=True)

        if app is None:
            raise Exception(
                "Can not init model class from {0}. Please check mlconfig.yaml or {0} or mlchain run -m {{mode}}!"
                .format(entry_file))

        app = GrpcServer(app, name=name)
        app.run(host, port)
    elif wrapper == 'gunicorn':
        from gunicorn.app.base import BaseApplication
        gpus = select_gpu()

        class GunicornWrapper(BaseApplication):
            def __init__(self, server_, **kwargs):
                assert server_.lower() in ['quart', 'flask']
                self.server = server_.lower()
                self.options = kwargs
                self.autofrontend = False
                super(GunicornWrapper, self).__init__()

            def load_config(self):
                config = {
                    key: value
                    for key, value in self.options.items()
                    if key in self.cfg.settings and value is not None
                }
                for key, value in config.items():
                    self.cfg.set(key.lower(), value)

                from mlchain.base.gunicorn_config import post_worker_init
                self.cfg.set("post_worker_init", post_worker_init)

            def load(self):
                original_cuda_variable = os.environ.get('CUDA_VISIBLE_DEVICES')
                if original_cuda_variable is None:
                    os.environ['CUDA_VISIBLE_DEVICES'] = str(next(gpus))
                else:
                    logger.info(
                        f"Skipping automatic GPU selection for gunicorn worker since CUDA_VISIBLE_DEVICES environment variable is already set to {original_cuda_variable}"
                    )
                serve_model = get_model(entry_file, serve_model=True)

                if serve_model is None:
                    raise Exception(
                        f"Can not init model class from {entry_file}. Please check mlconfig.yaml or {entry_file} or mlchain run -m {{mode}}!"
                    )

                if isinstance(serve_model, ServeModel):
                    if (not self.autofrontend) and model_id is not None:
                        from mlchain.server.autofrontend import register_autofrontend
                        register_autofrontend(model_id,
                                              serve_model=serve_model,
                                              version=version,
                                              endpoint=os.getenv('NGROK_URL'))
                        self.autofrontend = True

                    if self.server == 'flask':
                        from mlchain.server.flask_server import FlaskServer
                        app = FlaskServer(serve_model,
                                          name=name,
                                          api_format=api_format,
                                          version=version,
                                          authentication=authentication,
                                          static_url_path=static_url_path,
                                          static_folder=static_folder,
                                          template_folder=template_folder)
                        app.register_swagger()
                        if cors:
                            from flask_cors import CORS
                            CORS(app.app)
                        return app.app
                    if self.server == 'quart':
                        from mlchain.server.quart_server import QuartServer
                        app = QuartServer(serve_model,
                                          name=name,
                                          api_format=api_format,
                                          version=version,
                                          authentication=authentication,
                                          static_url_path=static_url_path,
                                          static_folder=static_folder,
                                          template_folder=template_folder)
                        app.register_swagger()
                        if cors:
                            from quart_cors import cors as CORS
                            CORS(app.app)
                        return app.app
                return None

        if host is not None and port is not None:
            bind.append('{0}:{1}'.format(host, port))

        bind = list(set(bind))
        gunicorn_config = config.get('gunicorn', {})
        gunicorn_env = ['worker_class', 'threads', 'workers']
        if workers is not None:
            gunicorn_config['workers'] = workers

        for k in gunicorn_env:
            if get_env(k) in os.environ:
                gunicorn_config[k] = os.environ[get_env(k)]
        if server == 'flask' and 'worker_class' in gunicorn_config:
            if 'uvicorn' in gunicorn_config['worker_class']:
                logger.warning(
                    "Can't use flask with uvicorn. change to gthread")
                gunicorn_config['worker_class'] = 'gthread'

        GunicornWrapper(server, bind=bind, **gunicorn_config).run()
    elif wrapper == 'hypercorn' and server == 'quart':
        from mlchain.server.quart_server import QuartServer
        app = get_model(entry_file, serve_model=True)

        if app is None:
            raise Exception(
                "Can not init model class from {0}. Please check mlconfig.yaml or {0} or mlchain run -m {{mode}}!"
                .format(entry_file))

        app = QuartServer(app,
                          name=name,
                          version=version,
                          api_format=api_format,
                          authentication=authentication,
                          static_url_path=static_url_path,
                          static_folder=static_folder,
                          template_folder=template_folder)
        app.run(host,
                port,
                bind=bind,
                cors=cors,
                gunicorn=False,
                hypercorn=True,
                **config.get('hypercorn', {}),
                model_id=model_id)

    app = get_model(entry_file)

    if app is None:
        raise Exception(
            "Can not init model class from {0}. Please check mlconfig.yaml or {0} or mlchain run -m {{mode}}!"
            .format(entry_file))

    if isinstance(app, MLServer):
        if app.__class__.__name__ == 'FlaskServer':
            app.run(host, port, cors=cors, gunicorn=False)
        elif app.__class__.__name__ == 'QuartServer':
            app.run(host, port, cors=cors, gunicorn=False, hypercorn=False)
        elif app.__class__.__name__ == 'GrpcServer':
            app.run(host, port)
    elif isinstance(app, ServeModel):
        if server not in ['quart', 'grpc']:
            server = 'flask'
        if server == 'flask':
            from mlchain.server.flask_server import FlaskServer
            app = FlaskServer(app,
                              name=name,
                              api_format=api_format,
                              version=version,
                              authentication=authentication,
                              static_url_path=static_url_path,
                              static_folder=static_folder,
                              template_folder=template_folder)
            app.run(host,
                    port,
                    cors=cors,
                    gunicorn=False,
                    model_id=model_id,
                    threads=workers > 1)
        elif server == 'quart':
            from mlchain.server.quart_server import QuartServer
            app = QuartServer(app,
                              name=name,
                              api_format=api_format,
                              version=version,
                              authentication=authentication,
                              static_url_path=static_url_path,
                              static_folder=static_folder,
                              template_folder=template_folder)
            app.run(host,
                    port,
                    cors=cors,
                    gunicorn=False,
                    hypercorn=False,
                    model_id=model_id,
                    workers=workers)

        elif server == 'grpc':
            from mlchain.server.grpc_server import GrpcServer
            app = GrpcServer(app, name=name)
            app.run(host, port)
コード例 #5
0
def run_command(entry_file, host, port, bind, wrapper, server, queue, workers,
                trace, log, config, name, mode, kws):
    from mlchain import config as mlconfig
    default_config = False
    if config is None:
        default_config = True
        config = 'mlconfig.yaml'
    if os.path.isfile(config):
        if config.endswith('.json'):
            config = mlconfig.load_json(config)
        elif config.endswith('.yaml') or config.endswith('.yml'):
            config = mlconfig.load_yaml(config)
        else:
            raise AssertionError("Not support file config {0}".format(config))
    else:
        if not default_config:
            raise FileNotFoundError("Not found file {0}".format(config))
        config = {}
    if 'mode' in config and 'env' in config['mode']:
        if mode in config['mode']['env']:
            config['mode']['default'] = mode
    mlconfig.load_config(config)
    entry_file = mlconfig.get_value(entry_file, config, 'entry_file',
                                    'server.py')
    host = mlconfig.get_value(host, config, 'host', 'localhost')
    port = mlconfig.get_value(port, config, 'port', 5000)
    server = mlconfig.get_value(server, config, 'server', 'flask')
    if len(bind) == 0:
        bind = None
    bind = mlconfig.get_value(bind, config, 'bind', [])
    wrapper = mlconfig.get_value(wrapper, config, 'wrapper', None)
    queue = mlconfig.get_value(queue, config, 'queue', None)
    trace = mlconfig.get_value(trace, config, 'trace', False)
    log = mlconfig.get_value(log, config, 'log', False)
    workers = mlconfig.get_value(workers, config, 'workers', 1)
    name = mlconfig.get_value(name, config, 'name', None)
    monitor_sampling_rate = mlconfig.get_value(None, config,
                                               'monitor_sampling_rate', 1)
    cors = mlconfig.get_value(False, config, 'cors', False)
    logger.debug(
        dict(entry_file=entry_file,
             host=host,
             port=port,
             bind=bind,
             wrapper=wrapper,
             server=server,
             queue=queue,
             workers=workers,
             log=log,
             name=name,
             trace=trace,
             kws=kws))
    bind = list(bind)

    if server == 'grpc':
        from mlchain.rpc.server.grpc_server import GrpcServer
        app = get_model(entry_file, serve_model=True, queue=queue, trace=trace)
        app = GrpcServer(app,
                         name=name,
                         trace=trace,
                         log=log,
                         monitor_sampling_rate=monitor_sampling_rate)
        app.run(host, port)
    elif wrapper == 'gunicorn':
        from gunicorn.app.base import BaseApplication

        class GunicornWrapper(BaseApplication):
            def __init__(self, server_, **kwargs):
                assert server_.lower() in ['quart', 'flask']
                self.server = server_.lower()
                self.options = kwargs
                super(GunicornWrapper, self).__init__()

            def load_config(self):
                config = {
                    key: value
                    for key, value in self.options.items()
                    if key in self.cfg.settings and value is not None
                }
                for key, value in config.items():
                    self.cfg.set(key.lower(), value)

            def load(self):
                app = get_model(entry_file,
                                serve_model=True,
                                queue=queue,
                                trace=trace,
                                name=name)
                if isinstance(app, ServeModel):
                    if self.server == 'flask':
                        from mlchain.rpc.server.flask_server import FlaskServer
                        app = FlaskServer(
                            app,
                            name=name,
                            log=log,
                            trace=trace,
                            monitor_sampling_rate=monitor_sampling_rate)
                        app.register_swagger(host, port)
                        if cors:
                            from flask_cors import CORS
                            CORS(app.app)
                        return app.app
                    elif self.server == 'quart':
                        from mlchain.rpc.server.quart_server import QuartServer
                        app = QuartServer(
                            app,
                            name=name,
                            log=log,
                            trace=trace,
                            monitor_sampling_rate=monitor_sampling_rate)
                        if cors:
                            from quart_cors import cors as CORS
                            CORS(app.app)
                        return app.app
                return None

        if host is not None and port is not None:
            bind.append('{0}:{1}'.format(host, port))
        bind = list(set(bind))
        gunicorn_config = config.get('gunicorn', {})
        gunicorn_env = ['worker_class', 'threads', 'workers']
        gunicorn_config['workers'] = workers

        def get_env(_k):
            return 'GUNICORN_' + _k.upper()

        for k in gunicorn_env:
            if get_env(k) in os.environ:
                gunicorn_config[k] = os.environ[get_env(k)]
        GunicornWrapper(server, bind=bind, **gunicorn_config).run()
    elif wrapper == 'hypercorn' and server == 'quart':
        from mlchain.rpc.server.quart_server import QuartServer
        app = get_model(entry_file,
                        serve_model=True,
                        queue=queue,
                        trace=trace,
                        name=name)
        app = QuartServer(app,
                          name=name,
                          log=log,
                          trace=trace,
                          monitor_sampling_rate=monitor_sampling_rate)
        app.run(host,
                port,
                bind=bind,
                cors=cors,
                gunicorn=False,
                hypercorn=True,
                **config.get('hypercorn', {}))

    app = get_model(entry_file,
                    queue=queue,
                    trace=trace,
                    serve_model=True,
                    name=name)
    if isinstance(app, MLServer):
        if app.__class__.__name__ == 'FlaskServer':
            app.run(host, port, cors=cors, gunicorn=False)
        elif app.__class__.__name__ == 'QuartServer':
            app.run(host, port, cors=cors, gunicorn=False, hypercorn=False)
        elif app.__class__.__name__ == 'GrpcServer':
            app.run(host, port)
    elif isinstance(app, ServeModel):
        if server not in ['quart', 'grpc']:
            server = 'flask'
        if server == 'flask':
            from mlchain.rpc.server.flask_server import FlaskServer
            app = FlaskServer(app,
                              name=name,
                              trace=trace,
                              log=log,
                              monitor_sampling_rate=monitor_sampling_rate)
            app.run(host, port, cors=cors, gunicorn=False)
        elif server == 'quart':
            from mlchain.rpc.server.quart_server import QuartServer
            app = QuartServer(app,
                              name=name,
                              trace=trace,
                              log=log,
                              monitor_sampling_rate=monitor_sampling_rate)
            app.run(host, port, cors=cors, gunicorn=False, hypercorn=False)

        elif server == 'grpc':
            from mlchain.rpc.server.grpc_server import GrpcServer
            app = GrpcServer(app,
                             name=name,
                             trace=trace,
                             log=log,
                             monitor_sampling_rate=monitor_sampling_rate)
            app.run(host, port)