class HasMultiClass(WithParams, ABC): """ Base class for the shared multi class param. Supported options: <li>auto: selects the classification type based on the number of classes: If the number of unique label values from the input data is one or two, set to "binomial". Otherwise, set to "multinomial". <li>binomial: binary logistic regression. <li>multinomial: multinomial logistic regression. """ MULTI_CLASS: Param[str] = StringParam( "multi_class", "Classification type. Supported options: 'auto', 'binomial' and 'multinomial'.", 'auto', ParamValidators.in_array(['auto', 'binomial', 'multinomial'])) def set_multi_class(self, class_type: str): return self.set(self.MULTI_CLASS, class_type) def get_multi_class(self) -> str: return self.get(self.MULTI_CLASS) @property def multi_class(self) -> str: return self.get_multi_class()
class _StringIndexerParams(_StringIndexerModelParams): """ Params for :class:`StringIndexer`. """ STRING_ORDER_TYPE: Param[str] = StringParam( "string_order_type", "How to order strings of each column.", "arbitrary", ParamValidators.in_array([ 'arbitrary', 'frequencyDesc', 'frequencyAsc', 'alphabetDesc', 'alphabetAsc' ])) def __init__(self, java_params): super(_StringIndexerParams, self).__init__(java_params) def set_string_order_type(self, value: str): return typing.cast(_StringIndexerParams, self.set(self.STRING_ORDER_TYPE, value)) def get_string_order_type(self) -> str: return self.get(self.STRING_ORDER_TYPE) @property def string_order_type(self): return self.get_string_order_type()
def test_validators(self): gt = ParamValidators.gt(10) self.assertFalse(gt.validate(None)) self.assertFalse(gt.validate(5)) self.assertFalse(gt.validate(10)) self.assertTrue(gt.validate(15)) gt_eq = ParamValidators.gt_eq(10) self.assertFalse(gt_eq.validate(None)) self.assertFalse(gt_eq.validate(5)) self.assertTrue(gt_eq.validate(10)) self.assertTrue(gt_eq.validate(15)) lt = ParamValidators.lt(10) self.assertFalse(lt.validate(None)) self.assertTrue(lt.validate(5)) self.assertFalse(lt.validate(10)) self.assertFalse(lt.validate(15)) lt_eq = ParamValidators.lt_eq(10) self.assertFalse(lt_eq.validate(None)) self.assertTrue(lt_eq.validate(5)) self.assertTrue(lt_eq.validate(10)) self.assertFalse(lt_eq.validate(15)) in_range_inclusive = ParamValidators.in_range(5, 15) self.assertFalse(in_range_inclusive.validate(None)) self.assertFalse(in_range_inclusive.validate(0)) self.assertTrue(in_range_inclusive.validate(5)) self.assertTrue(in_range_inclusive.validate(10)) self.assertTrue(in_range_inclusive.validate(15)) self.assertFalse(in_range_inclusive.validate(20)) in_range_exclusive = ParamValidators.in_range(5, 15, False, False) self.assertFalse(in_range_exclusive.validate(None)) self.assertFalse(in_range_exclusive.validate(0)) self.assertFalse(in_range_exclusive.validate(5)) self.assertTrue(in_range_exclusive.validate(10)) self.assertFalse(in_range_exclusive.validate(15)) self.assertFalse(in_range_exclusive.validate(20)) in_array = ParamValidators.in_array([1, 2, 3]) self.assertFalse(in_array.validate(None)) self.assertTrue(in_array.validate(1)) self.assertFalse(in_array.validate(0)) not_null = ParamValidators.not_null() self.assertTrue(not_null.validate(5)) self.assertFalse(not_null.validate(None))
class HasBatchStrategy(WithParams, ABC): """ Base class for the shared batch strategy param. """ BATCH_STRATEGY: Param[str] = StringParam( "batch_strategy", "Strategy to create mini batch from online train data.", "count", ParamValidators.in_array(["count"])) def get_batch_strategy(self) -> str: return self.get(self.BATCH_STRATEGY) @property def batch_strategy(self): return self.get_batch_strategy()
class HasDistanceMeasure(WithParams, ABC): """ Base class for the shared distance_measure param. """ DISTANCE_MEASURE: Param[str] = StringParam( "distance_measure", "Distance measure. Supported options: 'euclidean' and 'cosine'.", "euclidean", ParamValidators.in_array(['euclidean', 'cosine'])) def set_distance_measure(self, distance_measure: str): return self.set(self.DISTANCE_MEASURE, distance_measure) def get_distance_measure(self) -> str: return self.get(self.DISTANCE_MEASURE) @property def distance_measure(self) -> str: return self.get_distance_measure()
class _KMeansParams(_KMeansModelParams, HasSeed, HasMaxIter): """ Params for :class:`KMeans`. """ INIT_MODE: Param[str] = StringParam( "init_mode", "The initialization algorithm. Supported options: 'random'.", "random", ParamValidators.in_array(["random"])) def __init__(self, java_params): super(_KMeansParams, self).__init__(java_params) def set_init_mode(self, value: str): return self.set(self.INIT_MODE, value) def get_init_mode(self) -> str: return self.get(self.INIT_MODE) @property def init_mode(self): return self.get_init_mode()
class _NaiveBayesModelParams(JavaWithParams, HasFeaturesCol, HasPredictionCol, ABC): """ Params for :class:`NaiveBayesModel`. """ MODEL_TYPE: Param[str] = StringParam( "model_type", "The model type.", "multinomial", ParamValidators.in_array(["multinomial"])) def __init__(self, java_params): super(_NaiveBayesModelParams, self).__init__(java_params) def set_model_type(self, value: str): return self.set(self.MODEL_TYPE, value) def get_model_type(self) -> str: return self.get(self.MODEL_TYPE) @property def model_type(self) -> str: return self.get_model_type()
class HasHandleInvalid(WithParams, ABC): """ Base class for the shared handle_invalid param. Supported options and the corresponding behavior to handle invalid entries is listed as follows. <ul> <li>error: raise an exception. <li>skip: filter out rows with bad values. </ul> """ HANDLE_INVALID: Param[str] = StringParam( "handle_invalid", "Strategy to handle invalid entries.", "error", ParamValidators.in_array(['error', 'skip'])) def set_handle_invalid(self, value: str): return self.set(self.HANDLE_INVALID, value) def get_handle_invalid(self) -> str: return self.get(self.HANDLE_INVALID) @property def handle_invalid(self) -> str: return self.get_handle_invalid()