def invocations(): payload = flask.request.data if len(payload) == 0: return flask.Response(response="", status=http.client.NO_CONTENT) try: dtest, content_type = serve_utils.parse_content_data( payload, flask.request.content_type) except Exception as e: logging.exception(e) return flask.Response(response=str(e), status=http.client.UNSUPPORTED_MEDIA_TYPE) try: format = ScoringService.load_model() except Exception as e: logging.exception(e) return flask.Response(response="Unable to load model: %s" % e, status=http.client.INTERNAL_SERVER_ERROR) try: preds = ScoringService.predict(data=dtest, content_type=content_type, model_format=format) except Exception as e: logging.exception(e) return flask.Response( response="Unable to evaluate payload provided: %s" % e, status=http.client.BAD_REQUEST) if serve_utils.is_selectable_inference_output(): try: accept = _parse_accept(flask.request) except Exception as e: logging.exception(e) return flask.Response(response=str(e), status=http.client.NOT_ACCEPTABLE) return _handle_selectable_inference_response(preds, accept) if SAGEMAKER_BATCH: return_data = "\n".join(map(str, preds.tolist())) + '\n' else: return_data = ",".join(map(str, preds.tolist())) return flask.Response(response=return_data, status=http.client.OK, mimetype="text/csv")
def default_input_fn(self, input_data, input_content_type): """Take request data and de-serializes the data into an object for prediction. When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server, the model server receives two pieces of information: - The request Content-Type, for example "application/json" - The request data, which is at most 5 MB (5 * 1024 * 1024 bytes) in size. The input_fn is responsible to take the request data and pre-process it before prediction. Args: input_data (obj): the request data. input_content_type (str): the request Content-Type. XGBoost accepts CSV, LIBSVM, and RECORDIO-PROTOBUF. Returns: (obj): data ready for prediction. For XGBoost, this defaults to DMatrix. """ if len(input_data) == 0: raise NoContentInferenceError() dtest, content_type = serve_utils.parse_content_data(input_data, input_content_type) return dtest, content_type
def test_parse_libsvm_data(libsvm_content_type): data_payload = b'0:1' parsed_payload, parsed_content_type = serve_utils.parse_content_data(data_payload, libsvm_content_type) assert parsed_content_type == data_utils.LIBSVM
def test_incorrect_content_type(incorrect_content_type): data_payload = '0' with pytest.raises(exc.UserError): serve_utils.parse_content_data(data_payload, incorrect_content_type)
def test_parse_csv_data(csv_content_type): data_payload = b'1,1' parsed_payload, parsed_content_type = serve_utils.parse_content_data(data_payload, csv_content_type) assert parsed_content_type == data_utils.CSV