示例#1
0
class _MinMaxScalerParams(JavaWithParams, HasInputCol, HasOutputCol):
    """
    Params for :class:`MinMaxScaler`.
    """

    MIN: Param[float] = FloatParam("min",
                                   "Lower bound of the output feature range.",
                                   0.0, ParamValidators.not_null())

    MAX: Param[float] = FloatParam("max",
                                   "Upper bound of the output feature range.",
                                   1.0, ParamValidators.not_null())

    def __init__(self, java_params):
        super(_MinMaxScalerParams, self).__init__(java_params)

    def set_min(self, value: float):
        return typing.cast(_MinMaxScalerParams, self.set(self.MIN, value))

    def set_max(self, value: float):
        return typing.cast(_MinMaxScalerParams, self.set(self.MAX, value))

    def get_min(self) -> bool:
        return self.get(self.MIN)

    def get_max(self) -> bool:
        return self.get(self.MAX)

    @property
    def min(self):
        return self.get_min()

    @property
    def max(self):
        return self.get_max()
示例#2
0
    def test_validators(self):
        gt = ParamValidators.gt(10)
        self.assertFalse(gt.validate(None))
        self.assertFalse(gt.validate(5))
        self.assertFalse(gt.validate(10))
        self.assertTrue(gt.validate(15))

        gt_eq = ParamValidators.gt_eq(10)
        self.assertFalse(gt_eq.validate(None))
        self.assertFalse(gt_eq.validate(5))
        self.assertTrue(gt_eq.validate(10))
        self.assertTrue(gt_eq.validate(15))

        lt = ParamValidators.lt(10)
        self.assertFalse(lt.validate(None))
        self.assertTrue(lt.validate(5))
        self.assertFalse(lt.validate(10))
        self.assertFalse(lt.validate(15))

        lt_eq = ParamValidators.lt_eq(10)
        self.assertFalse(lt_eq.validate(None))
        self.assertTrue(lt_eq.validate(5))
        self.assertTrue(lt_eq.validate(10))
        self.assertFalse(lt_eq.validate(15))

        in_range_inclusive = ParamValidators.in_range(5, 15)
        self.assertFalse(in_range_inclusive.validate(None))
        self.assertFalse(in_range_inclusive.validate(0))
        self.assertTrue(in_range_inclusive.validate(5))
        self.assertTrue(in_range_inclusive.validate(10))
        self.assertTrue(in_range_inclusive.validate(15))
        self.assertFalse(in_range_inclusive.validate(20))

        in_range_exclusive = ParamValidators.in_range(5, 15, False, False)
        self.assertFalse(in_range_exclusive.validate(None))
        self.assertFalse(in_range_exclusive.validate(0))
        self.assertFalse(in_range_exclusive.validate(5))
        self.assertTrue(in_range_exclusive.validate(10))
        self.assertFalse(in_range_exclusive.validate(15))
        self.assertFalse(in_range_exclusive.validate(20))

        in_array = ParamValidators.in_array([1, 2, 3])
        self.assertFalse(in_array.validate(None))
        self.assertTrue(in_array.validate(1))
        self.assertFalse(in_array.validate(0))

        not_null = ParamValidators.not_null()
        self.assertTrue(not_null.validate(5))
        self.assertFalse(not_null.validate(None))
示例#3
0
class HasOutputCol(WithParams, ABC):
    """
    Base class for the shared output_col param.
    """
    OUTPUT_COL: Param[str] = StringParam("output_col", "Output column name.",
                                         "output", ParamValidators.not_null())

    def set_output_col(self, col: str):
        return self.set(self.OUTPUT_COL, col)

    def get_output_col(self) -> str:
        return self.get(self.OUTPUT_COL)

    @property
    def output_col(self) -> str:
        return self.get_output_col()
示例#4
0
class HasLabelCol(WithParams, ABC):
    """
    Base class for the shared label column param.
    """
    LABEL_COL: Param[str] = StringParam("label_col", "Label column name.",
                                        "label", ParamValidators.not_null())

    def set_label_col(self, col: str):
        return self.set(self.LABEL_COL, col)

    def get_label_col(self) -> str:
        return self.get(self.LABEL_COL)

    @property
    def label_col(self) -> str:
        return self.get_label_col()
示例#5
0
class HasInputCol(WithParams, ABC):
    """
    Base class for the shared input col param.
    """
    INPUT_COL: Param[str] = StringParam("input_col", "Input column name.",
                                        "input", ParamValidators.not_null())

    def set_input_col(self, col: str):
        return self.set(self.INPUT_COL, col)

    def get_input_col(self) -> str:
        return self.get(self.INPUT_COL)

    @property
    def input_col(self) -> str:
        return self.get_input_col()
示例#6
0
class HasFeaturesCol(WithParams, ABC):
    """
    Base class for the shared feature_col param.
    """
    FEATURES_COL: Param[str] = StringParam("features_col",
                                           "Features column name.", "features",
                                           ParamValidators.not_null())

    def set_features_col(self, col):
        return self.set(self.FEATURES_COL, col)

    def get_features_col(self) -> str:
        return self.get(self.FEATURES_COL)

    @property
    def features_col(self) -> str:
        return self.get_features_col()
示例#7
0
class HasPredictionCol(WithParams, ABC):
    """
    Base class for the shared prediction column param.
    """
    PREDICTION_COL: Param[str] = StringParam("prediction_col",
                                             "Prediction column name.",
                                             "prediction",
                                             ParamValidators.not_null())

    def set_prediction_col(self, col: str):
        return self.set(self.PREDICTION_COL, col)

    def get_prediction_col(self) -> str:
        return self.get(self.PREDICTION_COL)

    @property
    def prediction_col(self) -> str:
        return self.get_prediction_col()
示例#8
0
class _LinearSVCModelParams(JavaWithParams, HasFeaturesCol, HasPredictionCol,
                            HasRawPredictionCol, ABC):
    """
    Params for :class:`LinearSVCModel`.
    """

    THRESHOLD: Param[float] = FloatParam(
        "threshold",
        "Threshold in binary classification prediction applied to rawPrediction.",
        0.0, ParamValidators.not_null())

    def __init__(self, java_params):
        super(_LinearSVCModelParams, self).__init__(java_params)

    def set_threshold(self, value: int):
        return typing.cast(_LinearSVCModelParams,
                           self.set(self.THRESHOLD, value))

    def get_threshold(self) -> int:
        return self.get(self.THRESHOLD)

    @property
    def threshold(self) -> int:
        return self.get_threshold()
示例#9
0
BOOLEAN_PARAM = BooleanParam("boolean_param", "Description", False)
INT_PARAM = IntParam("int_param", "Description", 1, ParamValidators.lt(100))
FLOAT_PARAM = FloatParam("float_param", "Description", 3.0,
                         ParamValidators.lt(100))
STRING_PARAM = StringParam('string_param', "Description", "5")
INT_ARRAY_PARAM = IntArrayParam("int_array_param", "Description", (6, 7))
FLOAT_ARRAY_PARAM = FloatArrayParam("float_array_param", "Description",
                                    (10.0, 11.0))
STRING_ARRAY_PARAM = StringArrayParam("string_array_param", "Description",
                                      ("14", "15"))
EXTRA_INT_PARAM = IntParam("extra_int_param", "Description", 20,
                           ParamValidators.always_true())
PARAM_WITH_NONE_DEFAULT = IntParam(
    "param_with_none_default", "Must be explicitly set with a non-none value",
    None, ParamValidators.not_null())


class StageTest(PyFlinkMLTestCase):
    def test_param_set_value_with_name(self):
        stage = MyStage()
        stage.set(INT_PARAM, 2)
        self.assertEqual(2, stage.get(INT_PARAM))

        param = stage.get_param("int_param")
        stage.set(param, 3)
        self.assertEqual(3, stage.get(param))

        param = stage.get_param('extra_int_param')
        stage.set(param, 50)
        self.assertEqual(50, stage.get(param))