Example #1
0
def predict():
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        payload = request.get_json()
    except:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    api = local_cache["api"]
    predictor = local_cache["predictor"]

    try:
        try:
            debug_obj("payload", payload, debug)
            output = predictor.predict(payload, api["predictor"]["metadata"])
            debug_obj("prediction", output, debug)
        except Exception as e:
            raise UserRuntimeException(api["predictor"]["path"], "predict",
                                       str(e)) from e
    except Exception as e:
        cx_logger().exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = output
    return jsonify(output)
Example #2
0
def predict(
        request: Any = Body(..., media_type="application/json"), debug=False):
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]

    debug_obj("payload", request, debug)
    prediction = predictor_impl.predict(request)
    debug_obj("prediction", prediction, debug)

    try:
        json_string = json.dumps(prediction)
    except:
        json_string = util.json_tricks_encoder().encode(prediction)

    response = Response(content=json_string, media_type="application/json")

    if api.tracker is not None:
        try:
            predicted_value = api.tracker.extract_predicted_value(prediction)
            api.post_tracker_metrics(predicted_value)
            if (api.tracker.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks = BackgroundTasks()
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
                response.background = tasks
        except:
            cx_logger().warn("unable to record prediction metric",
                             exc_info=True)

    return response
Example #3
0
def predict(
        request: Any = Body(..., media_type="application/json"), debug=False):
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]

    debug_obj("payload", request, debug)
    prediction = predictor_impl.predict(request)

    try:
        json_string = json.dumps(prediction)
    except Exception as e:
        raise UserRuntimeException(
            f"the return value of predict() or one of its nested values is not JSON serializable",
            str(e),
        ) from e

    debug_obj("prediction", json_string, debug)

    response = Response(content=json_string, media_type="application/json")

    if api.tracker is not None:
        try:
            predicted_value = api.tracker.extract_predicted_value(prediction)
            api.post_tracker_metrics(predicted_value)
            if (api.tracker.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks = BackgroundTasks()
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
                response.background = tasks
        except:
            cx_logger().warn("unable to record prediction metric",
                             exc_info=True)

    return response
Example #4
0
def predict(app_name, api_name):
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        sample = request.get_json()
    except Exception as e:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    sess = local_cache["sess"]
    api = local_cache["api"]
    ctx = local_cache["ctx"]
    request_handler = local_cache.get("request_handler")
    input_metadata = local_cache["input_metadata"]
    output_metadata = local_cache["output_metadata"]

    try:
        debug_obj("sample", sample, debug)

        prepared_sample = sample
        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            try:
                prepared_sample = request_handler.pre_inference(
                    sample, input_metadata)
                debug_obj("pre_inference", prepared_sample, debug)
            except Exception as e:
                raise UserRuntimeException(api["request_handler"],
                                           "pre_inference request handler",
                                           str(e)) from e

        inference_input = convert_to_onnx_input(prepared_sample,
                                                input_metadata)
        model_outputs = sess.run([], inference_input)
        result = []
        for model_output in model_outputs:
            if type(model_output) is np.ndarray:
                result.append(model_output.tolist())
            else:
                result.append(model_output)

        debug_obj("inference", result, debug)

        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            try:
                result = request_handler.post_inference(
                    result, output_metadata)
            except Exception as e:
                raise UserRuntimeException(api["request_handler"],
                                           "post_inference request handler",
                                           str(e)) from e

            debug_obj("post_inference", result, debug)
    except Exception as e:
        logger.exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = result
    return jsonify(result)
Example #5
0
def predict():
    debug = request.args.get("debug", "false").lower() == "true"

    try:
        payload = request.get_json()
    except Exception as e:
        return "malformed json", status.HTTP_400_BAD_REQUEST

    sess = local_cache["sess"]
    api = local_cache["api"]
    ctx = local_cache["ctx"]
    request_handler = local_cache.get("request_handler")
    input_metadata = local_cache["input_metadata"]
    output_metadata = local_cache["output_metadata"]

    try:
        debug_obj("payload", payload, debug)

        prepared_payload = payload
        if request_handler is not None and util.has_function(
                request_handler, "pre_inference"):
            try:
                prepared_payload = request_handler.pre_inference(
                    payload, input_metadata, api["onnx"]["metadata"])
                debug_obj("pre_inference", prepared_payload, debug)
            except Exception as e:
                raise UserRuntimeException(api["onnx"]["request_handler"],
                                           "pre_inference request handler",
                                           str(e)) from e

        inference_input = convert_to_onnx_input(prepared_payload,
                                                input_metadata)
        model_output = sess.run([], inference_input)

        debug_obj("inference", model_output, debug)
        result = model_output
        if request_handler is not None and util.has_function(
                request_handler, "post_inference"):
            try:
                result = request_handler.post_inference(
                    model_output, output_metadata, api["onnx"]["metadata"])
            except Exception as e:
                raise UserRuntimeException(api["onnx"]["request_handler"],
                                           "post_inference request handler",
                                           str(e)) from e

            debug_obj("post_inference", result, debug)
    except Exception as e:
        cx_logger().exception("prediction failed")
        return prediction_failed(str(e))

    g.prediction = result
    return jsonify(result)
Example #6
0
def predict(
        request: Any = Body(..., media_type="application/json"), debug=False):
    api = local_cache["api"]
    predictor_impl = local_cache["predictor_impl"]

    debug_obj("payload", request, debug)
    prediction = predictor_impl.predict(request)
    debug_obj("prediction", prediction, debug)

    if isinstance(prediction, bytes):
        response = Response(content=prediction,
                            media_type="application/octet-stream")
    elif isinstance(prediction, str):
        response = Response(content=prediction, media_type="text/plain")
    elif isinstance(prediction, Response):
        response = prediction
    else:
        try:
            json_string = json.dumps(prediction)
        except Exception as e:
            raise UserRuntimeException(
                str(e),
                "please return an object that is JSON serializable (including its nested fields), a bytes object, a string, or a starlette.response.Response object",
            ) from e
        response = Response(content=json_string, media_type="application/json")

    if api.tracker is not None:
        try:
            predicted_value = api.tracker.extract_predicted_value(prediction)
            api.post_tracker_metrics(predicted_value)
            if (api.tracker.model_type == "classification"
                    and predicted_value not in local_cache["class_set"]):
                tasks = BackgroundTasks()
                tasks.add_task(api.upload_class, class_name=predicted_value)
                local_cache["class_set"].add(predicted_value)
                response.background = tasks
        except:
            cx_logger().warn("unable to record prediction metric",
                             exc_info=True)

    return response
Example #7
0
def run_predict(payload, debug=False):
    ctx = local_cache["ctx"]
    api = local_cache["api"]
    request_handler = local_cache.get("request_handler")

    prepared_payload = payload

    debug_obj("payload", payload, debug)
    if request_handler is not None and util.has_function(
            request_handler, "pre_inference"):
        try:
            prepared_payload = request_handler.pre_inference(
                payload,
                local_cache["model_metadata"]["signatureDef"],
                api["tensorflow"]["metadata"],
            )
            debug_obj("pre_inference", prepared_payload, debug)
        except Exception as e:
            raise UserRuntimeException(api["tensorflow"]["request_handler"],
                                       "pre_inference request handler",
                                       str(e)) from e

    validate_payload(prepared_payload)

    prediction_request = create_prediction_request(prepared_payload)
    response_proto = local_cache["stub"].Predict(prediction_request,
                                                 timeout=300.0)
    result = parse_response_proto(response_proto)
    debug_obj("inference", result, debug)

    if request_handler is not None and util.has_function(
            request_handler, "post_inference"):
        try:
            result = request_handler.post_inference(
                result, local_cache["model_metadata"]["signatureDef"],
                api["tensorflow"]["metadata"])
            debug_obj("post_inference", result, debug)
        except Exception as e:
            raise UserRuntimeException(api["tensorflow"]["request_handler"],
                                       "post_inference request handler",
                                       str(e)) from e

    return result