Beispiel #1
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("onnx") is None:
            raise CortexException(api["name"], "onnx key not configured")

        _, prefix = ctx.storage.deconstruct_s3_path(api["onnx"]["model"])
        model_path = os.path.join(args.model_dir, os.path.basename(prefix))
        if api["onnx"].get("request_handler") is not None:
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"], args.project_dir)
        request_handler = local_cache.get("request_handler")

        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            cx_logger().info(
                "using pre_inference request handler provided in {}".format(
                    api["onnx"]["request_handler"]))
        else:
            cx_logger().info("pre_inference request handler not found")

        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            cx_logger().info(
                "using post_inference request handler provided in {}".format(
                    api["onnx"]["request_handler"]))
        else:
            cx_logger().info("post_inference request handler not found")

        sess = rt.InferenceSession(model_path)
        local_cache["sess"] = sess
        local_cache["input_metadata"] = sess.get_inputs()
        cx_logger().info("input_metadata: {}".format(
            truncate(extract_signature(local_cache["input_metadata"]))))
        local_cache["output_metadata"] = sess.get_outputs()
        cx_logger().info("output_metadata: {}".format(
            truncate(extract_signature(local_cache["output_metadata"]))))

    except Exception as e:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    cx_logger().info("API is ready")
    serve(app, listen="*:{}".format(args.port))
Beispiel #2
0
def predict(app_name, api_name):
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        sample = request.get_json()
    except Exception as e:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    sess = local_cache["sess"]
    api = local_cache["api"]
    ctx = local_cache["ctx"]
    request_handler = local_cache.get("request_handler")
    input_metadata = local_cache["input_metadata"]
    output_metadata = local_cache["output_metadata"]

    try:
        debug_obj("sample", sample, debug)

        prepared_sample = sample
        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            try:
                prepared_sample = request_handler.pre_inference(
                    sample, input_metadata)
                debug_obj("pre_inference", prepared_sample, debug)
            except Exception as e:
                raise UserRuntimeException(api["request_handler"],
                                           "pre_inference request handler",
                                           str(e)) from e

        inference_input = convert_to_onnx_input(prepared_sample,
                                                input_metadata)
        model_outputs = sess.run([], inference_input)
        result = []
        for model_output in model_outputs:
            if type(model_output) is np.ndarray:
                result.append(model_output.tolist())
            else:
                result.append(model_output)

        debug_obj("inference", result, debug)

        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            try:
                result = request_handler.post_inference(
                    result, output_metadata)
            except Exception as e:
                raise UserRuntimeException(api["request_handler"],
                                           "post_inference request handler",
                                           str(e)) from e

            debug_obj("post_inference", result, debug)
    except Exception as e:
        logger.exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = result
    return jsonify(result)
Beispiel #3
0
def predict():
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        payload = request.get_json()
    except Exception as e:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    sess = local_cache["sess"]
    api = local_cache["api"]
    ctx = local_cache["ctx"]
    request_handler = local_cache.get("request_handler")
    input_metadata = local_cache["input_metadata"]
    output_metadata = local_cache["output_metadata"]

    try:
        debug_obj("payload", payload, debug)

        prepared_payload = payload
        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            try:
                prepared_payload = request_handler.pre_inference(
                    payload, input_metadata, api["onnx"]["metadata"])
                debug_obj("pre_inference", prepared_payload, debug)
            except Exception as e:
                raise UserRuntimeException(api["onnx"]["request_handler"],
                                           "pre_inference request handler",
                                           str(e)) from e

        inference_input = convert_to_onnx_input(prepared_payload,
                                                input_metadata)
        model_output = sess.run([], inference_input)

        debug_obj("inference", model_output, debug)
        result = model_output
        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            try:
                result = request_handler.post_inference(
                    model_output, output_metadata, api["onnx"]["metadata"])
            except Exception as e:
                raise UserRuntimeException(api["onnx"]["request_handler"],
                                           "post_inference request handler",
                                           str(e)) from e

            debug_obj("post_inference", result, debug)
    except Exception as e:
        cx_logger().exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = result
    return jsonify(result)
Beispiel #4
0
def run_predict(sample):
    ctx = local_cache["ctx"]
    request_handler = local_cache.get("request_handler")

    logger.info("sample: " + util.pp_str_flat(sample))

    prepared_sample = sample
    if request_handler is not None and util.has_function(
            request_handler, "pre_inference"):
        prepared_sample = request_handler.pre_inference(
            sample, local_cache["metadata"]["signatureDef"])
        logger.info("pre_inference: " + util.pp_str_flat(prepared_sample))

    validate_sample(sample)

    if util.is_resource_ref(local_cache["api"]["model"]):
        for column in local_cache["required_inputs"]:
            column_type = ctx.get_inferred_column_type(column["name"])
            prepared_sample[column["name"]] = util.upcast(
                prepared_sample[column["name"]], column_type)

        transformed_sample = transform_sample(prepared_sample)
        logger.info("transformed_sample: " +
                    util.pp_str_flat(transformed_sample))

        prediction_request = create_prediction_request(transformed_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto(response_proto)

        result["transformed_sample"] = transformed_sample
        logger.info("inference: " + util.pp_str_flat(result))
    else:
        prediction_request = create_raw_prediction_request(prepared_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto_raw(response_proto)

        logger.info("inference: " + util.pp_str_flat(result))

    if request_handler is not None and util.has_function(
            request_handler, "post_inference"):
        result = request_handler.post_inference(
            result, local_cache["metadata"]["signatureDef"])
        logger.info("post_inference: " + util.pp_str_flat(result))

    return result
Beispiel #5
0
def run_predict(payload, debug=False):
    ctx = local_cache["ctx"]
    api = local_cache["api"]
    request_handler = local_cache.get("request_handler")

    prepared_payload = payload

    debug_obj("payload", payload, debug)
    if request_handler is not None and util.has_function(
            request_handler, "pre_inference"):
        try:
            prepared_payload = request_handler.pre_inference(
                payload,
                local_cache["model_metadata"]["signatureDef"],
                api["tensorflow"]["metadata"],
            )
            debug_obj("pre_inference", prepared_payload, debug)
        except Exception as e:
            raise UserRuntimeException(api["tensorflow"]["request_handler"],
                                       "pre_inference request handler",
                                       str(e)) from e

    validate_payload(prepared_payload)

    prediction_request = create_prediction_request(prepared_payload)
    response_proto = local_cache["stub"].Predict(prediction_request,
                                                 timeout=300.0)
    result = parse_response_proto(response_proto)
    debug_obj("inference", result, debug)

    if request_handler is not None and util.has_function(
            request_handler, "post_inference"):
        try:
            result = request_handler.post_inference(
                result, local_cache["model_metadata"]["signatureDef"],
                api["tensorflow"]["metadata"])
            debug_obj("post_inference", result, debug)
        except Exception as e:
            raise UserRuntimeException(api["tensorflow"]["request_handler"],
                                       "post_inference request handler",
                                       str(e)) from e

    return result
Beispiel #6
0
def run_predict(sample):
    request_handler = local_cache.get("request_handler")

    prepared_sample = sample
    if request_handler is not None and util.has_function(
            request_handler, "pre_inference"):
        prepared_sample = request_handler.pre_inference(
            sample, local_cache["metadata"]["signatureDef"])

    if util.is_resource_ref(local_cache["api"]["model"]):
        transformed_sample = transform_sample(prepared_sample)
        prediction_request = create_prediction_request(transformed_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto(response_proto)

        util.log_indent("Raw sample:", indent=4)
        util.log_pretty_flat(sample, indent=6)
        util.log_indent("Transformed sample:", indent=4)
        util.log_pretty_flat(transformed_sample, indent=6)
        util.log_indent("Prediction:", indent=4)
        util.log_pretty_flat(result, indent=6)

        result["transformed_sample"] = transformed_sample

    else:
        prediction_request = create_raw_prediction_request(prepared_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto_raw(response_proto)
        util.log_indent("Sample:", indent=4)
        util.log_pretty_flat(sample, indent=6)
        util.log_indent("Prediction:", indent=4)
        util.log_pretty_flat(result, indent=6)

    if request_handler is not None and util.has_function(
            request_handler, "post_inference"):
        result = request_handler.post_inference(
            result, local_cache["metadata"]["signatureDef"])

    return result
Beispiel #7
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("predictor") is None:
            raise CortexException(api["name"], "predictor key not configured")

        cx_logger().info("loading the predictor from {}".format(
            api["predictor"]["path"]))
        local_cache["predictor"] = ctx.get_predictor_impl(
            api["name"], args.project_dir)

        if util.has_function(local_cache["predictor"], "init"):
            try:
                model_path = None
                if api["predictor"].get("model") is not None:
                    _, prefix = ctx.storage.deconstruct_s3_path(
                        api["predictor"]["model"])
                    model_path = os.path.join(
                        args.model_dir,
                        os.path.basename(os.path.normpath(prefix)))

                cx_logger().info("calling the predictor's init() function")
                local_cache["predictor"].init(model_path,
                                              api["predictor"]["metadata"])
            except Exception as e:
                raise UserRuntimeException(api["predictor"]["path"], "init",
                                           str(e)) from e
            finally:
                refresh_logger()
    except:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    cx_logger().info("{} api is live".format(api["name"]))
    serve(app, listen="*:{}".format(args.port))
Beispiel #8
0
def start(args):
    api = None
    try:
        ctx = Context(s3_path=args.context,
                      cache_dir=args.cache_dir,
                      workload_id=args.workload_id)
        api = ctx.apis_id_map[args.api]
        local_cache["api"] = api
        local_cache["ctx"] = ctx

        if api.get("tensorflow") is None:
            raise CortexException(api["name"], "tensorflow key not configured")

        if api["tensorflow"].get("request_handler") is not None:
            cx_logger().info("loading the request handler from {}".format(
                api["tensorflow"]["request_handler"]))
            local_cache["request_handler"] = ctx.get_request_handler_impl(
                api["name"], args.project_dir)
        request_handler = local_cache.get("request_handler")

        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            cx_logger().info(
                "using pre_inference request handler defined in {}".format(
                    api["tensorflow"]["request_handler"]))
        else:
            cx_logger().info("pre_inference request handler not defined")

        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            cx_logger().info(
                "using post_inference request handler defined in {}".format(
                    api["tensorflow"]["request_handler"]))
        else:
            cx_logger().info("post_inference request handler not defined")

    except Exception as e:
        cx_logger().exception("failed to start api")
        sys.exit(1)

    try:
        validate_model_dir(args.model_dir)
    except Exception as e:
        cx_logger().exception("failed to validate model")
        sys.exit(1)

    if api.get("tracker") is not None and api["tracker"].get(
            "model_type") == "classification":
        try:
            local_cache["class_set"] = api_utils.get_classes(ctx, api["name"])
        except Exception as e:
            cx_logger().warn(
                "an error occurred while attempting to load classes",
                exc_info=True)

    channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
    local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(
        channel)

    # wait a bit for tf serving to start before querying metadata
    limit = 60
    for i in range(limit):
        try:
            local_cache["model_metadata"] = run_get_model_metadata()
            break
        except Exception as e:
            if i > 6:
                cx_logger().warn(
                    "unable to read model metadata - model is still loading. Retrying..."
                )
            if i == limit - 1:
                cx_logger().exception("retry limit exceeded")
                sys.exit(1)

        time.sleep(5)

    signature_key, parsed_signature = extract_signature(
        local_cache["model_metadata"]["signatureDef"],
        api["tensorflow"]["signature_key"])

    local_cache["signature_key"] = signature_key
    local_cache["parsed_signature"] = parsed_signature
    cx_logger().info("model_signature: {}".format(
        local_cache["parsed_signature"]))

    cx_logger().info("{} API is live".format(api["name"]))
    serve(app, listen="*:{}".format(args.port))
Beispiel #9
0
def predict(app_name, api_name):
    try:
        payload = request.get_json()
    except Exception as e:
        return "Malformed JSON", status.HTTP_400_BAD_REQUEST

    sess = local_cache["sess"]
    api = local_cache["api"]
    request_handler = local_cache.get("request_handler")
    input_metadata = local_cache["input_metadata"]
    output_metadata = local_cache["output_metadata"]

    response = {}

    if not util.is_dict(payload) or "samples" not in payload:
        util.log_pretty_flat(payload, logging_func=logger.error)
        return prediction_failed(
            payload, "top level `samples` key not found in request")

    predictions = []
    samples = payload["samples"]
    if not util.is_list(samples):
        util.log_pretty_flat(samples, logging_func=logger.error)
        return prediction_failed(
            payload,
            "expected the value of key `samples` to be a list of json objects")

    for i, sample in enumerate(payload["samples"]):
        try:
            logger.info("sample: " + util.pp_str_flat(sample))
            prepared_sample = sample
            if request_handler is not None and util.has_function(
                    request_handler, "pre_inference"):
                prepared_sample = request_handler.pre_inference(
                    sample, input_metadata)
                logger.info("pre_inference: " +
                            util.pp_str_flat(prepared_sample))

            inference_input = convert_to_onnx_input(prepared_sample,
                                                    input_metadata)
            model_outputs = sess.run([], inference_input)
            result = []
            for model_output in model_outputs:
                if type(model_output) is np.ndarray:
                    result.append(model_output.tolist())
                else:
                    result.append(model_output)

            logger.info("inference: " + util.pp_str_flat(result))
            if request_handler is not None and util.has_function(
                    request_handler, "post_inference"):
                result = request_handler.post_inference(
                    result, output_metadata)
                logger.info("post_inference: " + util.pp_str_flat(result))

            prediction = {"prediction": result}
        except CortexException as e:
            e.wrap("error", "sample {}".format(i + 1))
            logger.error(str(e))
            logger.exception(
                "An error occurred, see `cx logs -v api {}` for more details.".
                format(api["name"]))
            return prediction_failed(sample, str(e))
        except Exception as e:
            logger.exception(
                "An error occurred, see `cx logs -v api {}` for more details.".
                format(api["name"]))
            return prediction_failed(sample, str(e))

        predictions.append(prediction)

    response["predictions"] = predictions
    response["resource_id"] = api["id"]

    return jsonify(response)