コード例 #1
0
def invocations():
    payload = flask.request.data
    if len(payload) == 0:
        return flask.Response(response="", status=http.client.NO_CONTENT)

    try:
        dtest, content_type = serve_utils.parse_content_data(
            payload, flask.request.content_type)
    except Exception as e:
        logging.exception(e)
        return flask.Response(response=str(e),
                              status=http.client.UNSUPPORTED_MEDIA_TYPE)

    try:
        format = ScoringService.load_model()
    except Exception as e:
        logging.exception(e)
        return flask.Response(response="Unable to load model: %s" % e,
                              status=http.client.INTERNAL_SERVER_ERROR)

    try:
        preds = ScoringService.predict(data=dtest,
                                       content_type=content_type,
                                       model_format=format)
    except Exception as e:
        logging.exception(e)
        return flask.Response(
            response="Unable to evaluate payload provided: %s" % e,
            status=http.client.BAD_REQUEST)

    if serve_utils.is_selectable_inference_output():
        try:
            accept = _parse_accept(flask.request)
        except Exception as e:
            logging.exception(e)
            return flask.Response(response=str(e),
                                  status=http.client.NOT_ACCEPTABLE)

        return _handle_selectable_inference_response(preds, accept)

    if SAGEMAKER_BATCH:
        return_data = "\n".join(map(str, preds.tolist())) + '\n'
    else:
        return_data = ",".join(map(str, preds.tolist()))

    return flask.Response(response=return_data,
                          status=http.client.OK,
                          mimetype="text/csv")
コード例 #2
0
 def default_input_fn(self, input_data, input_content_type):
     """Take request data and de-serializes the data into an object for prediction.
         When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server,
         the model server receives two pieces of information:
             - The request Content-Type, for example "application/json"
             - The request data, which is at most 5 MB (5 * 1024 * 1024 bytes) in size.
         The input_fn is responsible to take the request data and pre-process it before prediction.
     Args:
         input_data (obj): the request data.
         input_content_type (str): the request Content-Type. XGBoost accepts CSV, LIBSVM, and RECORDIO-PROTOBUF.
     Returns:
         (obj): data ready for prediction. For XGBoost, this defaults to DMatrix.
     """
     if len(input_data) == 0:
         raise NoContentInferenceError()
     dtest, content_type = serve_utils.parse_content_data(input_data, input_content_type)
     return dtest, content_type
コード例 #3
0
def test_parse_libsvm_data(libsvm_content_type):
    data_payload = b'0:1'
    parsed_payload, parsed_content_type = serve_utils.parse_content_data(data_payload, libsvm_content_type)
    assert parsed_content_type == data_utils.LIBSVM
コード例 #4
0
def test_incorrect_content_type(incorrect_content_type):
    data_payload = '0'
    with pytest.raises(exc.UserError):
        serve_utils.parse_content_data(data_payload, incorrect_content_type)
コード例 #5
0
def test_parse_csv_data(csv_content_type):
    data_payload = b'1,1'
    parsed_payload, parsed_content_type = serve_utils.parse_content_data(data_payload, csv_content_type)
    assert parsed_content_type == data_utils.CSV