Exemple #1
0
 def fit(self, data):
     # 掐头去尾, 删除id和label,得到特征名称
     self.features = list(data.columns)[1: -1]
     # 获取所有类别
     self.classes = data['label'].unique().astype(str)
     # 初始化多分类损失函数的参数 K
     self.loss.init_classes(self.classes)
     # 根据类别将‘label’列进行one-hot处理
     for class_name in self.classes:
         label_name = 'label_' + class_name
         data[label_name] = data['label'].apply(lambda x: 1 if str(x) == class_name else 0)
         # 初始化 f_0(x)
         self.f_0[class_name] = self.loss.initialize_f_0(data, class_name)
     # print(data)
     # 对 m = 1, 2, ..., M
     logger.handlers[0].setLevel(logging.INFO if self.is_log else logging.CRITICAL)
     for iter in range(1, self.n_trees + 1):
         if len(logger.handlers) > 1:
             logger.removeHandler(logger.handlers[-1])
         fh = logging.FileHandler('results/NO.{}_tree.log'.format(iter), mode='w', encoding='utf-8')
         fh.setLevel(logging.DEBUG)
         logger.addHandler(fh)
         logger.info(('-----------------------------构建第%d颗树-----------------------------' % iter))
         # 这里计算负梯度整体计算是为了计算p_sum的一致性
         self.loss.calculate_residual(data, iter)
         self.trees[iter] = {}
         for class_name in self.classes:
             target_name = 'res_' + class_name + '_' + str(iter)
             self.trees[iter][class_name] = Tree(data, self.max_depth, self.min_samples_split,
                                                 self.features, self.loss, target_name, logger)
             self.loss.update_f_m(data, self.trees, iter, class_name, self.learning_rate, logger)
         if self.is_plot:
                 plot_multi(self.trees[iter], max_depth=self.max_depth, iter=iter)
     if self.is_plot:
         plot_all_trees(self.n_trees)
Exemple #2
0
 def fit(self, data):
     """
     :param data: pandas.DataFrame, the features data of train training   
     """
     # 掐头去尾, 删除id和label,得到特征名称
     self.features = list(data.columns)[1: -1]
     # 初始化 f_0(x)
     # 对于平方损失来说,初始化 f_0(x) 就是 y 的均值
     self.f_0 = self.loss.initialize_f_0(data)
     # 对 m = 1, 2, ..., M
     logger.handlers[0].setLevel(logging.INFO if self.is_log else logging.CRITICAL)
     for iter in range(1, self.n_trees+1):
         if len(logger.handlers) > 1:
             logger.removeHandler(logger.handlers[-1])
         fh = logging.FileHandler('results/NO.{}_tree.log'.format(iter), mode='w', encoding='utf-8')
         fh.setLevel(logging.DEBUG)
         logger.addHandler(fh)
         # 计算负梯度--对于平方误差来说就是残差
         logger.info(('-----------------------------构建第%d颗树-----------------------------' % iter))
         self.loss.calculate_residual(data, iter)
         target_name = 'res_' + str(iter)
         self.trees[iter] = Tree(data, self.max_depth, self.min_samples_split,
                                 self.features, self.loss, target_name, logger)
         self.loss.update_f_m(data, self.trees, iter, self.learning_rate, logger)
         if self.is_plot:
             plot_tree(self.trees[iter], max_depth=self.max_depth, iter=iter)
     # print(self.trees)
     if self.is_plot:
         plot_all_trees(self.n_trees)
Exemple #3
0
 def fit(self, data):
     '''
     :param x: pandas.DataFrame, the features data of train training  
     :param y: list, the label of training
     '''
     # 去头掐尾, 删除id和label,得到特征名称
     self.features = list(data.columns)[1:-1]
     # 初始化 f_0(x)
     # 对于平方损失来说,初始化 f_0(x) 就是 y 的均值
     self.f_0 = self.loss_function.initialize_f_0(data)
     # 对 m = 1, 2, ..., M
     logger.setLevel(logging.INFO if self.is_log else logging.CRITICAL)
     for iter in range(1, self.n_trees + 1):
         # 计算负梯度--对于平方误差来说就是残差
         logger.info((
             '-----------------------------构建第%d颗树-----------------------------'
             % iter))
         self.loss_function.calculate_residual(data, iter)
         self.trees[iter] = Tree(data, self.max_depth, self.features, iter,
                                 logger)
         self.loss_function.update_f_m(data, self.trees, iter,
                                       self.learning_rate, logger)