def setup_logger(log_level: str, debug_mode: bool) -> logging.Logger: # set up log level log_level_raw = os.environ.get(LOG_LEVEL_ENV, log_level.upper()) log_level_num = getattr(logging, log_level_raw, None) if not isinstance(log_level_num, int): raise ValueError("Invalid log level: %s", log_level) logger.setLevel(log_level_num) # Set right level on access logs flask_logger = logging.getLogger("werkzeug") flask_logger.setLevel(log_level_num) if getenv_as_bool(FILTER_METRICS_ACCESS_LOGS_ENV_NAME, default=not debug_mode): flask_logger.addFilter(MetricsEndpointFilter()) gunicorn_logger = logging.getLogger("gunicorn.access") gunicorn_logger.addFilter(MetricsEndpointFilter()) logger.debug("Log level set to %s:%s", log_level, log_level_num) # set log level for the imported microservice type seldon_microservice.logger.setLevel(log_level_num) logging.getLogger().setLevel(log_level_num) for handler in logger.handlers: handler.setLevel(log_level_num) return logger
def _set_flask_app_configs(app): """ Set the configs for the flask app based on environment variables See https://flask.palletsprojects.com/config/#builtin-configuration-values :param app: :return: """ FLASK_CONFIG_IDENTIFIER = "FLASK_" FLASK_CONFIGS_ALLOWED = [ "DEBUG", "EXPLAIN_TEMPLATE_LOADING", "JSONIFY_PRETTYPRINT_REGULAR", "JSON_SORT_KEYS", "PROPAGATE_EXCEPTIONS", "PRESERVE_CONTEXT_ON_EXCEPTION", "SESSION_COOKIE_HTTPONLY", "SESSION_COOKIE_SECURE", "SESSION_REFRESH_EACH_REQUEST", "TEMPLATES_AUTO_RELOAD", "TESTING", "TRAP_HTTP_EXCEPTIONS", "TRAP_BAD_REQUEST_ERRORS", "USE_X_SENDFILE", ] for flask_config in FLASK_CONFIGS_ALLOWED: flask_config_value = getenv_as_bool( f"{FLASK_CONFIG_IDENTIFIER}{flask_config}", default=None ) if flask_config_value is None: continue app.config[flask_config] = flask_config_value logger.info(f"App Config: {app.config}")
def test_getenv_as_bool(monkeypatch, env_val, expected): env_var = "MY_BOOL_VAR" if env_val is not None: monkeypatch.setenv(env_var, env_val) value = scu.getenv_as_bool(env_var, default=False) assert value == expected
json_to_feedback, getenv_as_bool, ) from seldon_core.flask_utils import get_request, jsonify from seldon_core.flask_utils import ( SeldonMicroserviceException, ANNOTATION_GRPC_MAX_MSG_SIZE, ) from seldon_core.proto import prediction_pb2_grpc from seldon_core.proto import prediction_pb2 logger = logging.getLogger(__name__) PRED_UNIT_ID = os.environ.get("PREDICTIVE_UNIT_ID", "0") METRICS_ENDPOINT = os.environ.get("PREDICTIVE_UNIT_METRICS_ENDPOINT", "/metrics") PAYLOAD_PASSTHROUGH = getenv_as_bool("PAYLOAD_PASSTHROUGH", default=False) def get_rest_microservice(user_model, seldon_metrics): app = Flask(__name__, static_url_path="") CORS(app) _set_flask_app_configs(app) # dict representing the validated model metadata # None value will represent a validation error metadata_data = seldon_core.seldon_methods.init_metadata(user_model) if hasattr(user_model, "model_error_handler"): logger.info("Registering the custom error handler...") app.register_blueprint(user_model.model_error_handler)
def main(): LOG_FORMAT = ( "%(asctime)s - %(name)s:%(funcName)s:%(lineno)s - %(levelname)s: %(message)s" ) logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) logger.info("Starting microservice.py:main") logger.info(f"Seldon Core version: {__version__}") sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() parser.add_argument("interface_name", type=str, help="Name of the user interface.") parser.add_argument( "--service-type", type=str, choices=[ "MODEL", "ROUTER", "TRANSFORMER", "COMBINER", "OUTLIER_DETECTOR" ], default="MODEL", ) parser.add_argument("--persistence", nargs="?", default=0, const=1, type=int) parser.add_argument("--parameters", type=str, default=os.environ.get(PARAMETERS_ENV_NAME, "[]")) parser.add_argument( "--log-level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], default=DEFAULT_LOG_LEVEL, help="Log level of the inference server.", ) parser.add_argument( "--debug", nargs="?", type=bool, default=getenv_as_bool(DEBUG_ENV, default=False), const=True, help="Enable debug mode.", ) parser.add_argument( "--tracing", nargs="?", default=int(os.environ.get("TRACING", "0")), const=1, type=int, ) # gunicorn settings, defaults are from # http://docs.gunicorn.org/en/stable/settings.html parser.add_argument( "--workers", type=int, default=int(os.environ.get("GUNICORN_WORKERS", "1")), help="Number of Gunicorn workers for handling requests.", ) parser.add_argument( "--threads", type=int, default=int(os.environ.get("GUNICORN_THREADS", "10")), help="Number of threads to run per Gunicorn worker.", ) parser.add_argument( "--max-requests", type=int, default=int(os.environ.get("GUNICORN_MAX_REQUESTS", "0")), help= "Maximum number of requests gunicorn worker will process before restarting.", ) parser.add_argument( "--max-requests-jitter", type=int, default=int(os.environ.get("GUNICORN_MAX_REQUESTS_JITTER", "0")), help="Maximum random jitter to add to max-requests.", ) parser.add_argument( "--single-threaded", type=int, default=int(os.environ.get("FLASK_SINGLE_THREADED", "0")), help= "Force the Flask app to run single-threaded. Also applies to Gunicorn.", ) parser.add_argument( "--http-port", type=int, default=int( os.environ.get(HTTP_SERVICE_PORT_ENV_NAME, DEFAULT_HTTP_PORT)), help="Set http port of seldon service", ) parser.add_argument( "--grpc-port", type=int, default=int( os.environ.get(GRPC_SERVICE_PORT_ENV_NAME, DEFAULT_GRPC_PORT)), help="Set grpc port of seldon service", ) parser.add_argument( "--metrics-port", type=int, default=int( os.environ.get(METRICS_SERVICE_PORT_ENV_NAME, DEFAULT_METRICS_PORT)), help="Set metrics port of seldon service", ) parser.add_argument("--pidfile", type=str, default=None, help="A file path to use for the PID file") parser.add_argument( "--access-log", nargs="?", type=bool, default=getenv_as_bool(GUNICORN_ACCESS_LOG_ENV, default=False), const=True, help="Enable gunicorn access log.", ) args, remaining = parser.parse_known_args() if len(remaining) > 0: logger.error( f"Unknown args {remaining}. Note since 1.5.0 this CLI does not take API type (REST, GRPC)" ) sys.exit(-1) parameters = parse_parameters(json.loads(args.parameters)) setup_logger(args.log_level, args.debug) # set flask trace jaeger extra tags jaeger_extra_tags = list( filter( lambda x: (x != ""), [ tag.strip() for tag in os.environ.get("JAEGER_EXTRA_TAGS", "").split(",") ], )) logger.info("Parse JAEGER_EXTRA_TAGS %s", jaeger_extra_tags) annotations = load_annotations() logger.info("Annotations: %s", annotations) parts = args.interface_name.rsplit(".", 1) if len(parts) == 1: logger.info("Importing %s", args.interface_name) interface_file = importlib.import_module(args.interface_name) user_class = getattr(interface_file, args.interface_name) else: logger.info("Importing submodule %s", parts) interface_file = importlib.import_module(parts[0]) user_class = getattr(interface_file, parts[1]) if args.persistence: logger.info("Restoring persisted component") user_object = persistence.restore(user_class, parameters) persistence.persist(user_object, parameters.get("push_frequency")) else: user_object = user_class(**parameters) http_port = args.http_port grpc_port = args.grpc_port metrics_port = args.metrics_port # if args.tracing: # tracer = setup_tracing(args.interface_name) seldon_metrics = SeldonMetrics(worker_id_func=os.getpid) # TODO why 2 ways to create metrics server # seldon_metrics = SeldonMetrics( # worker_id_func=lambda: threading.current_thread().name # ) if args.debug: # Start Flask debug server def rest_prediction_server(): app = seldon_microservice.get_rest_microservice( user_object, seldon_metrics) try: user_object.load() except (NotImplementedError, AttributeError): pass if args.tracing: logger.info("Tracing branch is active") from flask_opentracing import FlaskTracing tracer = setup_tracing(args.interface_name) logger.info("Set JAEGER_EXTRA_TAGS %s", jaeger_extra_tags) FlaskTracing(tracer, True, app, jaeger_extra_tags) app.run( host="0.0.0.0", port=http_port, threaded=False if args.single_threaded else True, ) logger.info( "REST microservice running on port %i single-threaded=%s", http_port, args.single_threaded, ) server1_func = rest_prediction_server else: # Start production server def rest_prediction_server(): options = { "bind": "%s:%s" % ("0.0.0.0", http_port), "accesslog": accesslog(args.access_log), "loglevel": args.log_level.lower(), "timeout": 5000, "threads": threads(args.threads, args.single_threaded), "workers": args.workers, "max_requests": args.max_requests, "max_requests_jitter": args.max_requests_jitter, "post_worker_init": post_worker_init, "worker_exit": partial(worker_exit, seldon_metrics=seldon_metrics), } if args.pidfile is not None: options["pidfile"] = args.pidfile app = seldon_microservice.get_rest_microservice( user_object, seldon_metrics) UserModelApplication( app, user_object, jaeger_extra_tags, args.interface_name, options=options, ).run() logger.info("REST gunicorn microservice running on port %i", http_port) server1_func = rest_prediction_server def grpc_prediction_server(): if args.tracing: from grpc_opentracing import open_tracing_server_interceptor logger.info("Adding tracer") tracer = setup_tracing(args.interface_name) interceptor = open_tracing_server_interceptor(tracer) else: interceptor = None server = seldon_microservice.get_grpc_server( user_object, seldon_metrics, annotations=annotations, trace_interceptor=interceptor, ) try: user_object.load() except (NotImplementedError, AttributeError): pass server.add_insecure_port(f"0.0.0.0:{grpc_port}") server.start() logger.info("GRPC microservice Running on port %i", grpc_port) while True: time.sleep(1000) server2_func = grpc_prediction_server def rest_metrics_server(): app = seldon_microservice.get_metrics_microservice(seldon_metrics) if args.debug: app.run(host="0.0.0.0", port=metrics_port) else: options = { "bind": "%s:%s" % ("0.0.0.0", metrics_port), "accesslog": accesslog(args.access_log), "loglevel": args.log_level.lower(), "timeout": 5000, "max_requests": args.max_requests, "max_requests_jitter": args.max_requests_jitter, "post_worker_init": post_worker_init, } if args.pidfile is not None: options["pidfile"] = args.pidfile StandaloneApplication(app, options=options).run() logger.info("REST metrics microservice running on port %i", metrics_port) metrics_server_func = rest_metrics_server if hasattr(user_object, "custom_service") and callable( getattr(user_object, "custom_service")): server3_func = user_object.custom_service else: server3_func = None logger.info("Starting servers") start_servers(server1_func, server2_func, server3_func, metrics_server_func)
def main(): LOG_FORMAT = ( "%(asctime)s - %(name)s:%(funcName)s:%(lineno)s - %(levelname)s: %(message)s" ) logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) logger.info("Starting microservice.py:main") logger.info(f"Seldon Core version: {__version__}") sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() parser.add_argument("interface_name", type=str, help="Name of the user interface.") parser.add_argument( "--service-type", type=str, choices=[ "MODEL", "ROUTER", "TRANSFORMER", "COMBINER", "OUTLIER_DETECTOR" ], default="MODEL", ) parser.add_argument( "--persistence", nargs="?", default=0, const=1, type=int, help="deprecated argument ", ) parser.add_argument("--parameters", type=str, default=os.environ.get(PARAMETERS_ENV_NAME, "[]")) parser.add_argument( "--log-level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], default=DEFAULT_LOG_LEVEL, help="Log level of the inference server.", ) parser.add_argument( "--debug", nargs="?", type=bool, default=getenv_as_bool(DEBUG_ENV, default=False), const=True, help="Enable debug mode.", ) parser.add_argument( "--tracing", nargs="?", default=int(os.environ.get("TRACING", "0")), const=1, type=int, ) # gunicorn settings, defaults are from # http://docs.gunicorn.org/en/stable/settings.html parser.add_argument( "--workers", type=int, default=int(os.environ.get("GUNICORN_WORKERS", "1")), help="Number of Gunicorn workers for handling requests.", ) parser.add_argument( "--threads", type=int, default=int(os.environ.get("GUNICORN_THREADS", "1")), help="Number of threads to run per Gunicorn worker.", ) parser.add_argument( "--max-requests", type=int, default=int(os.environ.get("GUNICORN_MAX_REQUESTS", "0")), help= "Maximum number of requests gunicorn worker will process before restarting.", ) parser.add_argument( "--max-requests-jitter", type=int, default=int(os.environ.get("GUNICORN_MAX_REQUESTS_JITTER", "0")), help="Maximum random jitter to add to max-requests.", ) parser.add_argument( "--keepalive", type=int, default=int(os.environ.get("GUNICORN_KEEPALIVE", "2")), help= "The number of seconds to wait for requests on a Keep-Alive connection.", ) parser.add_argument( "--single-threaded", type=int, default=int(os.environ.get("FLASK_SINGLE_THREADED", "0")), help= "Force the Flask app to run single-threaded. Also applies to Gunicorn.", ) parser.add_argument( "--http-port", type=int, default=int( os.environ.get(HTTP_SERVICE_PORT_ENV_NAME, DEFAULT_HTTP_PORT)), help="Set http port of seldon service", ) parser.add_argument( "--grpc-port", type=int, default=int( os.environ.get(GRPC_SERVICE_PORT_ENV_NAME, DEFAULT_GRPC_PORT)), help="Set grpc port of seldon service", ) parser.add_argument( "--metrics-port", type=int, default=int( os.environ.get(METRICS_SERVICE_PORT_ENV_NAME, DEFAULT_METRICS_PORT)), help="Set metrics port of seldon service", ) parser.add_argument("--pidfile", type=str, default=None, help="A file path to use for the PID file") parser.add_argument( "--access-log", nargs="?", type=bool, default=getenv_as_bool(GUNICORN_ACCESS_LOG_ENV, default=False), const=True, help="Enable gunicorn access log.", ) parser.add_argument( "--grpc-threads", type=int, default=os.environ.get("GRPC_THREADS", default="1"), help="Number of GRPC threads per worker.", ) parser.add_argument( "--grpc-workers", type=int, default=os.environ.get("GRPC_WORKERS", default="1"), help="Number of GPRC workers.", ) args, remaining = parser.parse_known_args() if len(remaining) > 0: logger.error( f"Unknown args {remaining}. Note since 1.5.0 this CLI does not take API type (REST, GRPC)" ) sys.exit(-1) parameters = parse_parameters(json.loads(args.parameters)) setup_logger(args.log_level, args.debug) # set flask trace jaeger extra tags jaeger_extra_tags = list( filter( lambda x: (x != ""), [ tag.strip() for tag in os.environ.get("JAEGER_EXTRA_TAGS", "").split(",") ], )) logger.info("Parse JAEGER_EXTRA_TAGS %s", jaeger_extra_tags) annotations = load_annotations() logger.info("Annotations: %s", annotations) parts = args.interface_name.rsplit(".", 1) if len(parts) == 1: logger.info("Importing %s", args.interface_name) interface_file = importlib.import_module(args.interface_name) user_class = getattr(interface_file, args.interface_name) else: logger.info("Importing submodule %s", parts) interface_file = importlib.import_module(parts[0]) user_class = getattr(interface_file, parts[1]) if args.persistence: logger.error(f"Persistence: ignored, persistence is deprecated") user_object = user_class(**parameters) http_port = args.http_port grpc_port = args.grpc_port metrics_port = args.metrics_port seldon_metrics = SeldonMetrics(worker_id_func=os.getpid) # TODO why 2 ways to create metrics server # seldon_metrics = SeldonMetrics( # worker_id_func=lambda: threading.current_thread().name # ) if args.debug: # Start Flask debug server def rest_prediction_server(): app = seldon_microservice.get_rest_microservice( user_object, seldon_metrics) try: user_object.load() except (NotImplementedError, AttributeError): pass if args.tracing: logger.info("Tracing branch is active") from flask_opentracing import FlaskTracing tracer = setup_tracing(args.interface_name) logger.info("Set JAEGER_EXTRA_TAGS %s", jaeger_extra_tags) FlaskTracing(tracer, True, app, jaeger_extra_tags) # Timeout not supported in flask development server app.run( host="0.0.0.0", port=http_port, threaded=False if args.single_threaded else True, ) logger.info( "REST microservice running on port %i single-threaded=%s", http_port, args.single_threaded, ) server1_func = rest_prediction_server else: # Start production server def rest_prediction_server(): rest_timeout = DEFAULT_ANNOTATION_REST_TIMEOUT if ANNOTATION_REST_TIMEOUT in annotations: # Gunicorn timeout is in seconds so convert as annotation is in miliseconds rest_timeout = int(annotations[ANNOTATION_REST_TIMEOUT]) / 1000 # Converting timeout from float to int and set to 1 if is 0 rest_timeout = int(rest_timeout) or 1 options = { "bind": "%s:%s" % ("0.0.0.0", http_port), "accesslog": accesslog(args.access_log), "loglevel": args.log_level.lower(), "timeout": rest_timeout, "threads": threads(args.threads, args.single_threaded), "workers": args.workers, "max_requests": args.max_requests, "max_requests_jitter": args.max_requests_jitter, "post_worker_init": post_worker_init, "worker_exit": partial(worker_exit, seldon_metrics=seldon_metrics), "keepalive": args.keepalive, } logger.info(f"Gunicorn Config: {options}") if args.pidfile is not None: options["pidfile"] = args.pidfile app = seldon_microservice.get_rest_microservice( user_object, seldon_metrics) UserModelApplication( app, user_object, args.tracing, jaeger_extra_tags, args.interface_name, options=options, ).run() logger.info("REST gunicorn microservice running on port %i", http_port) server1_func = rest_prediction_server def _wait_forever(server): try: while True: time.sleep(60 * 60) except KeyboardInterrupt: server.stop(None) def _run_grpc_server(bind_address): """Start a server in a subprocess.""" logger.info(f"Starting new GRPC server with {args.grpc_threads}.") if args.tracing: from grpc_opentracing import open_tracing_server_interceptor logger.info("Adding tracer") tracer = setup_tracing(args.interface_name) interceptor = open_tracing_server_interceptor(tracer) else: interceptor = None server = seldon_microservice.get_grpc_server( user_object, seldon_metrics, annotations=annotations, trace_interceptor=interceptor, num_threads=args.grpc_threads, ) try: user_object.load() except (NotImplementedError, AttributeError): pass server.add_insecure_port(bind_address) server.start() _wait_forever(server) @contextlib.contextmanager def _reserve_grpc_port(): """Find and reserve a port for all subprocesses to use.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) != 1: raise RuntimeError("Failed to set SO_REUSEPORT.") sock.bind(("", grpc_port)) try: yield sock.getsockname()[1] finally: sock.close() def grpc_prediction_server(): with _reserve_grpc_port() as bind_port: bind_address = "0.0.0.0:{}".format(bind_port) logger.info( f"GRPC Server Binding to '%s' {bind_address} with {args.workers} processes" ) sys.stdout.flush() workers = [] for _ in range(args.grpc_workers): # NOTE: It is imperative that the worker subprocesses be forked before # any gRPC servers start up. See # https://github.com/grpc/grpc/issues/16001 for more details. worker = multiprocessing.Process(target=_run_grpc_server, args=(bind_address, )) worker.start() workers.append(worker) for worker in workers: worker.join() server2_func = grpc_prediction_server if args.grpc_workers > 0 else None def rest_metrics_server(): app = seldon_microservice.get_metrics_microservice(seldon_metrics) if args.debug: app.run(host="0.0.0.0", port=metrics_port) else: options = { "bind": "%s:%s" % ("0.0.0.0", metrics_port), "accesslog": accesslog(args.access_log), "loglevel": args.log_level.lower(), "timeout": 5000, "max_requests": args.max_requests, "max_requests_jitter": args.max_requests_jitter, "post_worker_init": post_worker_init, "keepalive": args.keepalive, } if args.pidfile is not None: options["pidfile"] = args.pidfile StandaloneApplication(app, options=options).run() logger.info("REST metrics microservice running on port %i", metrics_port) metrics_server_func = rest_metrics_server if hasattr(user_object, "custom_service") and callable( getattr(user_object, "custom_service")): server3_func = user_object.custom_service else: server3_func = None logger.info("Starting servers") start_servers(server1_func, server2_func, server3_func, metrics_server_func)
def main(): LOG_FORMAT = ( "%(asctime)s - %(name)s:%(funcName)s:%(lineno)s - %(levelname)s: %(message)s" ) logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) logger.info("Starting microservice.py:main") logger.info(f"Seldon Core version: {__version__}") sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() parser.add_argument("interface_name", type=str, help="Name of the user interface.") parser.add_argument("api_type", type=str, choices=["REST", "GRPC", "FBS"]) parser.add_argument( "--service-type", type=str, choices=[ "MODEL", "ROUTER", "TRANSFORMER", "COMBINER", "OUTLIER_DETECTOR" ], default="MODEL", ) parser.add_argument("--persistence", nargs="?", default=0, const=1, type=int) parser.add_argument("--parameters", type=str, default=os.environ.get(PARAMETERS_ENV_NAME, "[]")) parser.add_argument("--log-level", type=str, default=os.environ.get(LOG_LEVEL_ENV, "INFO")) parser.add_argument( "--debug", nargs="?", type=bool, default=getenv_as_bool(DEBUG_ENV, default=False), const=True, help="Enable debug mode.", ) parser.add_argument( "--tracing", nargs="?", default=int(os.environ.get("TRACING", "0")), const=1, type=int, ) # gunicorn settings, defaults are from # http://docs.gunicorn.org/en/stable/settings.html parser.add_argument( "--workers", type=int, default=int(os.environ.get("GUNICORN_WORKERS", "1")), help="Number of Gunicorn workers for handling requests.", ) parser.add_argument( "--threads", type=int, default=int(os.environ.get("GUNICORN_THREADS", "10")), help="Number of threads to run per Gunicorn worker.", ) parser.add_argument( "--max-requests", type=int, default=int(os.environ.get("GUNICORN_MAX_REQUESTS", "0")), help= "Maximum number of requests gunicorn worker will process before restarting.", ) parser.add_argument( "--max-requests-jitter", type=int, default=int(os.environ.get("GUNICORN_MAX_REQUESTS_JITTER", "0")), help="Maximum random jitter to add to max-requests.", ) parser.add_argument( "--single-threaded", type=int, default=int(os.environ.get("FLASK_SINGLE_THREADED", "0")), help= "Force the Flask app to run single-threaded. Also applies to Gunicorn.", ) args = parser.parse_args() parameters = parse_parameters(json.loads(args.parameters)) # set flask trace jaeger extra tags jaeger_extra_tags = list( filter( lambda x: (x != ""), [ tag.strip() for tag in os.environ.get("JAEGER_EXTRA_TAGS", "").split(",") ], )) logger.info("Parse JAEGER_EXTRA_TAGS %s", jaeger_extra_tags) # set up log level log_level_num = getattr(logging, args.log_level, None) if not isinstance(log_level_num, int): raise ValueError("Invalid log level: %s", args.log_level) logger.setLevel(log_level_num) logger.debug("Log level set to %s:%s", args.log_level, log_level_num) annotations = load_annotations() logger.info("Annotations: %s", annotations) parts = args.interface_name.rsplit(".", 1) if len(parts) == 1: logger.info("Importing %s", args.interface_name) interface_file = importlib.import_module(args.interface_name) user_class = getattr(interface_file, args.interface_name) else: logger.info("Importing submodule %s", parts) interface_file = importlib.import_module(parts[0]) user_class = getattr(interface_file, parts[1]) if args.persistence: logger.info("Restoring persisted component") user_object = persistence.restore(user_class, parameters) persistence.persist(user_object, parameters.get("push_frequency")) else: user_object = user_class(**parameters) # set log level for the imported microservice type seldon_microservice.logger.setLevel(log_level_num) logging.getLogger().setLevel(log_level_num) for handler in logger.handlers: handler.setLevel(log_level_num) port = int(os.environ.get(SERVICE_PORT_ENV_NAME, DEFAULT_PORT)) metrics_port = int( os.environ.get(METRICS_SERVICE_PORT_ENV_NAME, DEFAULT_METRICS_PORT)) if args.tracing: tracer = setup_tracing(args.interface_name) if args.api_type == "REST": seldon_metrics = SeldonMetrics(worker_id_func=os.getpid) if args.debug: # Start Flask debug server def rest_prediction_server(): app = seldon_microservice.get_rest_microservice( user_object, seldon_metrics) try: user_object.load() except (NotImplementedError, AttributeError): pass if args.tracing: logger.info("Tracing branch is active") from flask_opentracing import FlaskTracing logger.info("Set JAEGER_EXTRA_TAGS %s", jaeger_extra_tags) FlaskTracing(tracer, True, app, jaeger_extra_tags) app.run( host="0.0.0.0", port=port, threaded=False if args.single_threaded else True, ) logger.info( "REST microservice running on port %i single-threaded=%s", port, args.single_threaded, ) server1_func = rest_prediction_server else: # Start production server def rest_prediction_server(): options = { "bind": "%s:%s" % ("0.0.0.0", port), "accesslog": accesslog(args.log_level), "loglevel": args.log_level.lower(), "timeout": 5000, "threads": threads(args.threads, args.single_threaded), "workers": args.workers, "max_requests": args.max_requests, "max_requests_jitter": args.max_requests_jitter, } app = seldon_microservice.get_rest_microservice( user_object, seldon_metrics) UserModelApplication(app, user_object, options=options).run() logger.info("REST gunicorn microservice running on port %i", port) server1_func = rest_prediction_server elif args.api_type == "GRPC": seldon_metrics = SeldonMetrics( worker_id_func=lambda: threading.current_thread().name) def grpc_prediction_server(): if args.tracing: from grpc_opentracing import open_tracing_server_interceptor logger.info("Adding tracer") interceptor = open_tracing_server_interceptor(tracer) else: interceptor = None server = seldon_microservice.get_grpc_server( user_object, seldon_metrics, annotations=annotations, trace_interceptor=interceptor, ) try: user_object.load() except (NotImplementedError, AttributeError): pass server.add_insecure_port(f"0.0.0.0:{port}") server.start() logger.info("GRPC microservice Running on port %i", port) while True: time.sleep(1000) server1_func = grpc_prediction_server else: server1_func = None def rest_metrics_server(): app = seldon_microservice.get_metrics_microservice(seldon_metrics) if args.debug: app.run(host="0.0.0.0", port=metrics_port) else: options = { "bind": "%s:%s" % ("0.0.0.0", metrics_port), "accesslog": accesslog(args.log_level), "loglevel": args.log_level.lower(), "timeout": 5000, "max_requests": args.max_requests, "max_requests_jitter": args.max_requests_jitter, } StandaloneApplication(app, options=options).run() logger.info("REST metrics microservice running on port %i", metrics_port) metrics_server_func = rest_metrics_server if hasattr(user_object, "custom_service") and callable( getattr(user_object, "custom_service")): server2_func = user_object.custom_service else: server2_func = None logger.info("Starting servers") start_servers(server1_func, server2_func, metrics_server_func)
import logging import time import os import ray import numpy as np from seldon_core.utils import getenv_as_bool RAY_PROXY = getenv_as_bool("RAY_PROXY", default=False) MODEL_FILE = "/microservice/pytorch_model.bin" BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "100")) NUM_ACTORS = int(os.environ.get("NUM_ACTORS", "10")) class RobertaModel: def __init__(self, load_on_init=False): if load_on_init: self.load() def load(self): import torch from simpletransformers.model import TransformerModel logging.info("starting RobertaModel...") model = TransformerModel( "roberta", "roberta-base", args=({ "fp16": False,
) from seldon_core.gunicorn_utils import ( StandaloneApplication, UserModelApplication, accesslog, post_worker_init, threads, worker_exit, ) from seldon_core.metrics import SeldonMetrics from seldon_core.utils import getenv_as_bool, setup_tracing # This is related to how multiprocessing is implemeneted on MacOS # See https://github.com/SeldonIO/seldon-core/issues/3410 for discussion. USE_MULTIPROCESS_ENV_NAME = "USE_MULTIPROCESS_PACKAGE" USE_MULTIPROCESS = getenv_as_bool(USE_MULTIPROCESS_ENV_NAME, default=False) if USE_MULTIPROCESS: import multiprocess as mp else: import multiprocessing as mp logger = logging.getLogger(__name__) PARAMETERS_ENV_NAME = "PREDICTIVE_UNIT_PARAMETERS" HTTP_SERVICE_PORT_ENV_NAME = "PREDICTIVE_UNIT_HTTP_SERVICE_PORT" GRPC_SERVICE_PORT_ENV_NAME = "PREDICTIVE_UNIT_GRPC_SERVICE_PORT" METRICS_SERVICE_PORT_ENV_NAME = "PREDICTIVE_UNIT_METRICS_SERVICE_PORT" FILTER_METRICS_ACCESS_LOGS_ENV_NAME = "FILTER_METRICS_ACCESS_LOGS" LOG_LEVEL_ENV = "SELDON_LOG_LEVEL"