def test_default_behavior(self): params = Params() not_optinal = ParamInfo("a", "", is_optional=False) with self.assertRaises(ValueError): params.get(not_optinal) # get optional without default param optional_without_default = ParamInfo("a", "") with self.assertRaises(ValueError): params.get(optional_without_default)
def test_remove_contains_size_clear_is_empty(self): param_info = ParamInfo( "key", "", has_default_value=True, default_value=None, type_converter=TypeConverters.to_string) params = Params() self.assertEqual(params.size(), 0) self.assertTrue(params.is_empty()) val = "3" params.set(param_info, val) self.assertEqual(params.size(), 1) self.assertFalse(params.is_empty()) params_json = params.to_json() params_new = Params.from_json(params_json) self.assertEqual(params.get(param_info), val) self.assertEqual(params_new.get(param_info), val) params.clear() self.assertEqual(params.size(), 0) self.assertTrue(params.is_empty())
class HasVectorCol(WithParams): """ Trait for parameter vectorColName. """ vector_col = ParamInfo("vectorCol", "Name of a vector column", is_optional=False, type_converter=TypeConverters.to_string) def set_vector_col(self, v: str) -> 'HasVectorCol': return super().set(self.vector_col, v) def get_vector_col(self) -> str: return super().get(self.vector_col)
class HasPredictionCol(WithParams): """ An interface for classes with a parameter specifying the column name of the prediction. """ prediction_col = ParamInfo("predictionCol", "Column name of prediction.", is_optional=False, type_converter=TypeConverters.to_string) def set_prediction_col(self, v: str) -> 'HasPredictionCol': return super().set(self.prediction_col, v) def get_prediction_col(self) -> str: return super().get(self.prediction_col)
class HasOutputCol(WithParams): """ An interface for classes with a parameter specifying the name of the output column. """ output_col = ParamInfo("outputCol", "Name of the output column", is_optional=False, type_converter=TypeConverters.to_string) def set_output_col(self, v: str) -> 'HasOutputCol': return super().set(self.output_col, v) def get_output_col(self) -> str: return super().get(self.output_col)
class HasSelectedCols(WithParams): """ An interface for classes with a parameter specifying the name of multiple table columns. """ selected_cols = ParamInfo("selectedCols", "Names of the columns used for processing", is_optional=False, type_converter=TypeConverters.to_list_string) def set_selected_cols(self, v: list) -> 'HasSelectedCols': return super().set(self.selected_cols, v) def get_selected_cols(self) -> list: return super().get(self.selected_cols)
def test_get_optional_param(self): param_info = ParamInfo("key", "", has_default_value=True, default_value=None, type_converter=TypeConverters.to_string) params = Params() self.assertIsNone(params.get(param_info)) val = "3" params.set(param_info, val) self.assertEqual(val, params.get(param_info)) params.set(param_info, None) self.assertIsNone(params.get(param_info))
def test_to_from_json(self): import jsonpickle param_info = ParamInfo("key", "", has_default_value=True, default_value=None, type_converter=TypeConverters.to_string) param_info_new = jsonpickle.decode(jsonpickle.encode(param_info)) self.assertEqual(param_info_new, param_info) params = Params() val = "3" params.set(param_info, val) params_new = Params.from_json(params.to_json()) self.assertEqual(params_new.get(param_info), val)