Example #1
0
    def test_validators(self):
        gt = ParamValidators.gt(10)
        self.assertFalse(gt.validate(None))
        self.assertFalse(gt.validate(5))
        self.assertFalse(gt.validate(10))
        self.assertTrue(gt.validate(15))

        gt_eq = ParamValidators.gt_eq(10)
        self.assertFalse(gt_eq.validate(None))
        self.assertFalse(gt_eq.validate(5))
        self.assertTrue(gt_eq.validate(10))
        self.assertTrue(gt_eq.validate(15))

        lt = ParamValidators.lt(10)
        self.assertFalse(lt.validate(None))
        self.assertTrue(lt.validate(5))
        self.assertFalse(lt.validate(10))
        self.assertFalse(lt.validate(15))

        lt_eq = ParamValidators.lt_eq(10)
        self.assertFalse(lt_eq.validate(None))
        self.assertTrue(lt_eq.validate(5))
        self.assertTrue(lt_eq.validate(10))
        self.assertFalse(lt_eq.validate(15))

        in_range_inclusive = ParamValidators.in_range(5, 15)
        self.assertFalse(in_range_inclusive.validate(None))
        self.assertFalse(in_range_inclusive.validate(0))
        self.assertTrue(in_range_inclusive.validate(5))
        self.assertTrue(in_range_inclusive.validate(10))
        self.assertTrue(in_range_inclusive.validate(15))
        self.assertFalse(in_range_inclusive.validate(20))

        in_range_exclusive = ParamValidators.in_range(5, 15, False, False)
        self.assertFalse(in_range_exclusive.validate(None))
        self.assertFalse(in_range_exclusive.validate(0))
        self.assertFalse(in_range_exclusive.validate(5))
        self.assertTrue(in_range_exclusive.validate(10))
        self.assertFalse(in_range_exclusive.validate(15))
        self.assertFalse(in_range_exclusive.validate(20))

        in_array = ParamValidators.in_array([1, 2, 3])
        self.assertFalse(in_array.validate(None))
        self.assertTrue(in_array.validate(1))
        self.assertFalse(in_array.validate(0))

        not_null = ParamValidators.not_null()
        self.assertTrue(not_null.validate(5))
        self.assertFalse(not_null.validate(None))
Example #2
0
class HasGlobalBatchSize(WithParams, ABC):
    """
    Base class for the shared global_batch_size param.
    """
    GLOBAL_BATCH_SIZE: Param[int] = IntParam(
        "global_batch_size", "Global batch size of training algorithms.", 32,
        ParamValidators.gt(0))

    def set_global_batch_size(self, global_batch_size: int):
        return self.set(self.GLOBAL_BATCH_SIZE, global_batch_size)

    def get_global_batch_size(self) -> int:
        return self.get(self.GLOBAL_BATCH_SIZE)

    @property
    def global_batch_size(self) -> int:
        return self.get_global_batch_size()
Example #3
0
class HasMaxIter(WithParams, ABC):
    """
    Base class for the shared maxIter param.
    """
    MAX_ITER: Param[int] = IntParam("max_iter",
                                    "Maximum number of iterations.", 20,
                                    ParamValidators.gt(0))

    def set_max_iter(self, max_iter: int):
        return self.set(self.MAX_ITER, max_iter)

    def get_max_iter(self) -> int:
        return self.get(self.MAX_ITER)

    @property
    def max_iter(self) -> int:
        return self.get_max_iter()
Example #4
0
class HasLearningRate(WithParams, ABC):
    """
    Base class for the shared learning rate param.
    """

    LEARNING_RATE: Param[float] = FloatParam(
        "learning_rate", "Learning rate of optimization method.", 0.1,
        ParamValidators.gt(0))

    def set_learning_rate(self, learning_rate: float):
        return self.set(self.LEARNING_RATE, learning_rate)

    def get_learning_rate(self) -> float:
        return self.get(self.LEARNING_RATE)

    @property
    def learning_rate(self) -> float:
        return self.get_learning_rate()
Example #5
0
class _KNNModelParams(JavaWithParams, HasFeaturesCol, HasPredictionCol, ABC):
    """
    Params for :class:`KNNModel`.
    """

    K: Param[int] = IntParam("k", "The number of nearest neighbors", 5,
                             ParamValidators.gt(0))

    def __init__(self, java_params):
        super(_KNNModelParams, self).__init__(java_params)

    def set_k(self, value: int):
        return typing.cast(_KNNModelParams, self.set(self.K, value))

    def get_k(self) -> int:
        return self.get(self.K)

    @property
    def k(self) -> int:
        return self.get_k()
Example #6
0
class _KMeansModelParams(JavaWithParams, HasDistanceMeasure, HasFeaturesCol,
                         HasPredictionCol, ABC):
    """
    Params for :class:`KMeansModel`.
    """

    K: Param[int] = IntParam("k", "The max number of clusters to create.", 2,
                             ParamValidators.gt(1))

    def __init__(self, java_params):
        super(_KMeansModelParams, self).__init__(java_params)

    def set_k(self, value: int):
        return typing.cast(_KMeansModelParams, self.set(self.K, value))

    def get_k(self) -> int:
        return self.get(self.K)

    @property
    def k(self) -> int:
        return self.get_k()