def test_validators(self): gt = ParamValidators.gt(10) self.assertFalse(gt.validate(None)) self.assertFalse(gt.validate(5)) self.assertFalse(gt.validate(10)) self.assertTrue(gt.validate(15)) gt_eq = ParamValidators.gt_eq(10) self.assertFalse(gt_eq.validate(None)) self.assertFalse(gt_eq.validate(5)) self.assertTrue(gt_eq.validate(10)) self.assertTrue(gt_eq.validate(15)) lt = ParamValidators.lt(10) self.assertFalse(lt.validate(None)) self.assertTrue(lt.validate(5)) self.assertFalse(lt.validate(10)) self.assertFalse(lt.validate(15)) lt_eq = ParamValidators.lt_eq(10) self.assertFalse(lt_eq.validate(None)) self.assertTrue(lt_eq.validate(5)) self.assertTrue(lt_eq.validate(10)) self.assertFalse(lt_eq.validate(15)) in_range_inclusive = ParamValidators.in_range(5, 15) self.assertFalse(in_range_inclusive.validate(None)) self.assertFalse(in_range_inclusive.validate(0)) self.assertTrue(in_range_inclusive.validate(5)) self.assertTrue(in_range_inclusive.validate(10)) self.assertTrue(in_range_inclusive.validate(15)) self.assertFalse(in_range_inclusive.validate(20)) in_range_exclusive = ParamValidators.in_range(5, 15, False, False) self.assertFalse(in_range_exclusive.validate(None)) self.assertFalse(in_range_exclusive.validate(0)) self.assertFalse(in_range_exclusive.validate(5)) self.assertTrue(in_range_exclusive.validate(10)) self.assertFalse(in_range_exclusive.validate(15)) self.assertFalse(in_range_exclusive.validate(20)) in_array = ParamValidators.in_array([1, 2, 3]) self.assertFalse(in_array.validate(None)) self.assertTrue(in_array.validate(1)) self.assertFalse(in_array.validate(0)) not_null = ParamValidators.not_null() self.assertTrue(not_null.validate(5)) self.assertFalse(not_null.validate(None))
class HasReg(WithParams, ABC): """ Base class for the shared regularization param. """ REG: Param[float] = FloatParam("reg", "Regularization parameter.", 0., ParamValidators.gt_eq(0.)) def set_reg(self, value: float): return self.set(self.REG, value) def get_reg(self) -> float: return self.get(self.REG) @property def reg(self) -> float: return self.get_reg()
class HasTol(WithParams, ABC): """ Base class for the shared tolerance param. """ TOL: Param[float] = FloatParam( "tol", "Convergence tolerance for iterative algorithms.", 1e-6, ParamValidators.gt_eq(0)) def set_tol(self, value: float): return self.set(self.TOL, value) def get_tol(self) -> float: return self.get(self.TOL) @property def tol(self) -> float: return self.get_tol()
class _NaiveBayesParams( _NaiveBayesModelParams, HasLabelCol, ): """ Params for :class:`NaiveBayes`. """ SMOOTHING: Param[float] = FloatParam("smoothing", "The smoothing parameter.", 1.0, ParamValidators.gt_eq(0.0)) def __init__(self, java_params): super(_NaiveBayesParams, self).__init__(java_params) def set_smoothing(self, value: float): return typing.cast(_NaiveBayesParams, self.set(self.SMOOTHING, value)) def get_smoothing(self) -> float: return self.get(self.SMOOTHING) @property def smoothing(self) -> float: return self.get_smoothing()