Ejemplo n.º 1
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("onnx") is None:
            raise CortexException(api["name"], "onnx key not configured")

        _, prefix = ctx.storage.deconstruct_s3_path(api["onnx"]["model"])
        model_path = os.path.join(args.model_dir, os.path.basename(prefix))
        if api["onnx"].get("request_handler") is not None:
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"], args.project_dir)
        request_handler = local_cache.get("request_handler")

        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            cx_logger().info(
                "using pre_inference request handler provided in {}".format(
                    api["onnx"]["request_handler"]))
        else:
            cx_logger().info("pre_inference request handler not found")

        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            cx_logger().info(
                "using post_inference request handler provided in {}".format(
                    api["onnx"]["request_handler"]))
        else:
            cx_logger().info("post_inference request handler not found")

        sess = rt.InferenceSession(model_path)
        local_cache["sess"] = sess
        local_cache["input_metadata"] = sess.get_inputs()
        cx_logger().info("input_metadata: {}".format(
            truncate(extract_signature(local_cache["input_metadata"]))))
        local_cache["output_metadata"] = sess.get_outputs()
        cx_logger().info("output_metadata: {}".format(
            truncate(extract_signature(local_cache["output_metadata"]))))

    except Exception as e:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    cx_logger().info("API is ready")
    serve(app, listen="*:{}".format(args.port))
Ejemplo n.º 2
0
def start(args):

    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        _, prefix = ctx.storage.deconstruct_s3_path(api["model"])
        model_path = os.path.join(args.model_dir, os.path.basename(prefix))
        if api.get("request_handler") is not None:
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"], args.project_dir)

        sess = rt.InferenceSession(model_path)
        local_cache["sess"] = sess
        local_cache["input_metadata"] = sess.get_inputs()
        logger.info("input_metadata: {}".format(
            truncate(extract_signature(local_cache["input_metadata"]))))
        local_cache["output_metadata"] = sess.get_outputs()
        logger.info("output_metadata: {}".format(
            truncate(extract_signature(local_cache["output_metadata"]))))

    except Exception as e:
        logger.exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            logger.warn("an error occurred while attempting to load classes",
                        exc_info=True)

    serve(app, listen="*:{}".format(args.port))
Ejemplo n.º 3
0
def debug_obj(name, payload, debug):
    if not debug:
        return

    cx_logger().info("{}: {}".format(name, stringify.truncate(payload)))
Ejemplo n.º 4
0
def debug_obj(name, sample, debug):
    if not debug:
        return

    logger.info("{}: {}".format(name, stringify.truncate(sample)))