def check_label(data_inst): """ check label. FTL only supports binary classification, and labels should be 1 or -1 """ LOGGER.debug('checking label') label_checker = ClassifyLabelChecker() num_class, class_set = label_checker.validate_label(data_inst) if num_class != 2: raise ValueError( 'ftl only support binary classification, however {} labels are provided.' .format(num_class)) if 1 in class_set and -1 in class_set: return data_inst else: soreted_class_set = sorted(list(class_set)) new_label_mapping = { soreted_class_set[1]: 1, soreted_class_set[0]: -1 } reset_label = functools.partial(FTL.reset_label, mapping=new_label_mapping) new_table = data_inst.mapValues(reset_label) new_table.schema = copy.deepcopy(data_inst.schema) return new_table
def setUp(self): session.init("test_label_checker") self.small_label_set = [Instance(label=i % 5) for i in range(100)] self.classify_inst = session.parallelize(self.small_label_set, include_key=False) self.regression_label = [Instance(label=random.random()) for i in range(100)] self.regression_inst = session.parallelize(self.regression_label) self.classify_checker = ClassifyLabelChecker() self.regression_checker = RegressionLabelChecker()
def check_label(self): LOGGER.info("check label") classes_ = [] num_classes, booster_dim = 1, 1 if self.task_type == consts.CLASSIFICATION: num_classes, classes_ = ClassifyLabelChecker.validate_label( self.data_bin) if num_classes > 2: booster_dim = num_classes range_from_zero = True for _class in classes_: try: if 0 <= _class < len(classes_) and isinstance(_class, int): continue else: range_from_zero = False break except: range_from_zero = False classes_ = sorted(classes_) if not range_from_zero: class_mapping = dict(zip(classes_, range(num_classes))) self.y = self.y.mapValues(lambda _class: class_mapping[_class]) else: RegressionLabelChecker.validate_label(self.data_bin) return classes_, num_classes, booster_dim
def check_label(self): LOGGER.info("check label") if self.task_type == consts.CLASSIFICATION: self.num_classes, self.classes_ = ClassifyLabelChecker.validate_label( self.data_bin) if self.num_classes > 2: self.classify_target = "multinomial" self.tree_dim = self.num_classes range_from_zero = True for _class in self.classes_: try: if _class >= 0 and _class < self.num_classes and isinstance( _class, int): continue else: range_from_zero = False break except: range_from_zero = False self.classes_ = sorted(self.classes_) if not range_from_zero: class_mapping = dict( zip(self.classes_, range(self.num_classes))) self.y = self.y.mapValues(lambda _class: class_mapping[_class]) else: RegressionLabelChecker.validate_label(self.data_bin) self.set_loss(self.objective_param)
class TeskClassifyLabelChecker(unittest.TestCase): def setUp(self): session.init("test_label_checker") self.small_label_set = [Instance(label=i % 5) for i in range(100)] self.classify_inst = session.parallelize(self.small_label_set, include_key=False, partition=16) self.regression_label = [ Instance(label=random.random()) for i in range(100) ] self.regression_inst = session.parallelize(self.regression_label, partition=16) self.classify_checker = ClassifyLabelChecker() self.regression_checker = RegressionLabelChecker() def test_classify_label_checkert(self): num_class, classes = self.classify_checker.validate_label( self.classify_inst) self.assertTrue(num_class == 5) self.assertTrue(sorted(classes) == [0, 1, 2, 3, 4]) def test_regression_label_checker(self): self.regression_checker.validate_label(self.regression_inst) def tearDown(self): session.stop()
def get_data_classes(self, data_instances): """ get all classes in data_instances """ class_set = None if self.has_label: num_class, class_list = ClassifyLabelChecker.validate_label(data_instances) class_set = set(class_list) self._synchronize_classes_list(class_set) return self.classes
def check_labels(self, data_inst, ) -> List[int]: LOGGER.debug('checking labels') classes_ = None if self.task_type == consts.CLASSIFICATION: num_classes, classes_ = ClassifyLabelChecker.validate_label(data_inst) else: RegressionLabelChecker.validate_label(data_inst) return classes_
def _client_check_data(self, data_instances): self._abnormal_detection(data_instances) self.check_abnormal_values(data_instances) self.init_schema(data_instances) num_classes, classes_ = ClassifyLabelChecker.validate_label(data_instances) aligned_label, new_label_mapping = HomoLabelEncoderClient().label_alignment(classes_) if len(aligned_label) > 2: raise ValueError("H**o LR support binary classification only now") elif len(aligned_label) <= 1: raise ValueError("Number of classes should be equal to 2")