class _MinMaxScalerParams(JavaWithParams, HasInputCol, HasOutputCol): """ Params for :class:`MinMaxScaler`. """ MIN: Param[float] = FloatParam("min", "Lower bound of the output feature range.", 0.0, ParamValidators.not_null()) MAX: Param[float] = FloatParam("max", "Upper bound of the output feature range.", 1.0, ParamValidators.not_null()) def __init__(self, java_params): super(_MinMaxScalerParams, self).__init__(java_params) def set_min(self, value: float): return typing.cast(_MinMaxScalerParams, self.set(self.MIN, value)) def set_max(self, value: float): return typing.cast(_MinMaxScalerParams, self.set(self.MAX, value)) def get_min(self) -> bool: return self.get(self.MIN) def get_max(self) -> bool: return self.get(self.MAX) @property def min(self): return self.get_min() @property def max(self): return self.get_max()
def test_validators(self): gt = ParamValidators.gt(10) self.assertFalse(gt.validate(None)) self.assertFalse(gt.validate(5)) self.assertFalse(gt.validate(10)) self.assertTrue(gt.validate(15)) gt_eq = ParamValidators.gt_eq(10) self.assertFalse(gt_eq.validate(None)) self.assertFalse(gt_eq.validate(5)) self.assertTrue(gt_eq.validate(10)) self.assertTrue(gt_eq.validate(15)) lt = ParamValidators.lt(10) self.assertFalse(lt.validate(None)) self.assertTrue(lt.validate(5)) self.assertFalse(lt.validate(10)) self.assertFalse(lt.validate(15)) lt_eq = ParamValidators.lt_eq(10) self.assertFalse(lt_eq.validate(None)) self.assertTrue(lt_eq.validate(5)) self.assertTrue(lt_eq.validate(10)) self.assertFalse(lt_eq.validate(15)) in_range_inclusive = ParamValidators.in_range(5, 15) self.assertFalse(in_range_inclusive.validate(None)) self.assertFalse(in_range_inclusive.validate(0)) self.assertTrue(in_range_inclusive.validate(5)) self.assertTrue(in_range_inclusive.validate(10)) self.assertTrue(in_range_inclusive.validate(15)) self.assertFalse(in_range_inclusive.validate(20)) in_range_exclusive = ParamValidators.in_range(5, 15, False, False) self.assertFalse(in_range_exclusive.validate(None)) self.assertFalse(in_range_exclusive.validate(0)) self.assertFalse(in_range_exclusive.validate(5)) self.assertTrue(in_range_exclusive.validate(10)) self.assertFalse(in_range_exclusive.validate(15)) self.assertFalse(in_range_exclusive.validate(20)) in_array = ParamValidators.in_array([1, 2, 3]) self.assertFalse(in_array.validate(None)) self.assertTrue(in_array.validate(1)) self.assertFalse(in_array.validate(0)) not_null = ParamValidators.not_null() self.assertTrue(not_null.validate(5)) self.assertFalse(not_null.validate(None))
class HasOutputCol(WithParams, ABC): """ Base class for the shared output_col param. """ OUTPUT_COL: Param[str] = StringParam("output_col", "Output column name.", "output", ParamValidators.not_null()) def set_output_col(self, col: str): return self.set(self.OUTPUT_COL, col) def get_output_col(self) -> str: return self.get(self.OUTPUT_COL) @property def output_col(self) -> str: return self.get_output_col()
class HasLabelCol(WithParams, ABC): """ Base class for the shared label column param. """ LABEL_COL: Param[str] = StringParam("label_col", "Label column name.", "label", ParamValidators.not_null()) def set_label_col(self, col: str): return self.set(self.LABEL_COL, col) def get_label_col(self) -> str: return self.get(self.LABEL_COL) @property def label_col(self) -> str: return self.get_label_col()
class HasInputCol(WithParams, ABC): """ Base class for the shared input col param. """ INPUT_COL: Param[str] = StringParam("input_col", "Input column name.", "input", ParamValidators.not_null()) def set_input_col(self, col: str): return self.set(self.INPUT_COL, col) def get_input_col(self) -> str: return self.get(self.INPUT_COL) @property def input_col(self) -> str: return self.get_input_col()
class HasFeaturesCol(WithParams, ABC): """ Base class for the shared feature_col param. """ FEATURES_COL: Param[str] = StringParam("features_col", "Features column name.", "features", ParamValidators.not_null()) def set_features_col(self, col): return self.set(self.FEATURES_COL, col) def get_features_col(self) -> str: return self.get(self.FEATURES_COL) @property def features_col(self) -> str: return self.get_features_col()
class HasPredictionCol(WithParams, ABC): """ Base class for the shared prediction column param. """ PREDICTION_COL: Param[str] = StringParam("prediction_col", "Prediction column name.", "prediction", ParamValidators.not_null()) def set_prediction_col(self, col: str): return self.set(self.PREDICTION_COL, col) def get_prediction_col(self) -> str: return self.get(self.PREDICTION_COL) @property def prediction_col(self) -> str: return self.get_prediction_col()
class _LinearSVCModelParams(JavaWithParams, HasFeaturesCol, HasPredictionCol, HasRawPredictionCol, ABC): """ Params for :class:`LinearSVCModel`. """ THRESHOLD: Param[float] = FloatParam( "threshold", "Threshold in binary classification prediction applied to rawPrediction.", 0.0, ParamValidators.not_null()) def __init__(self, java_params): super(_LinearSVCModelParams, self).__init__(java_params) def set_threshold(self, value: int): return typing.cast(_LinearSVCModelParams, self.set(self.THRESHOLD, value)) def get_threshold(self) -> int: return self.get(self.THRESHOLD) @property def threshold(self) -> int: return self.get_threshold()
BOOLEAN_PARAM = BooleanParam("boolean_param", "Description", False) INT_PARAM = IntParam("int_param", "Description", 1, ParamValidators.lt(100)) FLOAT_PARAM = FloatParam("float_param", "Description", 3.0, ParamValidators.lt(100)) STRING_PARAM = StringParam('string_param', "Description", "5") INT_ARRAY_PARAM = IntArrayParam("int_array_param", "Description", (6, 7)) FLOAT_ARRAY_PARAM = FloatArrayParam("float_array_param", "Description", (10.0, 11.0)) STRING_ARRAY_PARAM = StringArrayParam("string_array_param", "Description", ("14", "15")) EXTRA_INT_PARAM = IntParam("extra_int_param", "Description", 20, ParamValidators.always_true()) PARAM_WITH_NONE_DEFAULT = IntParam( "param_with_none_default", "Must be explicitly set with a non-none value", None, ParamValidators.not_null()) class StageTest(PyFlinkMLTestCase): def test_param_set_value_with_name(self): stage = MyStage() stage.set(INT_PARAM, 2) self.assertEqual(2, stage.get(INT_PARAM)) param = stage.get_param("int_param") stage.set(param, 3) self.assertEqual(3, stage.get(param)) param = stage.get_param('extra_int_param') stage.set(param, 50) self.assertEqual(50, stage.get(param))