Esempio n. 1
0
def parse_content_data(input_data, input_content_type):
    dtest = None
    content_type = get_content_type(input_content_type)
    payload = input_data
    if content_type == CSV:
        try:
            decoded_payload = payload.strip().decode("utf-8")
            dtest = encoder.csv_to_dmatrix(decoded_payload, dtype=np.float)
        except Exception as e:
            raise RuntimeError("Loading csv data failed with Exception, "
                               "please ensure data is in csv format:\n {}\n {}".format(type(e), e))
    elif content_type == LIBSVM:
        try:
            decoded_payload = payload.strip().decode("utf-8")
            dtest = xgb.DMatrix(_get_sparse_matrix_from_libsvm(decoded_payload))
        except Exception as e:
            raise RuntimeError("Loading libsvm data failed with Exception, "
                               "please ensure data is in libsvm format:\n {}\n {}".format(type(e), e))
    elif content_type == RECORDIO_PROTOBUF:
        try:
            dtest = encoder.recordio_protobuf_to_dmatrix(payload)
        except Exception as e:
            raise RuntimeError("Loading recordio-protobuf data failed with "
                               "Exception, please ensure data is in "
                               "recordio-protobuf format: {} {}".format(type(e), e))
    else:
        raise RuntimeError("Content-type {} is not supported.".format(input_content_type))

    return dtest, content_type
        def default_input_fn(self, input_data, input_content_type):
            """Take request data and de-serializes the data into an object for prediction.
                When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server,
                the model server receives two pieces of information:
                    - The request Content-Type, for example "application/json"
                    - The request data, which is at most 5 MB (5 * 1024 * 1024 bytes) in size.
                The input_fn is responsible to take the request data and pre-process it before prediction.
            Args:
                input_data (obj): the request data.
                input_content_type (str): the request Content-Type. XGBoost accepts CSV, LIBSVM, and RECORDIO-PROTOBUF.
            Returns:
                (obj): data ready for prediction. For XGBoost, this defaults to DMatrix.
            """
            if len(input_data) == 0:
                raise NoContentInferenceError()

            try:
                content_type = get_content_type(input_content_type)
            except Exception:
                raise UnsupportedMediaTypeInferenceError(
                    "Content type must be csv, libsvm, or recordio-protobuf.")

            if content_type == CSV:
                try:
                    input_data = input_data.decode('utf-8')
                    payload = input_data.strip()
                    dtest = xgb_encoder.csv_to_dmatrix(payload, dtype=np.float)
                except Exception as e:
                    raise UnsupportedMediaTypeInferenceError(
                        "Loading csv data failed with "
                        "Exception, please ensure data "
                        "is in csv format: {} {}".format(type(e), e))
            elif content_type == LIBSVM:
                try:
                    input_data = input_data.decode('utf-8')
                    payload = input_data.strip()
                    dtest = xgb.DMatrix(_get_sparse_matrix_from_libsvm(payload))
                except Exception as e:
                    raise UnsupportedMediaTypeInferenceError(
                        "Loading libsvm data failed with "
                        "Exception, please ensure data "
                        "is in libsvm format: {} {}".format(type(e), e))
            elif content_type == RECORDIO_PROTOBUF:
                try:
                    payload = input_data
                    dtest = xgb_encoder.recordio_protobuf_to_dmatrix(payload)
                except Exception as e:
                    raise UnsupportedMediaTypeInferenceError(
                        "Loading recordio-protobuf data failed with "
                        "Exception, please ensure data is in "
                        "recordio-protobuf format: {} {}".format(type(e), e))
            else:
                raise UnsupportedMediaTypeInferenceError(
                    "Content type must be csv, libsvm, or recordio-protobuf.")

            return dtest, content_type
Esempio n. 3
0
def test_sparse_recordio_protobuf_to_dmatrix():
    current_path = Path(os.path.abspath(__file__))
    data_path = os.path.join(str(current_path.parent.parent), 'resources', 'data')
    files_path = os.path.join(data_path, 'recordio_protobuf', 'sparse_edge_cases')

    for filename in os.listdir(files_path):
        file_path = os.path.join(files_path, filename)
        with open(file_path, 'rb') as f:
            target = f.read()
            actual = encoder.recordio_protobuf_to_dmatrix(target)
            assert type(actual) is xgb.DMatrix
Esempio n. 4
0
def test_recordio_protobuf_to_dmatrix(target):
    actual = encoder.recordio_protobuf_to_dmatrix(target)
    assert type(actual) is xgb.DMatrix