Exemple #1
0
def predict(app_name, api_name):
    try:
        payload = request.get_json()
    except Exception as e:
        return "Malformed JSON", status.HTTP_400_BAD_REQUEST

    model = local_cache["model"]
    api = local_cache["api"]

    response = {}

    if not util.is_dict(payload) or "samples" not in payload:
        util.log_pretty(payload, logging_func=logger.error)
        return prediction_failed(payload, "top level `samples` key not found in request")

    logger.info("Predicting " + util.pluralize(len(payload["samples"]), "sample", "samples"))

    predictions = []
    samples = payload["samples"]
    if not util.is_list(samples):
        util.log_pretty(samples, logging_func=logger.error)
        return prediction_failed(
            payload, "expected the value of key `samples` to be a list of json objects"
        )

    for i, sample in enumerate(payload["samples"]):
        util.log_indent("sample {}".format(i + 1), 2)

        is_valid, reason = is_valid_sample(sample)
        if not is_valid:
            return prediction_failed(sample, reason)

        for column in local_cache["required_inputs"]:
            sample[column["name"]] = util.upcast(sample[column["name"]], column["type"])

        try:
            result = run_predict(sample)
        except CortexException as e:
            e.wrap("error", "sample {}".format(i + 1))
            logger.error(str(e))
            logger.exception(
                "An error occurred, see `cx logs api {}` for more details.".format(api["name"])
            )
            return prediction_failed(sample, str(e))
        except Exception as e:
            logger.exception(
                "An error occurred, see `cx logs api {}` for more details.".format(api["name"])
            )
            return prediction_failed(sample, str(e))

        predictions.append(result)

    if model["type"] == "regression":
        response["regression_predictions"] = predictions
    if model["type"] == "classification":
        response["classification_predictions"] = predictions

    response["resource_id"] = api["id"]

    return jsonify(response)
Exemple #2
0
def run_predict(sample):
    request_handler = local_cache.get("request_handler")

    prepared_sample = sample
    if request_handler is not None and util.has_function(
            request_handler, "pre_inference"):
        prepared_sample = request_handler.pre_inference(
            sample, local_cache["metadata"]["signatureDef"])

    if util.is_resource_ref(local_cache["api"]["model"]):
        transformed_sample = transform_sample(prepared_sample)
        prediction_request = create_prediction_request(transformed_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto(response_proto)

        util.log_indent("Raw sample:", indent=4)
        util.log_pretty(sample, indent=6)
        util.log_indent("Transformed sample:", indent=4)
        util.log_pretty(transformed_sample, indent=6)
        util.log_indent("Prediction:", indent=4)
        util.log_pretty(result, indent=6)

        result["transformed_sample"] = transformed_sample

    else:
        prediction_request = create_raw_prediction_request(prepared_sample)
        response_proto = local_cache["stub"].Predict(prediction_request,
                                                     timeout=10.0)
        result = parse_response_proto_raw(response_proto)
        util.log_indent("Sample:", indent=4)
        util.log_pretty(sample, indent=6)
        util.log_indent("Prediction:", indent=4)
        util.log_pretty(result, indent=6)

    if request_handler is not None and util.has_function(
            request_handler, "post_inference"):
        result = request_handler.post_inference(
            result, local_cache["metadata"]["signatureDef"])

    return result
Exemple #3
0
def run_predict(sample):
    transformed_sample = transform_sample(sample)
    prediction_request = create_prediction_request(transformed_sample)
    response_proto = local_cache["stub"].Predict(prediction_request, timeout=10.0)
    result = parse_response_proto(response_proto)
    util.log_indent("Raw sample:", indent=4)
    util.log_pretty(sample, indent=6)
    util.log_indent("Transformed sample:", indent=4)
    util.log_pretty(transformed_sample, indent=6)
    util.log_indent("Prediction:", indent=4)
    util.log_pretty(result, indent=6)

    return result
Exemple #4
0
def run_predict(raw_features):
    transformed_features = transform_features(raw_features)
    prediction_request = create_prediction_request(transformed_features)
    response_proto = local_cache["stub"].Predict(prediction_request,
                                                 timeout=10.0)
    result = parse_response_proto(response_proto)
    util.log_indent("Raw features:", indent=4)
    util.log_pretty(raw_features, indent=6)
    util.log_indent("Transformed features:", indent=4)
    util.log_pretty(transformed_features, indent=6)
    util.log_indent("Prediction:", indent=4)
    util.log_pretty(result, indent=6)

    return result
Exemple #5
0
def predict(deployment_name, api_name):
    try:
        payload = request.get_json()
    except Exception as e:
        return "Malformed JSON", status.HTTP_400_BAD_REQUEST

    ctx = local_cache["ctx"]
    api = local_cache["api"]

    response = {}

    if not util.is_dict(payload) or "samples" not in payload:
        util.log_pretty(payload, logging_func=logger.error)
        return prediction_failed(
            payload, "top level `samples` key not found in request")

    logger.info("Predicting " +
                util.pluralize(len(payload["samples"]), "sample", "samples"))

    predictions = []
    samples = payload["samples"]
    if not util.is_list(samples):
        util.log_pretty(samples, logging_func=logger.error)
        return prediction_failed(
            payload,
            "expected the value of key `samples` to be a list of json objects")

    for i, sample in enumerate(payload["samples"]):
        util.log_indent("sample {}".format(i + 1), 2)

        if util.is_resource_ref(api["model"]):
            is_valid, reason = is_valid_sample(sample)
            if not is_valid:
                return prediction_failed(sample, reason)

            for column in local_cache["required_inputs"]:
                column_type = ctx.get_inferred_column_type(column["name"])
                sample[column["name"]] = util.upcast(sample[column["name"]],
                                                     column_type)

        try:
            result = run_predict(sample)
        except CortexException as e:
            e.wrap("error", "sample {}".format(i + 1))
            logger.error(str(e))
            logger.exception(
                "An error occurred, see `cortex logs -v api {}` for more details."
                .format(api["name"]))
            return prediction_failed(sample, str(e))
        except Exception as e:
            logger.exception(
                "An error occurred, see `cortex logs -v api {}` for more details."
                .format(api["name"]))

            # Show signature def for external models (since we don't validate input)
            schemaStr = ""
            signature_def = local_cache["metadata"]["signatureDef"]
            if (not util.is_resource_ref(api["model"]) and
                    signature_def.get("predict") is not None  # Just to be safe
                    and signature_def["predict"].get("inputs") is
                    not None  # Just to be safe
                ):
                schemaStr = "\n\nExpected shema:\n" + util.pp_str(
                    signature_def["predict"]["inputs"])

            return prediction_failed(sample, str(e) + schemaStr)

        predictions.append(result)

    response["predictions"] = predictions
    response["resource_id"] = api["id"]

    return jsonify(response)
Exemple #6
0
def predict(app_name, api_name):
    try:
        payload = request.get_json()
    except Exception as e:
        return "Malformed JSON", status.HTTP_400_BAD_REQUEST

    sess = local_cache["sess"]
    api = local_cache["api"]
    request_handler = local_cache.get("request_handler")
    input_metadata = local_cache["input_metadata"]
    output_metadata = local_cache["output_metadata"]

    response = {}

    if not util.is_dict(payload) or "samples" not in payload:
        util.log_pretty(payload, logging_func=logger.error)
        return prediction_failed(
            payload, "top level `samples` key not found in request")

    logger.info("Predicting " +
                util.pluralize(len(payload["samples"]), "sample", "samples"))

    predictions = []
    samples = payload["samples"]
    if not util.is_list(samples):
        util.log_pretty(samples, logging_func=logger.error)
        return prediction_failed(
            payload,
            "expected the value of key `samples` to be a list of json objects")

    for i, sample in enumerate(payload["samples"]):
        util.log_indent("sample {}".format(i + 1), 2)
        try:
            util.log_indent("Raw sample:", indent=4)
            util.log_pretty(sample, indent=6)

            if request_handler is not None and util.has_function(
                    request_handler, "pre_inference"):
                sample = request_handler.pre_inference(sample, input_metadata)

            inference_input = convert_to_onnx_input(sample, input_metadata)
            model_outputs = sess.run([], inference_input)
            result = []
            for model_output in model_outputs:
                if type(model_output) is np.ndarray:
                    result.append(model_output.tolist())
                else:
                    result.append(model_output)

            if request_handler is not None and util.has_function(
                    request_handler, "post_inference"):
                result = request_handler.post_inference(
                    result, output_metadata)
            util.log_indent("Prediction:", indent=4)
            util.log_pretty(result, indent=6)
            prediction = {"prediction": result}
        except CortexException as e:
            e.wrap("error", "sample {}".format(i + 1))
            logger.error(str(e))
            logger.exception(
                "An error occurred, see `cx logs -v api {}` for more details.".
                format(api["name"]))
            return prediction_failed(sample, str(e))
        except Exception as e:
            logger.exception(
                "An error occurred, see `cx logs -v api {}` for more details.".
                format(api["name"]))
            return prediction_failed(sample, str(e))

        predictions.append(prediction)

    response["predictions"] = predictions
    response["resource_id"] = api["id"]

    return jsonify(response)