def setup_logger(log_level: str, debug_mode: bool) -> logging.Logger:
    # set up log level
    log_level_raw = os.environ.get(LOG_LEVEL_ENV, log_level.upper())
    log_level_num = getattr(logging, log_level_raw, None)
    if not isinstance(log_level_num, int):
        raise ValueError("Invalid log level: %s", log_level)

    logger.setLevel(log_level_num)

    # Set right level on access logs
    flask_logger = logging.getLogger("werkzeug")
    flask_logger.setLevel(log_level_num)

    if getenv_as_bool(FILTER_METRICS_ACCESS_LOGS_ENV_NAME,
                      default=not debug_mode):
        flask_logger.addFilter(MetricsEndpointFilter())
        gunicorn_logger = logging.getLogger("gunicorn.access")
        gunicorn_logger.addFilter(MetricsEndpointFilter())

    logger.debug("Log level set to %s:%s", log_level, log_level_num)

    # set log level for the imported microservice type
    seldon_microservice.logger.setLevel(log_level_num)
    logging.getLogger().setLevel(log_level_num)
    for handler in logger.handlers:
        handler.setLevel(log_level_num)

    return logger
Example #2
0
def _set_flask_app_configs(app):
    """
    Set the configs for the flask app based on environment variables
    See https://flask.palletsprojects.com/config/#builtin-configuration-values
    :param app:
    :return:
    """
    FLASK_CONFIG_IDENTIFIER = "FLASK_"
    FLASK_CONFIGS_ALLOWED = [
        "DEBUG",
        "EXPLAIN_TEMPLATE_LOADING",
        "JSONIFY_PRETTYPRINT_REGULAR",
        "JSON_SORT_KEYS",
        "PROPAGATE_EXCEPTIONS",
        "PRESERVE_CONTEXT_ON_EXCEPTION",
        "SESSION_COOKIE_HTTPONLY",
        "SESSION_COOKIE_SECURE",
        "SESSION_REFRESH_EACH_REQUEST",
        "TEMPLATES_AUTO_RELOAD",
        "TESTING",
        "TRAP_HTTP_EXCEPTIONS",
        "TRAP_BAD_REQUEST_ERRORS",
        "USE_X_SENDFILE",
    ]

    for flask_config in FLASK_CONFIGS_ALLOWED:
        flask_config_value = getenv_as_bool(
            f"{FLASK_CONFIG_IDENTIFIER}{flask_config}", default=None
        )
        if flask_config_value is None:
            continue
        app.config[flask_config] = flask_config_value
    logger.info(f"App Config:  {app.config}")
Example #3
0
def test_getenv_as_bool(monkeypatch, env_val, expected):
    env_var = "MY_BOOL_VAR"

    if env_val is not None:
        monkeypatch.setenv(env_var, env_val)

    value = scu.getenv_as_bool(env_var, default=False)
    assert value == expected
Example #4
0
    json_to_feedback,
    getenv_as_bool,
)
from seldon_core.flask_utils import get_request, jsonify
from seldon_core.flask_utils import (
    SeldonMicroserviceException,
    ANNOTATION_GRPC_MAX_MSG_SIZE,
)
from seldon_core.proto import prediction_pb2_grpc
from seldon_core.proto import prediction_pb2

logger = logging.getLogger(__name__)

PRED_UNIT_ID = os.environ.get("PREDICTIVE_UNIT_ID", "0")
METRICS_ENDPOINT = os.environ.get("PREDICTIVE_UNIT_METRICS_ENDPOINT", "/metrics")
PAYLOAD_PASSTHROUGH = getenv_as_bool("PAYLOAD_PASSTHROUGH", default=False)


def get_rest_microservice(user_model, seldon_metrics):
    app = Flask(__name__, static_url_path="")
    CORS(app)

    _set_flask_app_configs(app)

    # dict representing the validated model metadata
    # None value will represent a validation error
    metadata_data = seldon_core.seldon_methods.init_metadata(user_model)

    if hasattr(user_model, "model_error_handler"):
        logger.info("Registering the custom error handler...")
        app.register_blueprint(user_model.model_error_handler)
def main():
    LOG_FORMAT = (
        "%(asctime)s - %(name)s:%(funcName)s:%(lineno)s - %(levelname)s:  %(message)s"
    )
    logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
    logger.info("Starting microservice.py:main")
    logger.info(f"Seldon Core version: {__version__}")

    sys.path.append(os.getcwd())
    parser = argparse.ArgumentParser()
    parser.add_argument("interface_name",
                        type=str,
                        help="Name of the user interface.")

    parser.add_argument(
        "--service-type",
        type=str,
        choices=[
            "MODEL", "ROUTER", "TRANSFORMER", "COMBINER", "OUTLIER_DETECTOR"
        ],
        default="MODEL",
    )
    parser.add_argument("--persistence",
                        nargs="?",
                        default=0,
                        const=1,
                        type=int)
    parser.add_argument("--parameters",
                        type=str,
                        default=os.environ.get(PARAMETERS_ENV_NAME, "[]"))
    parser.add_argument(
        "--log-level",
        type=str,
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        default=DEFAULT_LOG_LEVEL,
        help="Log level of the inference server.",
    )
    parser.add_argument(
        "--debug",
        nargs="?",
        type=bool,
        default=getenv_as_bool(DEBUG_ENV, default=False),
        const=True,
        help="Enable debug mode.",
    )
    parser.add_argument(
        "--tracing",
        nargs="?",
        default=int(os.environ.get("TRACING", "0")),
        const=1,
        type=int,
    )

    # gunicorn settings, defaults are from
    # http://docs.gunicorn.org/en/stable/settings.html
    parser.add_argument(
        "--workers",
        type=int,
        default=int(os.environ.get("GUNICORN_WORKERS", "1")),
        help="Number of Gunicorn workers for handling requests.",
    )
    parser.add_argument(
        "--threads",
        type=int,
        default=int(os.environ.get("GUNICORN_THREADS", "10")),
        help="Number of threads to run per Gunicorn worker.",
    )
    parser.add_argument(
        "--max-requests",
        type=int,
        default=int(os.environ.get("GUNICORN_MAX_REQUESTS", "0")),
        help=
        "Maximum number of requests gunicorn worker will process before restarting.",
    )
    parser.add_argument(
        "--max-requests-jitter",
        type=int,
        default=int(os.environ.get("GUNICORN_MAX_REQUESTS_JITTER", "0")),
        help="Maximum random jitter to add to max-requests.",
    )

    parser.add_argument(
        "--single-threaded",
        type=int,
        default=int(os.environ.get("FLASK_SINGLE_THREADED", "0")),
        help=
        "Force the Flask app to run single-threaded. Also applies to Gunicorn.",
    )

    parser.add_argument(
        "--http-port",
        type=int,
        default=int(
            os.environ.get(HTTP_SERVICE_PORT_ENV_NAME, DEFAULT_HTTP_PORT)),
        help="Set http port of seldon service",
    )

    parser.add_argument(
        "--grpc-port",
        type=int,
        default=int(
            os.environ.get(GRPC_SERVICE_PORT_ENV_NAME, DEFAULT_GRPC_PORT)),
        help="Set grpc port of seldon service",
    )

    parser.add_argument(
        "--metrics-port",
        type=int,
        default=int(
            os.environ.get(METRICS_SERVICE_PORT_ENV_NAME,
                           DEFAULT_METRICS_PORT)),
        help="Set metrics port of seldon service",
    )

    parser.add_argument("--pidfile",
                        type=str,
                        default=None,
                        help="A file path to use for the PID file")

    parser.add_argument(
        "--access-log",
        nargs="?",
        type=bool,
        default=getenv_as_bool(GUNICORN_ACCESS_LOG_ENV, default=False),
        const=True,
        help="Enable gunicorn access log.",
    )

    args, remaining = parser.parse_known_args()

    if len(remaining) > 0:
        logger.error(
            f"Unknown args {remaining}. Note since 1.5.0 this CLI does not take API type (REST, GRPC)"
        )
        sys.exit(-1)

    parameters = parse_parameters(json.loads(args.parameters))

    setup_logger(args.log_level, args.debug)

    # set flask trace jaeger extra tags
    jaeger_extra_tags = list(
        filter(
            lambda x: (x != ""),
            [
                tag.strip()
                for tag in os.environ.get("JAEGER_EXTRA_TAGS", "").split(",")
            ],
        ))
    logger.info("Parse JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)

    annotations = load_annotations()
    logger.info("Annotations: %s", annotations)

    parts = args.interface_name.rsplit(".", 1)
    if len(parts) == 1:
        logger.info("Importing %s", args.interface_name)
        interface_file = importlib.import_module(args.interface_name)
        user_class = getattr(interface_file, args.interface_name)
    else:
        logger.info("Importing submodule %s", parts)
        interface_file = importlib.import_module(parts[0])
        user_class = getattr(interface_file, parts[1])

    if args.persistence:
        logger.info("Restoring persisted component")
        user_object = persistence.restore(user_class, parameters)
        persistence.persist(user_object, parameters.get("push_frequency"))
    else:
        user_object = user_class(**parameters)

    http_port = args.http_port
    grpc_port = args.grpc_port
    metrics_port = args.metrics_port

    # if args.tracing:
    #    tracer = setup_tracing(args.interface_name)

    seldon_metrics = SeldonMetrics(worker_id_func=os.getpid)
    # TODO why 2 ways to create metrics server
    # seldon_metrics = SeldonMetrics(
    #    worker_id_func=lambda: threading.current_thread().name
    # )
    if args.debug:
        # Start Flask debug server
        def rest_prediction_server():
            app = seldon_microservice.get_rest_microservice(
                user_object, seldon_metrics)
            try:
                user_object.load()
            except (NotImplementedError, AttributeError):
                pass
            if args.tracing:
                logger.info("Tracing branch is active")
                from flask_opentracing import FlaskTracing

                tracer = setup_tracing(args.interface_name)

                logger.info("Set JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)
                FlaskTracing(tracer, True, app, jaeger_extra_tags)

            app.run(
                host="0.0.0.0",
                port=http_port,
                threaded=False if args.single_threaded else True,
            )

        logger.info(
            "REST microservice running on port %i single-threaded=%s",
            http_port,
            args.single_threaded,
        )
        server1_func = rest_prediction_server
    else:
        # Start production server
        def rest_prediction_server():
            options = {
                "bind": "%s:%s" % ("0.0.0.0", http_port),
                "accesslog": accesslog(args.access_log),
                "loglevel": args.log_level.lower(),
                "timeout": 5000,
                "threads": threads(args.threads, args.single_threaded),
                "workers": args.workers,
                "max_requests": args.max_requests,
                "max_requests_jitter": args.max_requests_jitter,
                "post_worker_init": post_worker_init,
                "worker_exit": partial(worker_exit,
                                       seldon_metrics=seldon_metrics),
            }
            if args.pidfile is not None:
                options["pidfile"] = args.pidfile
            app = seldon_microservice.get_rest_microservice(
                user_object, seldon_metrics)

            UserModelApplication(
                app,
                user_object,
                jaeger_extra_tags,
                args.interface_name,
                options=options,
            ).run()

        logger.info("REST gunicorn microservice running on port %i", http_port)
        server1_func = rest_prediction_server

    def grpc_prediction_server():

        if args.tracing:
            from grpc_opentracing import open_tracing_server_interceptor

            logger.info("Adding tracer")
            tracer = setup_tracing(args.interface_name)
            interceptor = open_tracing_server_interceptor(tracer)
        else:
            interceptor = None

        server = seldon_microservice.get_grpc_server(
            user_object,
            seldon_metrics,
            annotations=annotations,
            trace_interceptor=interceptor,
        )

        try:
            user_object.load()
        except (NotImplementedError, AttributeError):
            pass

        server.add_insecure_port(f"0.0.0.0:{grpc_port}")

        server.start()

        logger.info("GRPC microservice Running on port %i", grpc_port)
        while True:
            time.sleep(1000)

    server2_func = grpc_prediction_server

    def rest_metrics_server():
        app = seldon_microservice.get_metrics_microservice(seldon_metrics)
        if args.debug:
            app.run(host="0.0.0.0", port=metrics_port)
        else:
            options = {
                "bind": "%s:%s" % ("0.0.0.0", metrics_port),
                "accesslog": accesslog(args.access_log),
                "loglevel": args.log_level.lower(),
                "timeout": 5000,
                "max_requests": args.max_requests,
                "max_requests_jitter": args.max_requests_jitter,
                "post_worker_init": post_worker_init,
            }
            if args.pidfile is not None:
                options["pidfile"] = args.pidfile
            StandaloneApplication(app, options=options).run()

    logger.info("REST metrics microservice running on port %i", metrics_port)
    metrics_server_func = rest_metrics_server

    if hasattr(user_object, "custom_service") and callable(
            getattr(user_object, "custom_service")):
        server3_func = user_object.custom_service
    else:
        server3_func = None

    logger.info("Starting servers")
    start_servers(server1_func, server2_func, server3_func,
                  metrics_server_func)
def main():
    LOG_FORMAT = (
        "%(asctime)s - %(name)s:%(funcName)s:%(lineno)s - %(levelname)s:  %(message)s"
    )
    logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
    logger.info("Starting microservice.py:main")
    logger.info(f"Seldon Core version: {__version__}")

    sys.path.append(os.getcwd())
    parser = argparse.ArgumentParser()
    parser.add_argument("interface_name",
                        type=str,
                        help="Name of the user interface.")

    parser.add_argument(
        "--service-type",
        type=str,
        choices=[
            "MODEL", "ROUTER", "TRANSFORMER", "COMBINER", "OUTLIER_DETECTOR"
        ],
        default="MODEL",
    )
    parser.add_argument(
        "--persistence",
        nargs="?",
        default=0,
        const=1,
        type=int,
        help="deprecated argument ",
    )
    parser.add_argument("--parameters",
                        type=str,
                        default=os.environ.get(PARAMETERS_ENV_NAME, "[]"))
    parser.add_argument(
        "--log-level",
        type=str,
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
        default=DEFAULT_LOG_LEVEL,
        help="Log level of the inference server.",
    )
    parser.add_argument(
        "--debug",
        nargs="?",
        type=bool,
        default=getenv_as_bool(DEBUG_ENV, default=False),
        const=True,
        help="Enable debug mode.",
    )
    parser.add_argument(
        "--tracing",
        nargs="?",
        default=int(os.environ.get("TRACING", "0")),
        const=1,
        type=int,
    )

    # gunicorn settings, defaults are from
    # http://docs.gunicorn.org/en/stable/settings.html
    parser.add_argument(
        "--workers",
        type=int,
        default=int(os.environ.get("GUNICORN_WORKERS", "1")),
        help="Number of Gunicorn workers for handling requests.",
    )
    parser.add_argument(
        "--threads",
        type=int,
        default=int(os.environ.get("GUNICORN_THREADS", "1")),
        help="Number of threads to run per Gunicorn worker.",
    )
    parser.add_argument(
        "--max-requests",
        type=int,
        default=int(os.environ.get("GUNICORN_MAX_REQUESTS", "0")),
        help=
        "Maximum number of requests gunicorn worker will process before restarting.",
    )
    parser.add_argument(
        "--max-requests-jitter",
        type=int,
        default=int(os.environ.get("GUNICORN_MAX_REQUESTS_JITTER", "0")),
        help="Maximum random jitter to add to max-requests.",
    )
    parser.add_argument(
        "--keepalive",
        type=int,
        default=int(os.environ.get("GUNICORN_KEEPALIVE", "2")),
        help=
        "The number of seconds to wait for requests on a Keep-Alive connection.",
    )

    parser.add_argument(
        "--single-threaded",
        type=int,
        default=int(os.environ.get("FLASK_SINGLE_THREADED", "0")),
        help=
        "Force the Flask app to run single-threaded. Also applies to Gunicorn.",
    )

    parser.add_argument(
        "--http-port",
        type=int,
        default=int(
            os.environ.get(HTTP_SERVICE_PORT_ENV_NAME, DEFAULT_HTTP_PORT)),
        help="Set http port of seldon service",
    )

    parser.add_argument(
        "--grpc-port",
        type=int,
        default=int(
            os.environ.get(GRPC_SERVICE_PORT_ENV_NAME, DEFAULT_GRPC_PORT)),
        help="Set grpc port of seldon service",
    )

    parser.add_argument(
        "--metrics-port",
        type=int,
        default=int(
            os.environ.get(METRICS_SERVICE_PORT_ENV_NAME,
                           DEFAULT_METRICS_PORT)),
        help="Set metrics port of seldon service",
    )

    parser.add_argument("--pidfile",
                        type=str,
                        default=None,
                        help="A file path to use for the PID file")

    parser.add_argument(
        "--access-log",
        nargs="?",
        type=bool,
        default=getenv_as_bool(GUNICORN_ACCESS_LOG_ENV, default=False),
        const=True,
        help="Enable gunicorn access log.",
    )

    parser.add_argument(
        "--grpc-threads",
        type=int,
        default=os.environ.get("GRPC_THREADS", default="1"),
        help="Number of GRPC threads per worker.",
    )

    parser.add_argument(
        "--grpc-workers",
        type=int,
        default=os.environ.get("GRPC_WORKERS", default="1"),
        help="Number of GPRC workers.",
    )

    args, remaining = parser.parse_known_args()

    if len(remaining) > 0:
        logger.error(
            f"Unknown args {remaining}. Note since 1.5.0 this CLI does not take API type (REST, GRPC)"
        )
        sys.exit(-1)

    parameters = parse_parameters(json.loads(args.parameters))

    setup_logger(args.log_level, args.debug)

    # set flask trace jaeger extra tags
    jaeger_extra_tags = list(
        filter(
            lambda x: (x != ""),
            [
                tag.strip()
                for tag in os.environ.get("JAEGER_EXTRA_TAGS", "").split(",")
            ],
        ))
    logger.info("Parse JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)

    annotations = load_annotations()
    logger.info("Annotations: %s", annotations)

    parts = args.interface_name.rsplit(".", 1)
    if len(parts) == 1:
        logger.info("Importing %s", args.interface_name)
        interface_file = importlib.import_module(args.interface_name)
        user_class = getattr(interface_file, args.interface_name)
    else:
        logger.info("Importing submodule %s", parts)
        interface_file = importlib.import_module(parts[0])
        user_class = getattr(interface_file, parts[1])

    if args.persistence:
        logger.error(f"Persistence: ignored, persistence is deprecated")
    user_object = user_class(**parameters)

    http_port = args.http_port
    grpc_port = args.grpc_port
    metrics_port = args.metrics_port

    seldon_metrics = SeldonMetrics(worker_id_func=os.getpid)
    # TODO why 2 ways to create metrics server
    # seldon_metrics = SeldonMetrics(
    #    worker_id_func=lambda: threading.current_thread().name
    # )
    if args.debug:
        # Start Flask debug server
        def rest_prediction_server():
            app = seldon_microservice.get_rest_microservice(
                user_object, seldon_metrics)
            try:
                user_object.load()
            except (NotImplementedError, AttributeError):
                pass
            if args.tracing:
                logger.info("Tracing branch is active")
                from flask_opentracing import FlaskTracing

                tracer = setup_tracing(args.interface_name)

                logger.info("Set JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)
                FlaskTracing(tracer, True, app, jaeger_extra_tags)

            # Timeout not supported in flask development server
            app.run(
                host="0.0.0.0",
                port=http_port,
                threaded=False if args.single_threaded else True,
            )

        logger.info(
            "REST microservice running on port %i single-threaded=%s",
            http_port,
            args.single_threaded,
        )
        server1_func = rest_prediction_server
    else:
        # Start production server
        def rest_prediction_server():
            rest_timeout = DEFAULT_ANNOTATION_REST_TIMEOUT
            if ANNOTATION_REST_TIMEOUT in annotations:
                # Gunicorn timeout is in seconds so convert as annotation is in miliseconds
                rest_timeout = int(annotations[ANNOTATION_REST_TIMEOUT]) / 1000
                # Converting timeout from float to int and set to 1 if is 0
                rest_timeout = int(rest_timeout) or 1

            options = {
                "bind": "%s:%s" % ("0.0.0.0", http_port),
                "accesslog": accesslog(args.access_log),
                "loglevel": args.log_level.lower(),
                "timeout": rest_timeout,
                "threads": threads(args.threads, args.single_threaded),
                "workers": args.workers,
                "max_requests": args.max_requests,
                "max_requests_jitter": args.max_requests_jitter,
                "post_worker_init": post_worker_init,
                "worker_exit": partial(worker_exit,
                                       seldon_metrics=seldon_metrics),
                "keepalive": args.keepalive,
            }
            logger.info(f"Gunicorn Config:  {options}")

            if args.pidfile is not None:
                options["pidfile"] = args.pidfile
            app = seldon_microservice.get_rest_microservice(
                user_object, seldon_metrics)

            UserModelApplication(
                app,
                user_object,
                args.tracing,
                jaeger_extra_tags,
                args.interface_name,
                options=options,
            ).run()

        logger.info("REST gunicorn microservice running on port %i", http_port)
        server1_func = rest_prediction_server

    def _wait_forever(server):
        try:
            while True:
                time.sleep(60 * 60)
        except KeyboardInterrupt:
            server.stop(None)

    def _run_grpc_server(bind_address):
        """Start a server in a subprocess."""
        logger.info(f"Starting new GRPC server with {args.grpc_threads}.")

        if args.tracing:
            from grpc_opentracing import open_tracing_server_interceptor

            logger.info("Adding tracer")
            tracer = setup_tracing(args.interface_name)
            interceptor = open_tracing_server_interceptor(tracer)
        else:
            interceptor = None

        server = seldon_microservice.get_grpc_server(
            user_object,
            seldon_metrics,
            annotations=annotations,
            trace_interceptor=interceptor,
            num_threads=args.grpc_threads,
        )

        try:
            user_object.load()
        except (NotImplementedError, AttributeError):
            pass

        server.add_insecure_port(bind_address)
        server.start()
        _wait_forever(server)

    @contextlib.contextmanager
    def _reserve_grpc_port():
        """Find and reserve a port for all subprocesses to use."""
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) != 1:
            raise RuntimeError("Failed to set SO_REUSEPORT.")
        sock.bind(("", grpc_port))
        try:
            yield sock.getsockname()[1]
        finally:
            sock.close()

    def grpc_prediction_server():
        with _reserve_grpc_port() as bind_port:
            bind_address = "0.0.0.0:{}".format(bind_port)
            logger.info(
                f"GRPC Server Binding to '%s' {bind_address} with {args.workers} processes"
            )
            sys.stdout.flush()
            workers = []
            for _ in range(args.grpc_workers):
                # NOTE: It is imperative that the worker subprocesses be forked before
                # any gRPC servers start up. See
                # https://github.com/grpc/grpc/issues/16001 for more details.
                worker = multiprocessing.Process(target=_run_grpc_server,
                                                 args=(bind_address, ))
                worker.start()
                workers.append(worker)
            for worker in workers:
                worker.join()

    server2_func = grpc_prediction_server if args.grpc_workers > 0 else None

    def rest_metrics_server():
        app = seldon_microservice.get_metrics_microservice(seldon_metrics)
        if args.debug:
            app.run(host="0.0.0.0", port=metrics_port)
        else:
            options = {
                "bind": "%s:%s" % ("0.0.0.0", metrics_port),
                "accesslog": accesslog(args.access_log),
                "loglevel": args.log_level.lower(),
                "timeout": 5000,
                "max_requests": args.max_requests,
                "max_requests_jitter": args.max_requests_jitter,
                "post_worker_init": post_worker_init,
                "keepalive": args.keepalive,
            }
            if args.pidfile is not None:
                options["pidfile"] = args.pidfile
            StandaloneApplication(app, options=options).run()

    logger.info("REST metrics microservice running on port %i", metrics_port)
    metrics_server_func = rest_metrics_server

    if hasattr(user_object, "custom_service") and callable(
            getattr(user_object, "custom_service")):
        server3_func = user_object.custom_service
    else:
        server3_func = None

    logger.info("Starting servers")
    start_servers(server1_func, server2_func, server3_func,
                  metrics_server_func)
Example #7
0
def main():
    LOG_FORMAT = (
        "%(asctime)s - %(name)s:%(funcName)s:%(lineno)s - %(levelname)s:  %(message)s"
    )
    logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
    logger.info("Starting microservice.py:main")
    logger.info(f"Seldon Core version: {__version__}")

    sys.path.append(os.getcwd())
    parser = argparse.ArgumentParser()
    parser.add_argument("interface_name",
                        type=str,
                        help="Name of the user interface.")
    parser.add_argument("api_type", type=str, choices=["REST", "GRPC", "FBS"])

    parser.add_argument(
        "--service-type",
        type=str,
        choices=[
            "MODEL", "ROUTER", "TRANSFORMER", "COMBINER", "OUTLIER_DETECTOR"
        ],
        default="MODEL",
    )
    parser.add_argument("--persistence",
                        nargs="?",
                        default=0,
                        const=1,
                        type=int)
    parser.add_argument("--parameters",
                        type=str,
                        default=os.environ.get(PARAMETERS_ENV_NAME, "[]"))
    parser.add_argument("--log-level",
                        type=str,
                        default=os.environ.get(LOG_LEVEL_ENV, "INFO"))
    parser.add_argument(
        "--debug",
        nargs="?",
        type=bool,
        default=getenv_as_bool(DEBUG_ENV, default=False),
        const=True,
        help="Enable debug mode.",
    )
    parser.add_argument(
        "--tracing",
        nargs="?",
        default=int(os.environ.get("TRACING", "0")),
        const=1,
        type=int,
    )

    # gunicorn settings, defaults are from
    # http://docs.gunicorn.org/en/stable/settings.html
    parser.add_argument(
        "--workers",
        type=int,
        default=int(os.environ.get("GUNICORN_WORKERS", "1")),
        help="Number of Gunicorn workers for handling requests.",
    )
    parser.add_argument(
        "--threads",
        type=int,
        default=int(os.environ.get("GUNICORN_THREADS", "10")),
        help="Number of threads to run per Gunicorn worker.",
    )
    parser.add_argument(
        "--max-requests",
        type=int,
        default=int(os.environ.get("GUNICORN_MAX_REQUESTS", "0")),
        help=
        "Maximum number of requests gunicorn worker will process before restarting.",
    )
    parser.add_argument(
        "--max-requests-jitter",
        type=int,
        default=int(os.environ.get("GUNICORN_MAX_REQUESTS_JITTER", "0")),
        help="Maximum random jitter to add to max-requests.",
    )

    parser.add_argument(
        "--single-threaded",
        type=int,
        default=int(os.environ.get("FLASK_SINGLE_THREADED", "0")),
        help=
        "Force the Flask app to run single-threaded. Also applies to Gunicorn.",
    )

    args = parser.parse_args()

    parameters = parse_parameters(json.loads(args.parameters))

    # set flask trace jaeger extra tags
    jaeger_extra_tags = list(
        filter(
            lambda x: (x != ""),
            [
                tag.strip()
                for tag in os.environ.get("JAEGER_EXTRA_TAGS", "").split(",")
            ],
        ))
    logger.info("Parse JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)

    # set up log level
    log_level_num = getattr(logging, args.log_level, None)
    if not isinstance(log_level_num, int):
        raise ValueError("Invalid log level: %s", args.log_level)

    logger.setLevel(log_level_num)
    logger.debug("Log level set to %s:%s", args.log_level, log_level_num)

    annotations = load_annotations()
    logger.info("Annotations: %s", annotations)

    parts = args.interface_name.rsplit(".", 1)
    if len(parts) == 1:
        logger.info("Importing %s", args.interface_name)
        interface_file = importlib.import_module(args.interface_name)
        user_class = getattr(interface_file, args.interface_name)
    else:
        logger.info("Importing submodule %s", parts)
        interface_file = importlib.import_module(parts[0])
        user_class = getattr(interface_file, parts[1])

    if args.persistence:
        logger.info("Restoring persisted component")
        user_object = persistence.restore(user_class, parameters)
        persistence.persist(user_object, parameters.get("push_frequency"))
    else:
        user_object = user_class(**parameters)

    # set log level for the imported microservice type
    seldon_microservice.logger.setLevel(log_level_num)
    logging.getLogger().setLevel(log_level_num)
    for handler in logger.handlers:
        handler.setLevel(log_level_num)

    port = int(os.environ.get(SERVICE_PORT_ENV_NAME, DEFAULT_PORT))
    metrics_port = int(
        os.environ.get(METRICS_SERVICE_PORT_ENV_NAME, DEFAULT_METRICS_PORT))

    if args.tracing:
        tracer = setup_tracing(args.interface_name)

    if args.api_type == "REST":
        seldon_metrics = SeldonMetrics(worker_id_func=os.getpid)

        if args.debug:
            # Start Flask debug server
            def rest_prediction_server():
                app = seldon_microservice.get_rest_microservice(
                    user_object, seldon_metrics)
                try:
                    user_object.load()
                except (NotImplementedError, AttributeError):
                    pass
                if args.tracing:
                    logger.info("Tracing branch is active")
                    from flask_opentracing import FlaskTracing

                    logger.info("Set JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)
                    FlaskTracing(tracer, True, app, jaeger_extra_tags)

                app.run(
                    host="0.0.0.0",
                    port=port,
                    threaded=False if args.single_threaded else True,
                )

            logger.info(
                "REST microservice running on port %i single-threaded=%s",
                port,
                args.single_threaded,
            )
            server1_func = rest_prediction_server
        else:
            # Start production server
            def rest_prediction_server():
                options = {
                    "bind": "%s:%s" % ("0.0.0.0", port),
                    "accesslog": accesslog(args.log_level),
                    "loglevel": args.log_level.lower(),
                    "timeout": 5000,
                    "threads": threads(args.threads, args.single_threaded),
                    "workers": args.workers,
                    "max_requests": args.max_requests,
                    "max_requests_jitter": args.max_requests_jitter,
                }
                app = seldon_microservice.get_rest_microservice(
                    user_object, seldon_metrics)
                UserModelApplication(app, user_object, options=options).run()

            logger.info("REST gunicorn microservice running on port %i", port)
            server1_func = rest_prediction_server

    elif args.api_type == "GRPC":
        seldon_metrics = SeldonMetrics(
            worker_id_func=lambda: threading.current_thread().name)

        def grpc_prediction_server():

            if args.tracing:
                from grpc_opentracing import open_tracing_server_interceptor

                logger.info("Adding tracer")
                interceptor = open_tracing_server_interceptor(tracer)
            else:
                interceptor = None

            server = seldon_microservice.get_grpc_server(
                user_object,
                seldon_metrics,
                annotations=annotations,
                trace_interceptor=interceptor,
            )

            try:
                user_object.load()
            except (NotImplementedError, AttributeError):
                pass

            server.add_insecure_port(f"0.0.0.0:{port}")

            server.start()

            logger.info("GRPC microservice Running on port %i", port)
            while True:
                time.sleep(1000)

        server1_func = grpc_prediction_server

    else:
        server1_func = None

    def rest_metrics_server():
        app = seldon_microservice.get_metrics_microservice(seldon_metrics)
        if args.debug:
            app.run(host="0.0.0.0", port=metrics_port)
        else:
            options = {
                "bind": "%s:%s" % ("0.0.0.0", metrics_port),
                "accesslog": accesslog(args.log_level),
                "loglevel": args.log_level.lower(),
                "timeout": 5000,
                "max_requests": args.max_requests,
                "max_requests_jitter": args.max_requests_jitter,
            }
            StandaloneApplication(app, options=options).run()

    logger.info("REST metrics microservice running on port %i", metrics_port)
    metrics_server_func = rest_metrics_server

    if hasattr(user_object, "custom_service") and callable(
            getattr(user_object, "custom_service")):
        server2_func = user_object.custom_service
    else:
        server2_func = None

        logger.info("Starting servers")
    start_servers(server1_func, server2_func, metrics_server_func)
import logging
import time
import os

import ray

import numpy as np
from seldon_core.utils import getenv_as_bool

RAY_PROXY = getenv_as_bool("RAY_PROXY", default=False)
MODEL_FILE = "/microservice/pytorch_model.bin"

BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "100"))
NUM_ACTORS = int(os.environ.get("NUM_ACTORS", "10"))


class RobertaModel:
    def __init__(self, load_on_init=False):
        if load_on_init:
            self.load()

    def load(self):
        import torch
        from simpletransformers.model import TransformerModel

        logging.info("starting RobertaModel...")
        model = TransformerModel(
            "roberta",
            "roberta-base",
            args=({
                "fp16": False,
Example #9
0
)
from seldon_core.gunicorn_utils import (
    StandaloneApplication,
    UserModelApplication,
    accesslog,
    post_worker_init,
    threads,
    worker_exit,
)
from seldon_core.metrics import SeldonMetrics
from seldon_core.utils import getenv_as_bool, setup_tracing

# This is related to how multiprocessing is implemeneted on MacOS
# See https://github.com/SeldonIO/seldon-core/issues/3410 for discussion.
USE_MULTIPROCESS_ENV_NAME = "USE_MULTIPROCESS_PACKAGE"
USE_MULTIPROCESS = getenv_as_bool(USE_MULTIPROCESS_ENV_NAME, default=False)
if USE_MULTIPROCESS:
    import multiprocess as mp
else:
    import multiprocessing as mp

logger = logging.getLogger(__name__)

PARAMETERS_ENV_NAME = "PREDICTIVE_UNIT_PARAMETERS"
HTTP_SERVICE_PORT_ENV_NAME = "PREDICTIVE_UNIT_HTTP_SERVICE_PORT"
GRPC_SERVICE_PORT_ENV_NAME = "PREDICTIVE_UNIT_GRPC_SERVICE_PORT"
METRICS_SERVICE_PORT_ENV_NAME = "PREDICTIVE_UNIT_METRICS_SERVICE_PORT"

FILTER_METRICS_ACCESS_LOGS_ENV_NAME = "FILTER_METRICS_ACCESS_LOGS"

LOG_LEVEL_ENV = "SELDON_LOG_LEVEL"