Esempio n. 1
0
    def test_validators(self):
        gt = ParamValidators.gt(10)
        self.assertFalse(gt.validate(None))
        self.assertFalse(gt.validate(5))
        self.assertFalse(gt.validate(10))
        self.assertTrue(gt.validate(15))

        gt_eq = ParamValidators.gt_eq(10)
        self.assertFalse(gt_eq.validate(None))
        self.assertFalse(gt_eq.validate(5))
        self.assertTrue(gt_eq.validate(10))
        self.assertTrue(gt_eq.validate(15))

        lt = ParamValidators.lt(10)
        self.assertFalse(lt.validate(None))
        self.assertTrue(lt.validate(5))
        self.assertFalse(lt.validate(10))
        self.assertFalse(lt.validate(15))

        lt_eq = ParamValidators.lt_eq(10)
        self.assertFalse(lt_eq.validate(None))
        self.assertTrue(lt_eq.validate(5))
        self.assertTrue(lt_eq.validate(10))
        self.assertFalse(lt_eq.validate(15))

        in_range_inclusive = ParamValidators.in_range(5, 15)
        self.assertFalse(in_range_inclusive.validate(None))
        self.assertFalse(in_range_inclusive.validate(0))
        self.assertTrue(in_range_inclusive.validate(5))
        self.assertTrue(in_range_inclusive.validate(10))
        self.assertTrue(in_range_inclusive.validate(15))
        self.assertFalse(in_range_inclusive.validate(20))

        in_range_exclusive = ParamValidators.in_range(5, 15, False, False)
        self.assertFalse(in_range_exclusive.validate(None))
        self.assertFalse(in_range_exclusive.validate(0))
        self.assertFalse(in_range_exclusive.validate(5))
        self.assertTrue(in_range_exclusive.validate(10))
        self.assertFalse(in_range_exclusive.validate(15))
        self.assertFalse(in_range_exclusive.validate(20))

        in_array = ParamValidators.in_array([1, 2, 3])
        self.assertFalse(in_array.validate(None))
        self.assertTrue(in_array.validate(1))
        self.assertFalse(in_array.validate(0))

        not_null = ParamValidators.not_null()
        self.assertTrue(not_null.validate(5))
        self.assertFalse(not_null.validate(None))
Esempio n. 2
0
class HasReg(WithParams, ABC):
    """
    Base class for the shared regularization param.
    """
    REG: Param[float] = FloatParam("reg", "Regularization parameter.", 0.,
                                   ParamValidators.gt_eq(0.))

    def set_reg(self, value: float):
        return self.set(self.REG, value)

    def get_reg(self) -> float:
        return self.get(self.REG)

    @property
    def reg(self) -> float:
        return self.get_reg()
Esempio n. 3
0
class HasTol(WithParams, ABC):
    """
    Base class for the shared tolerance param.
    """
    TOL: Param[float] = FloatParam(
        "tol", "Convergence tolerance for iterative algorithms.", 1e-6,
        ParamValidators.gt_eq(0))

    def set_tol(self, value: float):
        return self.set(self.TOL, value)

    def get_tol(self) -> float:
        return self.get(self.TOL)

    @property
    def tol(self) -> float:
        return self.get_tol()
Esempio n. 4
0
class _NaiveBayesParams(
        _NaiveBayesModelParams,
        HasLabelCol,
):
    """
    Params for :class:`NaiveBayes`.
    """

    SMOOTHING: Param[float] = FloatParam("smoothing",
                                         "The smoothing parameter.", 1.0,
                                         ParamValidators.gt_eq(0.0))

    def __init__(self, java_params):
        super(_NaiveBayesParams, self).__init__(java_params)

    def set_smoothing(self, value: float):
        return typing.cast(_NaiveBayesParams, self.set(self.SMOOTHING, value))

    def get_smoothing(self) -> float:
        return self.get(self.SMOOTHING)

    @property
    def smoothing(self) -> float:
        return self.get_smoothing()