コード例 #1
0
def parse_content_data(input_data, input_content_type):
    dtest = None
    content_type = get_content_type(input_content_type)
    payload = input_data
    if content_type == CSV:
        try:
            decoded_payload = payload.strip().decode("utf-8")
            dtest = encoder.csv_to_dmatrix(decoded_payload, dtype=np.float)
        except Exception as e:
            raise RuntimeError("Loading csv data failed with Exception, "
                               "please ensure data is in csv format:\n {}\n {}".format(type(e), e))
    elif content_type == LIBSVM:
        try:
            decoded_payload = payload.strip().decode("utf-8")
            dtest = xgb.DMatrix(_get_sparse_matrix_from_libsvm(decoded_payload))
        except Exception as e:
            raise RuntimeError("Loading libsvm data failed with Exception, "
                               "please ensure data is in libsvm format:\n {}\n {}".format(type(e), e))
    elif content_type == RECORDIO_PROTOBUF:
        try:
            dtest = encoder.recordio_protobuf_to_dmatrix(payload)
        except Exception as e:
            raise RuntimeError("Loading recordio-protobuf data failed with "
                               "Exception, please ensure data is in "
                               "recordio-protobuf format: {} {}".format(type(e), e))
    else:
        raise RuntimeError("Content-type {} is not supported.".format(input_content_type))

    return dtest, content_type
コード例 #2
0
        def default_input_fn(self, input_data, input_content_type):
            """Take request data and de-serializes the data into an object for prediction.
                When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server,
                the model server receives two pieces of information:
                    - The request Content-Type, for example "application/json"
                    - The request data, which is at most 5 MB (5 * 1024 * 1024 bytes) in size.
                The input_fn is responsible to take the request data and pre-process it before prediction.
            Args:
                input_data (obj): the request data.
                input_content_type (str): the request Content-Type. XGBoost accepts CSV, LIBSVM, and RECORDIO-PROTOBUF.
            Returns:
                (obj): data ready for prediction. For XGBoost, this defaults to DMatrix.
            """
            if len(input_data) == 0:
                raise NoContentInferenceError()

            try:
                content_type = get_content_type(input_content_type)
            except Exception:
                raise UnsupportedMediaTypeInferenceError(
                    "Content type must be csv, libsvm, or recordio-protobuf.")

            if content_type == CSV:
                try:
                    input_data = input_data.decode('utf-8')
                    payload = input_data.strip()
                    dtest = xgb_encoder.csv_to_dmatrix(payload, dtype=np.float)
                except Exception as e:
                    raise UnsupportedMediaTypeInferenceError(
                        "Loading csv data failed with "
                        "Exception, please ensure data "
                        "is in csv format: {} {}".format(type(e), e))
            elif content_type == LIBSVM:
                try:
                    input_data = input_data.decode('utf-8')
                    payload = input_data.strip()
                    dtest = xgb.DMatrix(_get_sparse_matrix_from_libsvm(payload))
                except Exception as e:
                    raise UnsupportedMediaTypeInferenceError(
                        "Loading libsvm data failed with "
                        "Exception, please ensure data "
                        "is in libsvm format: {} {}".format(type(e), e))
            elif content_type == RECORDIO_PROTOBUF:
                try:
                    payload = input_data
                    dtest = xgb_encoder.recordio_protobuf_to_dmatrix(payload)
                except Exception as e:
                    raise UnsupportedMediaTypeInferenceError(
                        "Loading recordio-protobuf data failed with "
                        "Exception, please ensure data is in "
                        "recordio-protobuf format: {} {}".format(type(e), e))
            else:
                raise UnsupportedMediaTypeInferenceError(
                    "Content type must be csv, libsvm, or recordio-protobuf.")

            return dtest, content_type
コード例 #3
0
def input_fn(request_body, request_content_type):
    """
    The SageMaker XGBoost model server receives the request data body and the content type,
    and invokes the `input_fn`.

    Return a DMatrix (an object that can be passed to predict_fn).
    """
    if request_content_type == "text/csv":
        return xgb_encoders.csv_to_dmatrix(
            request_body.rstrip('\n').lstrip('\n'))
    else:
        raise ValueError(
            "Content type {} is not supported.".format(request_content_type))
コード例 #4
0
def _parse_content_data(request):
    dtest = None
    content_type = serve_utils.get_content_type(request)
    payload = request.data.strip()
    if content_type == "text/csv":
        try:
            payload = payload.decode("utf-8")
            dtest = encoder.csv_to_dmatrix(payload, dtype=np.float)
        except Exception as e:
            raise RuntimeError(
                "Loading csv data failed with Exception, "
                "please ensure data is in csv format:\n {}\n {}".format(
                    type(e), e))
    elif content_type == "text/x-libsvm" or content_type == 'text/libsvm':
        try:
            payload = payload.decode("utf-8")
            dtest = xgb.DMatrix(_get_sparse_matrix_from_libsvm(payload))
        except Exception as e:
            raise RuntimeError(
                "Loading libsvm data failed with Exception, "
                "please ensure data is in libsvm format:\n {}\n {}".format(
                    type(e), e))

    return dtest, content_type
コード例 #5
0
def test_csv_to_dmatrix_error(target):
    try:
        encoder.csv_to_dmatrix(target)
        assert False
    except Exception as e:
        assert type(e) is ValueError
コード例 #6
0
def test_csv_to_dmatrix(target):
    actual = encoder.csv_to_dmatrix(target)
    assert type(actual) is xgb.DMatrix
コード例 #7
0
def transform_fn(model, request_body, content_type, accept_type):
    """
    The SageMaker XGBoost model server receives the request data body and the content type, 
    we first need to create a DMatrix (an object that can be passed to predict)
    """
    multiple_predictions_flag = False

    if "csv" not in content_type:
        # request_body is a bytes object, which we decode to a string
        request_body = request_body.decode()

    # request_body is of the form 'dataset, predict_function'
    # e.g. 'sklearn.datasets.fetch_california_housing(), pred_contribs'
    # comma separated: '[[var1, var2], [var3, var4], ..., varx]], pred_contribs'
    prediction_methods = ["predict", "pred_contribs", "pred_interactions"]
    if request_body.split(', ')[-1] in prediction_methods:
        if "[[" in request_body:
            multiple_predictions_flag = True
            dataset = json.loads(", ".join(request_body.split(', ')[:-1]))
        else:
            # "var1, var2, var3, var4, ..., varx, pred_contribs"
            dataset = ", ".join(request_body.split(', ')[:-1])

        predict = request_body.split(', ')[-1]
    else:
        dataset = request_body
        predict = "predict"

    if "sklearn.datasets" in dataset:
        import sklearn.datasets

        try:
            data = eval(dataset)
        except Exception:
            raise ValueError(
                "Function {} is not supported. Try something like 'sklearn.datasets.fetch_california_housing()'"
                .format(dataset))

        X = data.data
        y = data.target
        dmat = xgb.DMatrix(X, y)
        input_data = dmat

    elif content_type == "text/libsvm":
        input_data = xgb_encoders.libsvm_to_dmatrix(dataset)
    elif content_type == "text/csv":
        if multiple_predictions_flag:
            from pandas import DataFrame
            dataset = DataFrame(dataset)
            # this is for the NYC Taxi columns - may have to adjust for other CSV inputs
            dataset.columns = [
                'f0', 'f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9',
                'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16'
            ]
            input_data = xgb.DMatrix(dataset)

        else:
            input_data = xgb_encoders.csv_to_dmatrix(dataset)

    else:
        raise ValueError(
            "Content type {} is not supported.".format(content_type))
    """
    Now that we have the DMatrix and a prediction method, 
    we invoke the predict method and return the output. 
    """
    if "predict" in predict:
        predictions = model.predict(input_data)
        return str(predictions.tolist())

    elif "pred_contribs" in predict:
        shap_values = model.predict(input_data, pred_contribs=True)
        return str(shap_values.tolist())

    elif "pred_interactions" in predict:
        shap_interactions = model.predict(input_data, pred_interactions=True)
        return str(shap_interactions.tolist())

    else:
        raise ValueError(
            "Prediction parameter {} is not supported.".format(predict))