示例#1
0
class HasMultiClass(WithParams, ABC):
    """
    Base class for the shared multi class param.

    Supported options:
        <li>auto: selects the classification type based on the number of classes:
            If the number of unique label values from the input data is one or two,
            set to "binomial". Otherwise, set to "multinomial".
        <li>binomial: binary logistic regression.
        <li>multinomial: multinomial logistic regression.
    """
    MULTI_CLASS: Param[str] = StringParam(
        "multi_class",
        "Classification type. Supported options: 'auto', 'binomial' and 'multinomial'.",
        'auto', ParamValidators.in_array(['auto', 'binomial', 'multinomial']))

    def set_multi_class(self, class_type: str):
        return self.set(self.MULTI_CLASS, class_type)

    def get_multi_class(self) -> str:
        return self.get(self.MULTI_CLASS)

    @property
    def multi_class(self) -> str:
        return self.get_multi_class()
示例#2
0
class _StringIndexerParams(_StringIndexerModelParams):
    """
    Params for :class:`StringIndexer`.
    """

    STRING_ORDER_TYPE: Param[str] = StringParam(
        "string_order_type", "How to order strings of each column.",
        "arbitrary",
        ParamValidators.in_array([
            'arbitrary', 'frequencyDesc', 'frequencyAsc', 'alphabetDesc',
            'alphabetAsc'
        ]))

    def __init__(self, java_params):
        super(_StringIndexerParams, self).__init__(java_params)

    def set_string_order_type(self, value: str):
        return typing.cast(_StringIndexerParams,
                           self.set(self.STRING_ORDER_TYPE, value))

    def get_string_order_type(self) -> str:
        return self.get(self.STRING_ORDER_TYPE)

    @property
    def string_order_type(self):
        return self.get_string_order_type()
示例#3
0
    def test_validators(self):
        gt = ParamValidators.gt(10)
        self.assertFalse(gt.validate(None))
        self.assertFalse(gt.validate(5))
        self.assertFalse(gt.validate(10))
        self.assertTrue(gt.validate(15))

        gt_eq = ParamValidators.gt_eq(10)
        self.assertFalse(gt_eq.validate(None))
        self.assertFalse(gt_eq.validate(5))
        self.assertTrue(gt_eq.validate(10))
        self.assertTrue(gt_eq.validate(15))

        lt = ParamValidators.lt(10)
        self.assertFalse(lt.validate(None))
        self.assertTrue(lt.validate(5))
        self.assertFalse(lt.validate(10))
        self.assertFalse(lt.validate(15))

        lt_eq = ParamValidators.lt_eq(10)
        self.assertFalse(lt_eq.validate(None))
        self.assertTrue(lt_eq.validate(5))
        self.assertTrue(lt_eq.validate(10))
        self.assertFalse(lt_eq.validate(15))

        in_range_inclusive = ParamValidators.in_range(5, 15)
        self.assertFalse(in_range_inclusive.validate(None))
        self.assertFalse(in_range_inclusive.validate(0))
        self.assertTrue(in_range_inclusive.validate(5))
        self.assertTrue(in_range_inclusive.validate(10))
        self.assertTrue(in_range_inclusive.validate(15))
        self.assertFalse(in_range_inclusive.validate(20))

        in_range_exclusive = ParamValidators.in_range(5, 15, False, False)
        self.assertFalse(in_range_exclusive.validate(None))
        self.assertFalse(in_range_exclusive.validate(0))
        self.assertFalse(in_range_exclusive.validate(5))
        self.assertTrue(in_range_exclusive.validate(10))
        self.assertFalse(in_range_exclusive.validate(15))
        self.assertFalse(in_range_exclusive.validate(20))

        in_array = ParamValidators.in_array([1, 2, 3])
        self.assertFalse(in_array.validate(None))
        self.assertTrue(in_array.validate(1))
        self.assertFalse(in_array.validate(0))

        not_null = ParamValidators.not_null()
        self.assertTrue(not_null.validate(5))
        self.assertFalse(not_null.validate(None))
示例#4
0
class HasBatchStrategy(WithParams, ABC):
    """
    Base class for the shared batch strategy param.
    """
    BATCH_STRATEGY: Param[str] = StringParam(
        "batch_strategy",
        "Strategy to create mini batch from online train data.", "count",
        ParamValidators.in_array(["count"]))

    def get_batch_strategy(self) -> str:
        return self.get(self.BATCH_STRATEGY)

    @property
    def batch_strategy(self):
        return self.get_batch_strategy()
示例#5
0
class HasDistanceMeasure(WithParams, ABC):
    """
    Base class for the shared distance_measure param.
    """
    DISTANCE_MEASURE: Param[str] = StringParam(
        "distance_measure",
        "Distance measure. Supported options: 'euclidean' and 'cosine'.",
        "euclidean", ParamValidators.in_array(['euclidean', 'cosine']))

    def set_distance_measure(self, distance_measure: str):
        return self.set(self.DISTANCE_MEASURE, distance_measure)

    def get_distance_measure(self) -> str:
        return self.get(self.DISTANCE_MEASURE)

    @property
    def distance_measure(self) -> str:
        return self.get_distance_measure()
示例#6
0
文件: kmeans.py 项目: apache/flink-ml
class _KMeansParams(_KMeansModelParams, HasSeed, HasMaxIter):
    """
    Params for :class:`KMeans`.
    """
    INIT_MODE: Param[str] = StringParam(
        "init_mode",
        "The initialization algorithm. Supported options: 'random'.", "random",
        ParamValidators.in_array(["random"]))

    def __init__(self, java_params):
        super(_KMeansParams, self).__init__(java_params)

    def set_init_mode(self, value: str):
        return self.set(self.INIT_MODE, value)

    def get_init_mode(self) -> str:
        return self.get(self.INIT_MODE)

    @property
    def init_mode(self):
        return self.get_init_mode()
示例#7
0
class _NaiveBayesModelParams(JavaWithParams, HasFeaturesCol, HasPredictionCol,
                             ABC):
    """
    Params for :class:`NaiveBayesModel`.
    """

    MODEL_TYPE: Param[str] = StringParam(
        "model_type", "The model type.", "multinomial",
        ParamValidators.in_array(["multinomial"]))

    def __init__(self, java_params):
        super(_NaiveBayesModelParams, self).__init__(java_params)

    def set_model_type(self, value: str):
        return self.set(self.MODEL_TYPE, value)

    def get_model_type(self) -> str:
        return self.get(self.MODEL_TYPE)

    @property
    def model_type(self) -> str:
        return self.get_model_type()
示例#8
0
class HasHandleInvalid(WithParams, ABC):
    """
    Base class for the shared handle_invalid param.

    Supported options and the corresponding behavior to handle invalid entries is listed as follows.

    <ul>
        <li>error: raise an exception.
        <li>skip: filter out rows with bad values.
    </ul>
    """
    HANDLE_INVALID: Param[str] = StringParam(
        "handle_invalid", "Strategy to handle invalid entries.", "error",
        ParamValidators.in_array(['error', 'skip']))

    def set_handle_invalid(self, value: str):
        return self.set(self.HANDLE_INVALID, value)

    def get_handle_invalid(self) -> str:
        return self.get(self.HANDLE_INVALID)

    @property
    def handle_invalid(self) -> str:
        return self.get_handle_invalid()