Esempio n. 1
0
def parse_content_data(input_data, input_content_type):
    dtest = None
    content_type = get_content_type(input_content_type)
    payload = input_data
    if content_type == CSV:
        try:
            decoded_payload = payload.strip().decode("utf-8")
            dtest = encoder.csv_to_dmatrix(decoded_payload, dtype=np.float)
        except Exception as e:
            raise RuntimeError("Loading csv data failed with Exception, "
                               "please ensure data is in csv format:\n {}\n {}".format(type(e), e))
    elif content_type == LIBSVM:
        try:
            decoded_payload = payload.strip().decode("utf-8")
            dtest = xgb.DMatrix(_get_sparse_matrix_from_libsvm(decoded_payload))
        except Exception as e:
            raise RuntimeError("Loading libsvm data failed with Exception, "
                               "please ensure data is in libsvm format:\n {}\n {}".format(type(e), e))
    elif content_type == RECORDIO_PROTOBUF:
        try:
            dtest = encoder.recordio_protobuf_to_dmatrix(payload)
        except Exception as e:
            raise RuntimeError("Loading recordio-protobuf data failed with "
                               "Exception, please ensure data is in "
                               "recordio-protobuf format: {} {}".format(type(e), e))
    else:
        raise RuntimeError("Content-type {} is not supported.".format(input_content_type))

    return dtest, content_type
    def predict(cls,
                data,
                content_type='text/x-libsvm',
                model_format='pkl_format'):
        try:
            parsed_content_type = get_content_type(content_type)
        except Exception:
            raise ValueError(
                'Content type {} is not supported'.format(content_type))

        if model_format == 'pkl_format':
            x = len(cls.booster.feature_names)
            y = len(data.feature_names)

            if parsed_content_type == LIBSVM:
                if y > x + 1:
                    raise ValueError(
                        'Feature size of libsvm inference data {} is larger than '
                        'feature size of trained model {}.'.format(y, x))
            elif parsed_content_type in [CSV, RECORDIO_PROTOBUF]:
                if not ((x == y) or (x == y + 1)):
                    raise ValueError(
                        'Feature size of {} inference data {} is not consistent '
                        'with feature size of trained model {}.'.format(
                            content_type, y, x))
            else:
                raise ValueError(
                    'Content type {} is not supported'.format(content_type))
        return cls.booster.predict(data,
                                   ntree_limit=getattr(cls.booster,
                                                       "best_ntree_limit", 0),
                                   validate_features=False)
        def default_input_fn(self, input_data, input_content_type):
            """Take request data and de-serializes the data into an object for prediction.
                When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server,
                the model server receives two pieces of information:
                    - The request Content-Type, for example "application/json"
                    - The request data, which is at most 5 MB (5 * 1024 * 1024 bytes) in size.
                The input_fn is responsible to take the request data and pre-process it before prediction.
            Args:
                input_data (obj): the request data.
                input_content_type (str): the request Content-Type. XGBoost accepts CSV, LIBSVM, and RECORDIO-PROTOBUF.
            Returns:
                (obj): data ready for prediction. For XGBoost, this defaults to DMatrix.
            """
            if len(input_data) == 0:
                raise NoContentInferenceError()

            try:
                content_type = get_content_type(input_content_type)
            except Exception:
                raise UnsupportedMediaTypeInferenceError(
                    "Content type must be csv, libsvm, or recordio-protobuf.")

            if content_type == CSV:
                try:
                    input_data = input_data.decode('utf-8')
                    payload = input_data.strip()
                    dtest = xgb_encoder.csv_to_dmatrix(payload, dtype=np.float)
                except Exception as e:
                    raise UnsupportedMediaTypeInferenceError(
                        "Loading csv data failed with "
                        "Exception, please ensure data "
                        "is in csv format: {} {}".format(type(e), e))
            elif content_type == LIBSVM:
                try:
                    input_data = input_data.decode('utf-8')
                    payload = input_data.strip()
                    dtest = xgb.DMatrix(_get_sparse_matrix_from_libsvm(payload))
                except Exception as e:
                    raise UnsupportedMediaTypeInferenceError(
                        "Loading libsvm data failed with "
                        "Exception, please ensure data "
                        "is in libsvm format: {} {}".format(type(e), e))
            elif content_type == RECORDIO_PROTOBUF:
                try:
                    payload = input_data
                    dtest = xgb_encoder.recordio_protobuf_to_dmatrix(payload)
                except Exception as e:
                    raise UnsupportedMediaTypeInferenceError(
                        "Loading recordio-protobuf data failed with "
                        "Exception, please ensure data is in "
                        "recordio-protobuf format: {} {}".format(type(e), e))
            else:
                raise UnsupportedMediaTypeInferenceError(
                    "Content type must be csv, libsvm, or recordio-protobuf.")

            return dtest, content_type
Esempio n. 4
0
    def test_get_content_type(self):
        self.assertEqual('libsvm', data_utils.get_content_type('libsvm'))
        self.assertEqual('libsvm', data_utils.get_content_type('text/libsvm'))
        self.assertEqual('libsvm', data_utils.get_content_type('text/x-libsvm'))

        self.assertEqual('csv', data_utils.get_content_type('csv'))
        self.assertEqual('csv', data_utils.get_content_type('text/csv'))

        with self.assertRaises(exc.UserError):
            data_utils.get_content_type('incorrect_format')
Esempio n. 5
0
def predict(model, model_format, dtest, input_content_type, objective=None):
    bst, bst_format = (model[0], model_format[0]) if type(model) is list else (
        model, model_format)

    if bst_format == PKL_FORMAT:
        x = len(bst.feature_names)
        y = len(dtest.feature_names)

        try:
            content_type = get_content_type(input_content_type)
        except Exception:
            raise ValueError(
                'Content type {} is not supported'.format(input_content_type))

        if content_type == LIBSVM:
            if y > x + 1:
                raise ValueError(
                    'Feature size of libsvm inference data {} is larger than '
                    'feature size of trained model {}.'.format(y, x))
        elif content_type in [CSV, RECORDIO_PROTOBUF]:
            if not ((x == y) or (x == y + 1)):
                raise ValueError(
                    'Feature size of {} inference data {} is not consistent '
                    'with feature size of trained model {}.'.format(
                        content_type, y, x))
        else:
            raise ValueError(
                'Content type {} is not supported'.format(content_type))

    if isinstance(model, list):
        ensemble = [
            booster.predict(dtest,
                            ntree_limit=getattr(booster, "best_ntree_limit",
                                                0),
                            validate_features=False) for booster in model
        ]

        if objective in ["multi:softmax", "binary:hinge"]:
            logging.info(
                f"Vote ensemble prediction of {objective} with {len(model)} models"
            )
            return stats.mode(ensemble).mode[0]
        else:
            logging.info(
                f"Average ensemble prediction of {objective} with {len(model)} models"
            )
            return np.mean(ensemble, axis=0)
    else:
        return model.predict(dtest,
                             ntree_limit=getattr(model, "best_ntree_limit", 0),
                             validate_features=False)
Esempio n. 6
0
def predict(booster, model_format, dtest, input_content_type):
    if model_format == PKL_FORMAT:
        x = len(booster.feature_names)
        y = len(dtest.feature_names)

        try:
            content_type = get_content_type(input_content_type)
        except Exception:
            raise ValueError('Content type {} is not supported'.format(input_content_type))

        if content_type == LIBSVM:
            if y > x + 1:
                raise ValueError('Feature size of libsvm inference data {} is larger than '
                                 'feature size of trained model {}.'.format(y, x))
        elif content_type in [CSV, RECORDIO_PROTOBUF]:
            if not ((x == y) or (x == y + 1)):
                raise ValueError('Feature size of {} inference data {} is not consistent '
                                 'with feature size of trained model {}.'.
                                 format(content_type, y, x))
        else:
            raise ValueError('Content type {} is not supported'.format(content_type))
    return booster.predict(dtest,
                           ntree_limit=getattr(booster, "best_ntree_limit", 0),
                           validate_features=False)
def sagemaker_train(train_config, data_config, train_path, val_path, model_dir,
                    sm_hosts, sm_current_host, checkpoint_config):
    """Train XGBoost in a SageMaker training environment.

    Validate hyperparameters and data channel using SageMaker Algorithm Toolkit to fail fast if needed.
    If running with more than one host, check if the current host has data and run train_job() using
    rabit_run.

    :param train_config:
    :param data_config:
    :param train_path:
    :param val_path:
    :param model_dir:
    :param sm_hosts:
    :param sm_current_host:
    :param checkpoint_config:
    """
    metrics = metrics_mod.initialize()

    hyperparameters = hpv.initialize(metrics)
    validated_train_config = hyperparameters.validate(train_config)
    if validated_train_config.get("updater"):
        validated_train_config["updater"] = ",".join(
            validated_train_config["updater"])

    channels = cv.initialize()
    validated_data_config = channels.validate(data_config)

    logging.debug("hyperparameters {}".format(validated_train_config))
    logging.debug("channels {}".format(validated_data_config))

    # Get Training and Validation Data Matrices
    file_type = get_content_type(
        validated_data_config['train'].get("ContentType"))
    input_mode = validated_data_config['train'].get("TrainingInputMode")
    csv_weights = validated_train_config.get("csv_weights", 0)
    is_pipe = (input_mode == Channel.PIPE_MODE)

    validation_channel = validated_data_config.get('validation', None)
    train_dmatrix, val_dmatrix = get_validated_dmatrices(
        train_path, val_path, file_type, csv_weights, is_pipe)

    checkpoint_dir = checkpoint_config.get("LocalPath", None)

    train_args = dict(train_cfg=validated_train_config,
                      train_dmatrix=train_dmatrix,
                      val_dmatrix=val_dmatrix,
                      model_dir=model_dir,
                      checkpoint_dir=checkpoint_dir)

    # Obtain information about training resources to determine whether to set up Rabit or not
    num_hosts = len(sm_hosts)

    if num_hosts > 1:
        # Wait for hosts to find each other
        logging.info("Distributed node training with {} hosts: {}".format(
            num_hosts, sm_hosts))
        distributed.wait_hostname_resolution(sm_hosts)

        if not train_dmatrix:
            logging.warning(
                "Host {} does not have data. Will broadcast to cluster and will not be used in distributed"
                " training.".format(sm_current_host))
        distributed.rabit_run(exec_fun=train_job,
                              args=train_args,
                              include_in_training=(train_dmatrix is not None),
                              hosts=sm_hosts,
                              current_host=sm_current_host,
                              update_rabit_args=True)
    elif num_hosts == 1:
        if train_dmatrix:
            if validation_channel:
                if not val_dmatrix:
                    raise exc.UserError(
                        "No data in validation channel path {}".format(
                            val_path))
            logging.info("Single node training.")
            train_args.update({'is_master': True})
            train_job(**train_args)
        else:
            raise exc.UserError(
                "No data in training channel path {}".format(train_path))
    else:
        raise exc.PlatformError(
            "Number of hosts should be an int greater than or equal to 1")
Esempio n. 8
0
    def test_get_content_type(self):
        self.assertEqual('libsvm', data_utils.get_content_type('libsvm'))
        self.assertEqual('libsvm', data_utils.get_content_type('text/libsvm'))
        self.assertEqual('libsvm',
                         data_utils.get_content_type('text/x-libsvm'))

        self.assertEqual('csv', data_utils.get_content_type('csv'))
        self.assertEqual('csv', data_utils.get_content_type('text/csv'))
        self.assertEqual('csv',
                         data_utils.get_content_type('text/csv; label_size=1'))
        self.assertEqual(
            'csv', data_utils.get_content_type('text/csv;label_size = 1'))
        self.assertEqual(
            'csv', data_utils.get_content_type('text/csv; charset=utf-8'))
        self.assertEqual(
            'csv',
            data_utils.get_content_type(
                'text/csv; label_size=1; charset=utf-8'))

        self.assertEqual('parquet', data_utils.get_content_type('parquet'))
        self.assertEqual('parquet',
                         data_utils.get_content_type('application/x-parquet'))

        self.assertEqual('recordio-protobuf',
                         data_utils.get_content_type('recordio-protobuf'))
        self.assertEqual(
            'recordio-protobuf',
            data_utils.get_content_type('application/x-recordio-protobuf'))

        with self.assertRaises(exc.UserError):
            data_utils.get_content_type('incorrect_format')
        with self.assertRaises(exc.UserError):
            data_utils.get_content_type('text/csv; label_size=5')
        with self.assertRaises(exc.UserError):
            data_utils.get_content_type('text/csv; label_size=1=1')
        with self.assertRaises(exc.UserError):
            data_utils.get_content_type('text/csv; label_size=1; label_size=2')
        with self.assertRaises(exc.UserError):
            data_utils.get_content_type('label_size=1; text/csv')