def get_content_type(content_type_cfg_val): """Get content type from data config. Assumes that training and validation data have the same content type. ['libsvm', 'text/libsvm ;charset=utf8', 'text/x-libsvm'] will return 'libsvm' ['csv', 'text/csv', 'text/csv; label_size=1'] will return 'csv' :param content_type_cfg_val :return: Parsed content type """ if content_type_cfg_val is None: return LIBSVM else: # cgi.parse_header extracts all arguments after ';' as key-value pairs # e.g. cgi.parse_header('text/csv;label_size=1;charset=utf8') returns # the tuple ('text/csv', {'label_size': '1', 'charset': 'utf8'}) content_type, params = cgi.parse_header(content_type_cfg_val.lower()) if content_type in [CSV, _content_types.CSV]: # CSV content type allows a label_size parameter # that should be 1 for XGBoost if (params and 'label_size' in params and params['label_size'] != '1'): msg = "{} is not an accepted csv ContentType. "\ "Optional parameter label_size must be equal to 1".format(content_type_cfg_val) raise exc.UserError(msg) return CSV elif content_type in [LIBSVM, xgb_content_types.LIBSVM, xgb_content_types.X_LIBSVM]: return LIBSVM elif content_type in [PARQUET, xgb_content_types.X_PARQUET]: return PARQUET elif content_type in [RECORDIO_PROTOBUF, xgb_content_types.X_RECORDIO_PROTOBUF]: return RECORDIO_PROTOBUF else: raise exc.UserError(_get_invalid_content_type_error_msg(content_type_cfg_val))
def objective_validator(value, dependencies): num_class = dependencies.get("num_class") if value in ("multi:softmax", "multi:softprob") and num_class is None: raise exc.UserError("Require input for parameter 'num_class' for multi-classification") if value is None and num_class is not None: raise exc.UserError("Do not need to setup parameter 'num_class' for learning task other than " "multi-classification.")
def updater_validator(value, dependencies): valid_tree_plugins = ['grow_colmaker', 'distcol', 'grow_histmaker', 'grow_local_histmaker', 'grow_skmaker', 'sync', 'refresh', 'prune'] valid_tree_build_plugins = ['grow_colmaker', 'distcol', 'grow_histmaker', 'grow_local_histmaker', 'grow_colmaker'] valid_linear_plugins = ['shotgun', 'coord_descent'] valid_process_update_plugins = ['refresh', 'prune'] if dependencies.get('booster') == 'gblinear': # validate only one linear updater is selected if not (len(value) == 1 and value[0] in valid_linear_plugins): raise exc.UserError("Linear updater should be one of these options: {}.".format( ', '.join("'{0}'".format(valid_updater for valid_updater in valid_linear_plugins)) )) elif dependencies.get('process_type') == 'update': if not all(x in valid_process_update_plugins for x in value): raise exc.UserError("process_type 'update' can only be used with updater 'refresh' and 'prune'") else: if not all(x in valid_tree_plugins for x in value): raise exc.UserError( "Tree updater should be selected from these options: 'grow_colmaker', 'distcol', 'grow_histmaker', " "'grow_local_histmaker', 'grow_skmaker', 'sync', 'refresh', 'prune', 'shortgun', 'coord_descent'.") # validate only one tree updater is selected counter = 0 for tmp in value: if tmp in valid_tree_build_plugins: counter += 1 if counter > 1: raise exc.UserError("Only one tree grow plugin can be selected. Choose one from the" "following: 'grow_colmaker', 'distcol', 'grow_histmaker', " "'grow_local_histmaker', 'grow_skmaker'")
def eval_metric_dep_validator(value, dependencies): objective = dependencies["objective"] if "auc" in value: if not any(objective.startswith(metric_type) for metric_type in ['binary:', 'rank:']): raise exc.UserError("Metric 'auc' can only be applied for classification and ranking problems.") if "aft-nloglik" in value: if objective not in ["survival:aft"]: raise exc.UserError("Metric 'aft-nloglik' can only be applied for 'survival:aft' objective.")
def validate(self, user_hyperparameters): # NOTE: 0. Validate required or fill in default. for hp in self.hyperparameters: if hp not in user_hyperparameters: if self.hyperparameters[hp].required: raise exc.UserError( "Missing required hyperparameter: {}".format(hp)) elif self.hyperparameters[hp].default is not None: user_hyperparameters[hp] = self.hyperparameters[hp].default # NOTE: 1. Convert hyperparameters. converted_hyperparameters = {} for hp, value in user_hyperparameters.items(): try: hyperparameter_obj = self.hyperparameters[hp] except KeyError: raise exc.UserError( "Extraneous hyperparameter found: {}".format(hp)) try: converted_hyperparameters[hp] = hyperparameter_obj.parse(value) except ValueError as e: raise exc.UserError( "Hyperparameter {}: could not parse value".format(hp), caused_by=e) # NOTE: 2. Validate range. for hp, value in converted_hyperparameters.items(): try: self.hyperparameters[hp].validate_range(value) except exc.UserError: raise except Exception as e: raise exc.AlgorithmError( "Hyperparameter {}: unexpected failure when validating {}". format(hp, value), caused_by=e) # NOTE: 3. Validate dependencies. sorted_deps = self._sort_dependencies(converted_hyperparameters.keys()) new_validated_hyperparameters = {} while sorted_deps: hp = sorted_deps.pop() value = converted_hyperparameters[hp] if self.hyperparameters[hp].dependencies: dependencies = { hp_d: new_validated_hyperparameters[hp_d] for hp_d in self.hyperparameters[hp].dependencies if hp_d in new_validated_hyperparameters } self.hyperparameters[hp].validate_dependencies( value, dependencies) new_validated_hyperparameters[hp] = value return new_validated_hyperparameters
def eval_metric_dep_validator(value, dependencies): if "auc" in value: if not any(dependencies["objective"].startswith(metric_type) for metric_type in ['binary:', 'rank:']): raise exc.UserError( "Metric 'auc' can only be applied for classification and ranking problems." )
def interaction_constraints_validator(value, dependencies): tree_method = dependencies.get("tree_method") if value is not None and tree_method not in ("exact", "hist", "approx"): raise exc.UserError( "interaction_constraints can be used only when the tree_method parameter is set to " "either 'exact', 'hist' or 'approx'.")
def get_dmatrix(data_path, content_type, csv_weights=0): """Create Data Matrix from CSV or LIBSVM file. Assumes that sanity validation for content type has been done. :param data_path: Either directory or file :param content_type: :param csv_weights: Only used if file_type is 'csv'. 1 if the instance weights are in the second column of csv file; otherwise, 0 :return: xgb.DMatrix """ if not os.path.exists(data_path): return None else: if os.path.isfile(data_path): files_path = data_path else: for root, dirs, files in os.walk(data_path): if dirs == []: files_path = root break if content_type.lower() == CSV: dmatrix = get_csv_dmatrix(files_path, csv_weights) elif content_type.lower() == LIBSVM: dmatrix = get_libsvm_dmatrix(files_path) if dmatrix.get_label().size == 0: raise exc.UserError( "Got input data without labels. Please check the input data set. " "If training job is running on multiple instances, please switch " "to using single instance if number of records in the data set " "is less than number of workers (16 * number of instance) in the cluster." ) return dmatrix
def validate(self, value): """Validate the provided configuration against the channel's supported configuration.""" if (value[CONTENT_TYPE], value[TRAINING_INPUT_MODE], value[S3_DIST_TYPE]) not in self.supported: raise exc.UserError( "Channel configuration for '{}' channel is not supported: {}". format(self.name, value))
def _get_parquet_dmatrix_pipe_mode(pipe_path): """Get Data Matrix from parquet data in pipe mode. :param pipe_path: SageMaker pipe path where parquet formatted training data is piped :return: xgb.DMatrix or None """ try: f = mlio.SageMakerPipe(pipe_path) examples = [] with f.open_read() as strm: reader = mlio.ParquetRecordReader(strm) for record in reader: table = pq.read_table(as_arrow_file(record)) array = table.to_pandas() if type(array) is pd.DataFrame: array = array.to_numpy() examples.append(array) if examples: data = np.vstack(examples) del examples dmatrix = xgb.DMatrix(data[:, 1:], label=data[:, 0]) return dmatrix else: return None except Exception as e: raise exc.UserError("Failed to load parquet data with exception:\n{}".format(e))
def _get_csv_dmatrix_file_mode(files_path, csv_weights): """Get Data Matrix from CSV data in file mode. Infer the delimiter of data from first line of first data file. :param files_path: File path where CSV formatted training data resides, either directory or file :param csv_weights: 1 if instance weights are in second column of CSV data; else 0 :return: xgb.DMatrix """ csv_file = files_path if os.path.isfile(files_path) else [ f for f in os.listdir(files_path) if os.path.isfile(os.path.join(files_path, f))][0] with open(os.path.join(files_path, csv_file)) as read_file: sample_csv_line = read_file.readline() delimiter = _get_csv_delimiter(sample_csv_line) try: if csv_weights == 1: dmatrix = xgb.DMatrix( '{}?format=csv&label_column=0&delimiter={}&weight_column=1'.format(files_path, delimiter)) else: dmatrix = xgb.DMatrix('{}?format=csv&label_column=0&delimiter={}'.format(files_path, delimiter)) except Exception as e: raise exc.UserError("Failed to load csv data with exception:\n{}".format(e)) return dmatrix
def _get_csv_delimiter(sample_csv_line): try: delimiter = csv.Sniffer().sniff(sample_csv_line).delimiter logging.info("Determined delimiter of CSV input is \'{}\'".format(delimiter)) except Exception as e: raise exc.UserError("Could not determine delimiter on line {}:\n{}".format(sample_csv_line[:50], e)) return delimiter
def validate_data_file_path(data_path, content_type): """Validate data in data_path are formatted correctly based on content_type. Note: This is not a comprehensive validation. XGBoost has its own content validation. :param data_path: :param content_type: """ parsed_content_type = get_content_type(content_type) if not os.path.exists(data_path): raise exc.UserError("{} is not a valid path!".format(data_path)) else: if os.path.isfile(data_path): data_files = [data_path] else: dir_path = None for root, dirs, files in os.walk(data_path): if dirs == []: dir_path = root break data_files = [ os.path.join(dir_path, file_name) for file_name in os.listdir(dir_path) if _is_data_file( dir_path, file_name)] if parsed_content_type.lower() == CSV: for data_file_path in data_files: _validate_csv_format(data_file_path) elif parsed_content_type.lower() == LIBSVM: for data_file_path in data_files: _validate_libsvm_format(data_file_path) elif parsed_content_type.lower() == PARQUET or parsed_content_type.lower() == RECORDIO_PROTOBUF: # No op return
def _validate_libsvm_format(file_path): """Validate that data file is LIBSVM format. XGBoost expects the following LIBSVM format: <label>(:<instance weight>) <index>:<value> <index>:<value> <index>:<value> ... Note: This only validates the first line that has a feature. This is not a comprehensive file check, as XGBoost will have its own data validation. :param file_path """ with open(file_path, 'r', errors='ignore') as read_file: for line_to_validate in read_file: num_sparse_libsvm_features = _get_num_valid_libsvm_features( line_to_validate) if num_sparse_libsvm_features > 1: # Return after first valid LIBSVM line with features return elif num_sparse_libsvm_features < 0: raise exc.UserError( _get_invalid_libsvm_error_msg( line_snippet=line_to_validate[:50], file_name=file_path.split('/')[-1])) logging.warning( "File {} is not an invalid LIBSVM file but has no features. Accepting simple validation." .format(file_path.split('/')[-1]))
def validate(self, value): """Validates the provided configuration against the channel's supported configuration.""" if (value["ContentType"], value["TrainingInputMode"], value["S3DistributionType"]) not in self.supported: raise exc.UserError( "Channel configuration for '{}' channel is not supported: {}". format(self.name, value))
def get_size(data_path, is_pipe=False): """Return size of data files at dir_path. :param data_path: Either directory or file :param is_pipe: Boolean to indicate if data is being read in pipe mode :return: Size of data or 1 if sagemaker pipe found """ if is_pipe and os.path.exists(data_path + '_0'): logging.info('Pipe path {} found.'.format(data_path)) return 1 if not os.path.exists(data_path): logging.info('Path {} does not exist!'.format(data_path)) return 0 else: total_size = 0 if os.path.isfile(data_path): return os.path.getsize(data_path) else: for root, dirs, files in os.walk(data_path): for current_file in files: if current_file.startswith('.'): raise exc.UserError("Hidden file found in the data path! Remove that before training.") file_path = os.path.join(root, current_file) total_size += os.path.getsize(file_path) return total_size
def train_job(train_cfg, train_dmatrix, val_dmatrix, model_dir, is_master): """Train and save XGBoost model using data on current node. If doing distributed training, XGBoost will use rabit to sync the trained model between each boosting iteration. Trained model is only saved if 'is_master' is True. :param train_cfg: Training hyperparameter configurations :param train_dmatrix: Training Data Matrix :param val_dmatrix: Validation Data Matrix :param model_dir: Directory where model will be saved :param is_master: True if single node training, or the current node is the master node in distributed training. """ # Parse arguments for train() API early_stopping_rounds = train_cfg.get('early_stopping_rounds') num_round = train_cfg["num_round"] # Evaluation metrics to use with train() API tuning_objective_metric_param = train_cfg.get("_tuning_objective_metric") eval_metric = train_cfg.get("eval_metric") cleaned_eval_metric, configured_feval = train_utils.get_eval_metrics_and_feval( tuning_objective_metric_param, eval_metric) if cleaned_eval_metric: train_cfg['eval_metric'] = cleaned_eval_metric else: train_cfg.pop('eval_metric', None) # Set callback evals watchlist = [(train_dmatrix, 'train')] if val_dmatrix is not None: watchlist.append((val_dmatrix, 'validation')) logging.info("Train matrix has {} rows".format(train_dmatrix.num_row())) if val_dmatrix: logging.info("Validation matrix has {} rows".format( val_dmatrix.num_row())) try: logging.info(train_cfg) bst = xgb.train(train_cfg, train_dmatrix, num_boost_round=num_round, evals=watchlist, feval=configured_feval, early_stopping_rounds=early_stopping_rounds) except Exception as e: for customer_error_message in CUSTOMER_ERRORS: if customer_error_message in str(e): raise exc.UserError(str(e)) exception_prefix = "XGB train call failed with exception" raise exc.AlgorithmError("{}:\n {}".format(exception_prefix, str(e))) if not os.path.exists(model_dir): os.makedirs(model_dir) if is_master: model_location = model_dir + '/xgboost-model' pkl.dump(bst, open(model_location, 'wb')) logging.debug("Stored trained model at {}".format(model_location))
def eval_metric_range_validator(SUPPORTED_METRIC, metric): if "<function" in metric: raise exc.UserError("User defined evaluation metric {} is not supported yet.".format(metric)) if "@" in metric: metric_name = metric.split('@')[0].strip() metric_threshold = metric.split('@')[1].strip() if metric_name not in ["error", "ndcg", "map"]: raise exc.UserError( "Metric '{}' is not supported. Parameter 'eval_metric' with customized threshold should " "be one of these options: 'error', 'ndcg', 'map'.".format(metric)) try: float(metric_threshold) except ValueError: raise exc.UserError("Threshold value 't' in '{}@t' expects float input.".format(metric_name)) return True return metric in SUPPORTED_METRIC
def get_libsvm_dmatrix(files_path, is_pipe=False): """Get DMatrix from libsvm file path. Pipe mode not currently supported for libsvm. :param files_path: File path where LIBSVM formatted training data resides, either directory or file :param is_pipe: Boolean to indicate if data is being read in pipe mode :return: xgb.DMatrix """ if is_pipe: raise exc.UserError("Pipe mode not supported for LibSVM.") try: dmatrix = xgb.DMatrix(files_path) except Exception as e: raise exc.UserError("Failed to load libsvm data with exception:\n{}".format(e)) return dmatrix
def eval_metric_dep_validator(value, dependencies): if "auc" in value: if (dependencies["objective"] not in [ "binary:logistic", "binary:logitraw", "multi:softmax", "multi:softprob", "reg:logistic", "rank:pairwise", "binary:hinge" ]): raise exc.UserError( "Metric 'auc' can only be applied for classification and ranking problem." )
def validate(self, user_channels): """Validate the provided user-specified channels at runtime against the channels' supported configuration. Note that this adds default content type for channels if a default exists. :param user_channels: dictionary of channels formatted like so { "channel_name": { "ContentType": <content_type>. "TrainingInputMode": <training_input_mode>, "S3DistributionType": <s3_dist_type>, ... }, "channel_name": {... } } """ for channel in self.channels: if channel.name not in user_channels: if channel.required: raise exc.UserError("Missing required channel: {}".format( channel.name)) name_to_channel = {channel.name: channel for channel in self.channels} validated_channels = {} for channel, value in user_channels.items(): try: channel_obj = name_to_channel[channel] except KeyError: raise exc.UserError( "Extraneous channel found: {}".format(channel)) if CONTENT_TYPE not in value: if self.default_content_type: value[CONTENT_TYPE] = self.default_content_type else: raise exc.UserError( "Missing content type for channel: {}".format(channel)) channel_obj.validate(value) validated_channels[channel] = value return validated_channels
def get_recordio_protobuf_dmatrix(path, is_pipe=False, subsample_ratio_on_read=None): """Get Data Matrix from recordio-protobuf data. :param path: Path where recordio-protobuf formatted training data resides, either directory, file, or SageMaker pipe :param is_pipe: Boolean to indicate if data is being read in pipe mode :param subsample_ratio_on_read: None or a value in (0, 1) to indicate how much of the dataset should be read into memory. :return: xgb.DMatrix or None """ try: if is_pipe: dataset = [mlio.SageMakerPipe(path)] reader = mlio.RecordIOProtobufReader( dataset=dataset, batch_size=BATCH_SIZE, subsample_ratio=subsample_ratio_on_read) else: dataset = mlio.list_files(path) reader = mlio.RecordIOProtobufReader( dataset=dataset, batch_size=BATCH_SIZE, subsample_ratio=subsample_ratio_on_read) exm = reader.peek_example() if exm is None: return None # Recordio-protobuf tensor may be dense (use numpy) or sparse (use scipy) if isinstance(exm['values'], mlio.DenseTensor): to_matrix = as_numpy vstack = np.vstack else: to_matrix = to_coo_matrix vstack = scipy_vstack all_values = [] all_labels = [] for example in reader: values = to_matrix(example['values']) all_values.append(values) labels = as_numpy(example['label_values']).squeeze() all_labels.append(labels) all_values = vstack(all_values) all_labels = np.concatenate(all_labels) return xgb.DMatrix(all_values, label=all_labels) except Exception as e: raise exc.UserError( "Failed to load recordio-protobuf data with exception:\n{}".format( e))
def validate(self, user_channels): """Validates the provided user-specified channels at runtime against the channels' supported configuration.""" for channel in self.channels: if channel.name not in user_channels: if channel.required: raise exc.UserError("Missing required channel: {}".format( channel.name)) name_to_channel = {channel.name: channel for channel in self.channels} validated_channels = {} for channel, value in user_channels.items(): try: channel_obj = name_to_channel[channel] except KeyError: raise exc.UserError( "Extraneous channel found: {}".format(channel_obj)) channel_obj.validate(value) validated_channels[channel] = value return validated_channels
def get_content_type(request): content_type = request.content_type or "text/csv" content_type = content_type.lower() tokens = content_type.split(";") content_type = tokens[0].strip() if content_type not in ['text/csv', 'text/libsvm', 'text/x-libsvm']: raise exceptions.UserError( "Content-type {} not supported. " "Supported content-type is text/csv, text/libsvm".format( content_type)) return content_type
def get_libsvm_dmatrix(files_path): """Get DMatrix from libsvm file path. :param files_path: File path where LIBSVM formatted training data resides, either directory or file :return: xgb.DMatrix """ try: dmatrix = xgb.DMatrix(files_path) except Exception as e: raise exc.UserError( "Failed to load libsvm data with exception:\n{}".format(e)) return dmatrix
def get_recordio_protobuf_dmatrix(path, is_pipe=False): """Get Data Matrix from recordio-protobuf data. :param path: Path where recordio-protobuf formatted training data resides, either directory, file, or SageMaker pipe :param is_pipe: Boolean to indicate if data is being read in pipe mode :return: xgb.DMatrix or None """ try: if is_pipe: pipes_path = path if isinstance(path, list) else [path] dataset = [ mlio.SageMakerPipe(pipe_path) for pipe_path in pipes_path ] else: dataset = mlio.list_files(path) reader_params = mlio.DataReaderParams(dataset=dataset, batch_size=BATCH_SIZE) reader = mlio.RecordIOProtobufReader(reader_params) if reader.peek_example() is not None: # recordio-protobuf tensor may be dense (use numpy) or sparse (use scipy) is_dense_tensor = type( reader.peek_example()['values']) is mlio.DenseTensor all_features = [] all_labels = [] for example in reader: features = as_numpy( example['values']) if is_dense_tensor else to_coo_matrix( example['values']) all_features.append(features) labels = as_numpy(example['label_values']) all_labels.append(labels) all_features = np.vstack( all_features) if is_dense_tensor else scipy_vstack( all_features).tocsr() all_labels = np.concatenate(all_labels, axis=None) dmatrix = xgb.DMatrix(all_features, label=all_labels) return dmatrix else: return None except Exception as e: raise exc.UserError( "Failed to load recordio-protobuf data with exception:\n{}".format( e))
def _get_csv_dmatrix_pipe_mode(pipe_path, csv_weights): """Get Data Matrix from CSV data in pipe mode. :param pipe_path: SageMaker pipe path where CSV formatted training data is piped :param csv_weights: 1 if instance weights are in second column of CSV data; else 0 :return: xgb.DMatrix or None """ try: pipes_path = pipe_path if isinstance(pipe_path, list) else [pipe_path] dataset = [mlio.SageMakerPipe(path) for path in pipes_path] reader_params = mlio.DataReaderParams(dataset=dataset, batch_size=BATCH_SIZE) csv_params = mlio.CsvParams(header_row_index=None) reader = mlio.CsvReader(reader_params, csv_params) # Check if data is present in reader if reader.peek_example() is not None: examples = [] for example in reader: # Write each feature (column) of example into a single numpy array tmp = [as_numpy(feature).squeeze() for feature in example] tmp = np.array(tmp) if len(tmp.shape) > 1: # Columns are written as rows, needs to be transposed tmp = tmp.T else: # If tmp is a 1-D array, it needs to be reshaped as a matrix tmp = np.reshape(tmp, (1, tmp.shape[0])) examples.append(tmp) data = np.vstack(examples) del examples if csv_weights == 1: dmatrix = xgb.DMatrix(data[:, 2:], label=data[:, 0], weight=data[:, 1]) else: dmatrix = xgb.DMatrix(data[:, 1:], label=data[:, 0]) return dmatrix else: return None except Exception as e: raise exc.UserError( "Failed to load csv data with exception:\n{}".format(e))
def _validate_csv_format(file_path): """Validate that data file is CSV format. Check that delimiter can be inferred. Note: This only validates the first line in the file. This is not a comprehensive file check, as XGBoost will have its own data validation. :param file_path """ with open(file_path, 'r', errors='ignore') as read_file: line_to_validate = read_file.readline() _get_csv_delimiter(line_to_validate) if _get_num_valid_libsvm_features(line_to_validate) > 0: # Throw error if this line can be parsed as LIBSVM formatted line. raise exc.UserError( _get_invalid_csv_error_msg(line_snippet=line_to_validate, file_name=file_path.split('/')[-1]))
def get_dmatrix(data_path, content_type, csv_weights=0, is_pipe=False): """Create Data Matrix from CSV or LIBSVM file. Assumes that sanity validation for content type has been done. :param data_path: Either directory or file :param content_type: :param csv_weights: Only used if file_type is 'csv'. 1 if the instance weights are in the second column of csv file; otherwise, 0 :param is_pipe: Boolean to indicate if data is being read in pipe mode :return: xgb.DMatrix or None """ if not (os.path.exists(data_path) or (is_pipe and os.path.exists(data_path + '_0'))): return None else: if os.path.isfile(data_path) or is_pipe: files_path = data_path elif not is_pipe: for root, dirs, files in os.walk(data_path): if dirs == []: files_path = root break if content_type.lower() == CSV: dmatrix = get_csv_dmatrix(files_path, csv_weights, is_pipe) elif content_type.lower() == LIBSVM: dmatrix = get_libsvm_dmatrix(files_path, is_pipe) elif content_type.lower() == PARQUET: dmatrix = get_parquet_dmatrix(files_path, is_pipe) elif content_type.lower() == RECORDIO_PROTOBUF: dmatrix = get_recordio_protobuf_dmatrix(files_path, is_pipe) if dmatrix and dmatrix.get_label().size == 0: raise exc.UserError( "Got input data without labels. Please check the input data set. " "If training job is running on multiple instances, please switch " "to using single instance if number of records in the data set " "is less than number of workers (16 * number of instance) in the cluster." ) return dmatrix
def _user_module_transformer(user_module): model_fn = getattr(user_module, "model_fn", default_model_fn) input_fn = getattr(user_module, "input_fn", None) predict_fn = getattr(user_module, "predict_fn", None) output_fn = getattr(user_module, "output_fn", None) transform_fn = getattr(user_module, "transform_fn", None) if transform_fn and (input_fn or predict_fn or output_fn): raise exc.UserError( "Cannot use transform_fn implementation with input_fn, predict_fn, and/or output_fn" ) if transform_fn is not None: return transformer.Transformer(model_fn=model_fn, transform_fn=transform_fn) else: return transformer.Transformer( model_fn=model_fn, input_fn=input_fn or default_input_fn, predict_fn=default_predict_fn, output_fn=output_fn or default_output_fn, )