Example #1
0
    def __init__(self, parms_func={}, parms_opter={}, parms_log={}):
        '''
        parms_func: 目标函数信息,默认应包含'func_name', `x_lb`, `x_ub`, `dim`
        parms_opter: 优化函数需要用到的参数信息,默认应包含'opter_name', `PopSize`,
                    `Niter`
        parms_log: 寻优过程中控制打印或日志记录的参数,默认应包含`logger`, `nshow`
        '''

        # 目标函数信息
        parms_func_default = {
            'func_name': '',
            'x_lb': None,
            'x_ub': None,
            'dim': None,
            'kwargs': {}
        }
        parms_loss = {x: parms_func_default[x] \
                                for x in parms_func_default.keys() if \
                                    x not in parms_func.keys()}
        parms_func.update(parms_loss)
        self.parms_func = parms_func

        # 优化算法参数
        parms_opter_default = {'opter_name': '', 'PopSize': 20, 'Niter': 100}
        parms_loss = {x: parms_opter_default[x] \
                                for x in parms_opter_default.keys() if \
                                    x not in parms_opter.keys()}
        parms_opter.update(parms_loss)
        self.parms_opter = parms_opter

        # 日志参数
        parms_log_default = {'logger': simple_logger(), 'nshow': 10}
        parms_loss = {x: parms_log_default[x] \
                                for x in parms_log_default.keys() if \
                                    x not in parms_log.keys()}
        parms_log.update(parms_loss)
        self.parms_log = parms_log

        # 优化过程和结果
        self.__best_val = None  # 全局最优值
        self.__best_x = []  # 全局最优解
        self.__convergence_curve = []  # 收敛曲线(每轮最优)
        self.__convergence_curve_mean = []  # 收敛曲线(每轮平均)
        self.__startTime = None  # 开始时间
        self.__endTime = None  # 结束时间
        self.__exeTime = None  # 优化用时(单位秒)
Example #2
0
def check_parms_mdl(parms_mdl, objective, num_class, logger=None):
    '''
    检查模型相关参数设置是否正确
    parms_mdl为设置的模型参数
    检查的参数项目(可添加新的检查项目):
        任务类型objective
        multiclass中的num_class参数
    返回检查的参数值
    注:后续若更改该函数检查的参数项,则在调用该函数的地方也须做相应修改
    '''

    logger = simple_logger() if logger is None else logger

    # objective检查
    if objective is None:
        if isinstance(parms_mdl, dict):
            if 'objective' not in parms_mdl.keys():
                raise ValueError('必须设置任务类型objective或在parms_mdl中设置!')
            else:
                objective = parms_mdl['objective']
        else:
            raise ValueError('必须设置任务类型objective或在parms_mdl中设置!')
    else:
        if isinstance(parms_mdl, dict) and 'objective' in parms_mdl.keys() \
                                      and parms_mdl['objective'] != objective:
            logger.warning('objective与parms_mdl设置不一样,以前者为准!')
            parms_mdl['objective'] = objective

    if objective not in ['multiclass', 'binary', 'regression']:
        raise ValueError('{}不是允许的任务类型!'.format(objective))

    # multiclass的num_class检查
    if objective in ['multiclass']:
        if not isinstance(num_class, int):
            if 'num_class' not in parms_mdl.keys():
                raise ValueError('多分类任务必须指定类别数num_class, int!')
            else:
                num_class = parms_mdl['num_class']
        else:
            if isinstance(parms_mdl, dict) and 'num_class' in \
                     parms_mdl.keys() and num_class != parms_mdl['num_class']:
                logger.warning('num_class与parms_mdl设置不一样,以前者为准!')
                parms_mdl['num_class'] = num_class

    return objective, num_class
Example #3
0
def plot_Series(data, cols_styl_up_left, cols_styl_up_right=None,
                cols_styl_low_left=None, cols_styl_low_right=None,
                cols_to_label_info = {}, xparls_info={},
                yparls_info_up=None, yparls_info_low=None, ylabels=None,
                grids=False, figsize=(12, 9), title=None, nXticks=8,
                fontsize=15, markersize=10, fig_save_path=None, logger=None):
    '''
    todo: markersize分开设置,正常绘制与特殊标注重复绘制问题,
          x轴平行线对应列不一定非要在主图绘制列中选择
          平行线图层绘制在主线下面
          标注图层绘制在线型图层上面(根据输入顺序绘制图层而不是根据坐标轴区域顺序绘制)
    
    pd.DataFrame多列绘图
    
    Parameters
    ----------    
    cols_styl_up_left, cols_styl_up_right, cols_styl_low_left,
    cols_styl_low_right:
        分别指定顶部左轴、顶部右轴、底部左轴和底部右需要绘制的序列及其线型和图例,格式如:
        {'col1': ('.-b', 'lbl1'), 'col2': ...}或{'col1': '.-b', 'col2': ...}
        第一种格式中lbl设置图例(legend),若为None则默认取列名,为False,则不设置图例;
        第二种格式只设置线型,legend默认取列名
    cols_to_label_info: 设置需要特殊标注的列绘图信息,格式形如:
        {col1: [[col_lbl1, (v1, v2, ..), (styl1, styl2, ..), (lbl1, lbl2, ..)],
                [col_lbl2, (v1, v2, ..), ...]],
         col2: ..},其中col是需要被特殊标注的列,col_lbl为标签列;v指定哪些标签值对应的
        数据用于绘图;styl设置线型;lbl设置图例标签,若为None,则设置为v,若为False,
        则不设置图例标签
    xparls_info: 设置x轴平行线信息,格式形如:
        {col1: [(yval1, clor1, styl1, width1), (yval2, ...)], col2:, ...},
        其中yval指定平行线y轴位置,clor设置颜色,styl设置线型,width设置线宽
    yparls_info_up, yparls_info_low: 分别设置顶部和底部x轴平行线格式信息,格式形如:
        [(xval1, clor1, styl1, width1), (xval2, clor2, style2, width2), ...],
        其中xval指定平行线x轴位置,clor设置颜色,styl设置线型,width设置线宽
    ylabels: None或列表,设置四个y轴标签文本内容,若为None则不设置标签文本,
        若为False则既不设置y轴标签文本内容,也不显示y轴刻度
    grids: 设置四个坐标轴网格,若grids=True,则将顶部左轴和底部左轴绘制网格;
        若grids=False,则全部没有网格,grids可设置为列表分别对四个坐标轴设置网格
    '''
    
    df = data.copy()    
    logger = simple_logger() if logger is None else logger
    
    # 网格设置,grids分别设置顶部左边、顶部右边、底部左边、底部右边的网格
    if grids == True:
        grids = [True, False, True, False]
    elif grids == False or grids is None:
        grids = [False, False, False, False]
        
    # y轴标签设置
    if ylabels is None:
        ylabels = [None, None, None, None]        
    
    # 索引列处理
    if df.index.name is None:
        df.index.name = 'idx'        
    idx_name = df.index.name
    if idx_name in df.columns:
        df.drop(idx_name, axis=1, inplace=True)
    df.reset_index(inplace=True)
    
    if cols_styl_low_left is None and cols_styl_low_right is not None:
        logger.warning('当底部图只指定右边坐标轴时,默认绘制在左边坐标轴!')
        cols_styl_low_left, cols_styl_low_right = cols_styl_low_right, None
        
    # 坐标准备
    plt.figure(figsize=figsize)
    if cols_styl_low_left is not None:
        gs = GridSpec(3, 1)
        axUpLeft = plt.subplot(gs[:2, :]) # 顶部为主图,占三分之二高度
        axLowLeft = plt.subplot(gs[2, :])
    else:
        gs = GridSpec(1, 1)
        axUpLeft = plt.subplot(gs[:, :])
        
        
    def get_cols_to_label_info(cols_to_label_info, col):
        '''需要进行特殊点标注的列绘图设置信息获取'''
        
        to_plots = []
        for label_infos in cols_to_label_info[col]:
        
            lbl_col = label_infos[0]
            
            if label_infos[2] is None:
                label_infos = [lbl_col, label_infos[1], [None]*len(label_infos[1]),
                               label_infos[3]]
            
            if label_infos[3] == False:
                label_infos = [lbl_col, label_infos[1], label_infos[2],
                               [False]*len(label_infos[1])]
            elif isnull(label_infos[3]) or \
                                        all([isnull(x) for x in label_infos[3]]):
                label_infos = [lbl_col, label_infos[1], label_infos[2],
                               label_infos[1]]
            
            vals = label_infos[1]
            for k in range(len(vals)):
                series = df[df[lbl_col] == vals[k]][col]
                ln_styl = label_infos[2][k]
                lbl_str = label_infos[3][k]
                to_plots.append([series, (ln_styl, lbl_str)])
            
        return to_plots    

    def get_xparls_info(parls_info, col, clor_default='r',
                        lnstyl_default='--', lnwidth_default=2):
        '''x轴平行线绘图设置信息获取'''
        parls = parls_info[col]
        to_plots = []
        for val, clor, lnstyl, lnwidth in parls:
            clor = clor_default if clor is None else clor
            lnstyl = lnstyl_default if lnstyl is None else lnstyl
            lnwidth = lnwidth_default if lnwidth is None else lnwidth
            to_plots.append([val, clor, lnstyl, lnwidth])
        return to_plots
    
    def get_yparls_info(parls_info, clor_default='r', lnstyl_default='--',
                        lnwidth_default=2):
        '''y轴平行线绘图设置信息获取'''
        to_plots = []
        for val, clor, lnstyl, lnwidth in parls_info:
            clor = clor_default if clor is None else clor
            lnstyl = lnstyl_default if lnstyl is None else lnstyl
            lnwidth = lnwidth_default if lnwidth is None else lnwidth
            val = df[df[idx_name] == val].index[0]
            to_plots.append([val, clor, lnstyl, lnwidth])
        return to_plots
    
    
    # lns存放双坐标legend信息
    # 双坐标轴legend参考:https://www.cnblogs.com/Atanisi/p/8530693.html
    lns = []    
    # 顶部左边坐标轴
    for col, styl in cols_styl_up_left.items():
        ln = plot_series_with_styls_info(axUpLeft, df[col], styl)
        if ln is not None:
            lns.append(ln)            
           
        # 特殊点标注
        if col in cols_to_label_info.keys():
            to_plots = get_cols_to_label_info(cols_to_label_info, col)
            for series, styls_info in to_plots:
                ln = plot_series_with_styls_info(axUpLeft, series, styls_info,
                                    lnstyl_default='ko', markersize=markersize)
                if ln is not None:
                    lns.append(ln)
                
        # x轴平行线
        if col in xparls_info.keys():
            to_plots = get_xparls_info(xparls_info, col)
            for yval, clor, lnstyl, lnwidth in to_plots:
                axUpLeft.axhline(y=yval, c=clor, ls=lnstyl, lw=lnwidth)
                
    # y轴平行线
    if yparls_info_up is not None:
        to_plots = get_yparls_info(yparls_info_up)
        for xval, clor, lnstyl, lnwidth in to_plots:
            axUpLeft.axvline(x=xval, c=clor, ls=lnstyl, lw=lnwidth)
        
    # 顶部左边坐标轴网格       
    axUpLeft.grid(grids[0])
    
    # 标题绘制在顶部图上
    if title is not None:
        axUpLeft.set_title(title, fontsize=fontsize)
        
    # y轴标签文本
    if ylabels[0] is False:
        axUpLeft.set_ylabel(None)
        axUpLeft.set_yticks([])
    else:
        axUpLeft.set_ylabel(ylabels[0], fontsize=fontsize)
        
    # 顶部右边坐标轴
    if cols_styl_up_right is not None:
        axUpRight = axUpLeft.twinx()
        for col, styl in cols_styl_up_right.items():
            ln = plot_series_with_styls_info(axUpRight, df[col], styl,
                                             lbl_str_ext='(r)')
            if ln is not None:
                lns.append(ln)
            
            # 特殊点标注
            if col in cols_to_label_info.keys():
                to_plots = get_cols_to_label_info(cols_to_label_info, col)
                for series, styls_info in to_plots:
                    ln = plot_series_with_styls_info(axUpRight, series,
                                            styls_info, lnstyl_default='ko',
                                    markersize=markersize, lbl_str_ext='(r)')
                    if ln is not None:
                        lns.append(ln)
                    
            # x轴平行线
            if col in xparls_info.keys():
                to_plots = get_xparls_info(xparls_info, col)
                for yval, clor, lnstyl, lnwidth in to_plots:
                    axUpRight.axhline(y=yval, c=clor, ls=lnstyl, lw=lnwidth)
                    
        # 顶部右边坐标轴网格
        axUpRight.grid(grids[1])
        
        # y轴标签文本
        if ylabels[1] is False:
            axUpRight.set_ylabel(None)
            axUpRight.set_yticks([])
        else:
            axUpRight.set_ylabel(ylabels[1], fontsize=fontsize)
        
    # 顶部图legend合并显示
    if len(lns) > 0:
        lnsAdd = lns[0]
        for ln in lns[1:]:
            lnsAdd = lnsAdd + ln
        labs = [l.get_label() for l in lnsAdd]
        axUpLeft.legend(lnsAdd, labs, loc=0, fontsize=fontsize)
    
    
    if cols_styl_low_left is not None:
        # 要绘制底部图时取消顶部图x轴刻度
        # axUpLeft.set_xticks([]) # 这样会导致设置网格线时没有竖线
        axUpLeft.set_xticklabels([]) # 这样不会影响设置网格
        lns = []
        
        # 底部左边坐标轴
        for col, styl in cols_styl_low_left.items():
            ln = plot_series_with_styls_info(axLowLeft, df[col], styl)
            if ln is not None:
                lns.append(ln)
            
            # 特殊点标注
            if col in cols_to_label_info.keys():
                to_plots = get_cols_to_label_info(cols_to_label_info, col)
                for series, styls_info in to_plots:
                    ln = plot_series_with_styls_info(axLowLeft, series,
                        styls_info, lnstyl_default='ko', markersize=markersize)
                    if ln is not None:
                        lns.append(ln)
                    
            # x轴平行线
            if col in xparls_info.keys():
                to_plots = get_xparls_info(xparls_info, col)
                for yval, clor, lnstyl, lnwidth in to_plots:
                    axLowLeft.axhline(y=yval, c=clor, ls=lnstyl, lw=lnwidth)
                    
        # y轴平行线
        if yparls_info_low is not None:
            to_plots = get_yparls_info(yparls_info_low)
            for xval, clor, lnstyl, lnwidth in to_plots:
                axLowLeft.axvline(x=xval, c=clor, ls=lnstyl, lw=lnwidth)
            
        # 底部左边坐标轴网格
        axLowLeft.grid(grids[2])    
        
        # y轴标签文本
        if ylabels[2] is False:
            axLowLeft.set_ylabel(None)
            axLowLeft.set_yticks([])
        else:
            axLowLeft.set_ylabel(ylabels[2], fontsize=fontsize)
        
        # 底部右边坐标轴
        if cols_styl_low_right is not None:
            axLowRight = axLowLeft.twinx()
            for col, styl in cols_styl_low_right.items():
                ln = plot_series_with_styls_info(axLowRight, df[col], styl,
                                                 lbl_str_ext='(r)')
                if ln is not None:
                    lns.append(ln)
                
                # 特殊点标注
                if col in cols_to_label_info.keys():
                    to_plots = get_cols_to_label_info(cols_to_label_info, col)
                    for series, styls_info in to_plots:
                        ln = plot_series_with_styls_info(axLowRight, series,
                                            styls_info, lnstyl_default='ko',
                                    markersize=markersize, lbl_str_ext='(r)')
                        if ln is not None:
                            lns.append(ln)
                       
                # x轴平行线
                if col in xparls_info.keys():
                    to_plots = get_xparls_info(xparls_info, col)
                    for yval, clor, lnstyl, lnwidth in to_plots:
                        axLowRight.axhline(y=yval, c=clor, ls=lnstyl,
                                          lw=lnwidth)
                       
            # 底部右边坐标轴网格
            axLowRight.grid(grids[3]) 
            
            # y轴标签文本
            if ylabels[3] is False:
                axLowRight.set_ylabel(None)
                axLowRight.set_yticks([])
            else:
                axLowRight.set_ylabel(ylabels[3], fontsize=fontsize)
                
        # 底部图legend合并显示
        if len(lns) > 0:
            lnsAdd = lns[0]
            for ln in lns[1:]:
                lnsAdd = lnsAdd + ln
            labs = [l.get_label() for l in lnsAdd]
            axLowLeft.legend(lnsAdd, labs, loc=0, fontsize=fontsize)
        
    
    # x轴刻度
    n = df.shape[0]
    xpos = [int(x*n/nXticks) for x in range(0, nXticks)] + [n-1]
    plt.xticks(xpos, [df.loc[x, idx_name] for x in xpos])
    
    plt.tight_layout()
        
    # 保存图片
    if fig_save_path:
        plt.savefig(fig_save_path)
        
    plt.show()
Example #4
0
def lgb_cv_hoo(X_train,
               y_train,
               objective=None,
               parms_mdl_list=None,
               parms_train_list=None,
               Nfold=5,
               shuffle=True,
               random_state=62,
               mdl_path_list=None,
               logger=None):
    '''
    自定义lgb交叉验证,返回模型列表和结果列表
    '''

    logger = simple_logger() if logger is None else logger

    # 模型参数设置为列表(每个模型单独设置)
    if parms_mdl_list is None or isinstance(parms_mdl_list, dict):
        parms_mdl_list = [parms_mdl_list] * Nfold

    # 训练参数设置为列表(每个模型单独设置)
    if parms_train_list is None or isinstance(parms_train_list, dict):
        parms_train_list = [parms_train_list] * Nfold

    # 若模型存放路径为str,则新建文件夹并设置保存路径列表
    if mdl_path_list is None:
        mdl_path_list = [None] * Nfold
    elif isinstance(mdl_path_list, str):
        abs_dir = os.path.abspath(mdl_path_list)
        if not os.path.isdir(mdl_path_list):
            logger.warning('将创建模型存放文件夹{}!'.format(abs_dir))
            os.mkdir(abs_dir)
        mdl_path_list = [os.path.join(abs_dir, 'mdl_kf'+str(k)+'.bin') \
                         for k in range(1, Nfold+1)]

    # 交叉验证
    mdls, evals_results = [], []
    folds = KFold(n_splits=Nfold, shuffle=shuffle, random_state=random_state)
    for Ikf, (trnIdxs, valIdxs) in enumerate(folds.split(X_train, y_train)):
        logger.info('{}/{}折交叉验证训练中...'.format(Ikf + 1, Nfold))

        if isinstance(X_train, pd.core.frame.DataFrame):
            X_train_Ikf = X_train.iloc[trnIdxs, :]
            X_valid_Ikf = X_train.iloc[valIdxs, :]
        elif isinstance(X_train, np.ndarray):
            X_train_Ikf = X_train[trnIdxs]
            X_valid_Ikf = X_train[valIdxs]

        if isinstance(y_train, pd.core.series.Series):
            y_train_Ikf = y_train.iloc[trnIdxs]
            y_valid_Ikf = y_train.iloc[valIdxs]
        elif isinstance(y_train, np.ndarray):
            y_train_Ikf = y_train[trnIdxs]
            y_valid_Ikf = y_train[valIdxs]

        mdl, evals_result = lgb_train(X_train_Ikf,
                                      y_train_Ikf,
                                      X_valid=X_valid_Ikf,
                                      y_valid=y_valid_Ikf,
                                      objective=objective,
                                      parms_mdl=parms_mdl_list[Ikf],
                                      parms_train=parms_train_list[Ikf],
                                      mdl_save_path=mdl_path_list[Ikf],
                                      logger=logger)
        mdls.append(mdl)
        evals_results.append(evals_result)

    return mdls, evals_results
Example #5
0
def lgb_cv_GridSearch(X_train,
                      y_train,
                      objective=None,
                      parms_mdl=None,
                      parms_to_opt=None,
                      parms_cv=None,
                      logger=None):
    '''lgb交叉验证网格搜索调参'''

    logger = simple_logger() if logger is None else logger

    if not isinstance(parms_to_opt, dict) or len(parms_to_opt) == 0:
        logger.error('检测到待优化参数parms_to_opt为空!')
        return None, None

    # 检查任务相关参数(注意:若添加其他任务,可能需要添加对应需要检查的参数)
    num_class = len(set(y_train)) if objective == 'multiclass' else 1
    objective, num_class = check_parms_mdl(parms_mdl,
                                           objective,
                                           num_class,
                                           logger=logger)
    # 模型参数和训练参数准备
    parms_mdl = get_parms_mdl(parms_mdl=parms_mdl,
                              objective=objective,
                              num_class=num_class,
                              logger=logger)
    # 注意:这里metric没考虑自定义的情况,自定义metric需要再修改
    if len(parms_mdl['metric']) > 1 and \
                                    not isinstance(parms_mdl['metric'], str):
        metric = list(parms_mdl['metric'])[0]
        logger.warning('发现多个metric,将以{}为优化目标!'.format(metric))
        # 注:由于set是无序的,故当parms_mdl['metric']是set时可能取不到第一个
        parms_mdl['metric'] = metric
    if not isinstance(parms_mdl['metric'], str):
        parms_mdl['metric'] = list(parms_mdl['metric'])[0]

    # 判断metric越大越好还是越小越好
    if parms_mdl['metric'] in ['auc']:
        max_good = True
        best = -np.inf
    elif parms_mdl['metric'] in [
            'l1', 'l2', 'mape', 'rmse', 'binary_error', 'multi_error',
            'binary_logloss', 'multi_logloss'
    ]:
        max_good = False
        best = np.inf
    else:
        raise ValueError('未识别的metric: {},请更改或在此函数中增加该支持项!' \
                         .format(parms_mdl['metric']))

    # 将待优化参数网格化
    grid_parms = []
    opt_items = sorted(parms_to_opt.items())
    keys, values = zip(*opt_items)
    for v in product(*values):
        grid_parm = dict(zip(keys, v))
        grid_parms.append(grid_parm)

    # cv网格搜索
    metric = parms_mdl['metric']
    best_parms = None
    k = 1
    for grid_parm in grid_parms:
        logger.info('交叉验证网格搜索中:{} / {} ...'.format(k, len(grid_parms)))
        logger.info('当前参数:{}'.format(grid_parm))
        k += 1

        parms_mdl_now = parms_mdl.copy()
        parms_mdl_now.update(grid_parm)

        eval_hist = lgb_cv(X_train,
                           y_train,
                           objective=objective,
                           parms_mdl=parms_mdl_now,
                           parms_cv=parms_cv,
                           logger=logger)
        if max_good:
            best_now = max(eval_hist[metric + '-mean'])
            if best_now > best:
                best = best_now
                best_parms = grid_parm
        else:
            best_now = min(eval_hist[metric + '-mean'])
            if best_now < best:
                best = best_now
                best_parms = grid_parm

    return best_parms, {'best ' + metric: best}
Example #6
0
def lgb_cv(X_train,
           y_train,
           objective=None,
           parms_mdl=None,
           parms_cv=None,
           logger=None):
    '''
    lightgbm交叉验证
    X_train, y_train为pd或np格式,最好为pd格式数据
    objective为任务类型,支持的任务类型(可添加其他任务类型):
        multiclass、binary、regression
    parms_mdl和parms_cv为模型参数和cv参数(dict)
    返回各损失函数值?
    '''

    logger = simple_logger() if logger is None else logger

    # 检查任务相关参数(注意:若添加其他任务,可能需要添加对应需要检查的参数)
    num_class = len(set(y_train)) if objective == 'multiclass' else 1
    objective, num_class = check_parms_mdl(parms_mdl,
                                           objective,
                                           num_class,
                                           logger=logger)
    # 模型参数和训练参数准备
    parms_mdl = get_parms_mdl(parms_mdl=parms_mdl,
                              objective=objective,
                              num_class=num_class,
                              logger=logger)
    parms_cv = get_parms_TrainOrCV(parms_TrainOrCV=parms_cv)
    if objective == 'regression':
        parms_cv['stratified'] = None  # 回归任务不适用分层抽样

    # 数据集准备
    datTrain = lgb.Dataset(X_train,
                           y_train,
                           categorical_feature=parms_cv['categorical_feature'])

    # 交叉验证
    logger.info('交叉验证...')
    eval_hist = lgb.cv(
        params=parms_mdl,
        train_set=datTrain,
        num_boost_round=parms_cv['num_boost_round'],
        folds=parms_cv['folds'],
        nfold=parms_cv['nfold'],
        stratified=parms_cv['stratified'],
        shuffle=parms_cv['shuffle'],
        metrics=parms_cv['metrics'],
        fobj=parms_cv['fobj'],
        feval=parms_cv['feval'],
        init_model=parms_cv['init_model'],
        feature_name=parms_cv['feature_name'],
        categorical_feature=parms_cv['categorical_feature'],
        early_stopping_rounds=parms_cv['early_stopping_rounds'],
        fpreproc=parms_cv['fpreproc'],
        verbose_eval=parms_cv['verbose_eval'],
        show_stdv=parms_cv['show_stdv'],
        seed=parms_cv['seed'],
        callbacks=parms_cv['callbacks'],
        eval_train_metric=parms_cv['eval_train_metric'],
        # return_cvbooster=parms_cv['return_cvbooster']
    )

    return eval_hist
Example #7
0
def lgb_train(X_train,
              y_train,
              X_valid=None,
              y_valid=None,
              objective=None,
              parms_mdl=None,
              parms_train=None,
              mdl_save_path=None,
              logger=None):
    '''
    lightgbm模型训练
    X_train, y_train, X_valid, y_valid为pd或np格式,最好为pd格式数据
    objective为任务类型,支持的任务类型(可添加其他任务类型):
        multiclass、binary、regression
    parms_mdl和parms_train为模型参数和训练参数(dict)
    mdl_save_path为模型本地化路径
    返回训练好的模型和损失函数变化曲线数据
    '''

    logger = simple_logger() if logger is None else logger

    # 检查任务相关参数(注意:若添加其他任务,可能需要添加对应需要检查的参数)
    num_class = len(set(y_train)) if objective == 'multiclass' else 1
    objective, num_class = check_parms_mdl(parms_mdl,
                                           objective,
                                           num_class,
                                           logger=logger)
    # 模型参数和训练参数准备
    parms_mdl = get_parms_mdl(parms_mdl=parms_mdl,
                              objective=objective,
                              num_class=num_class,
                              logger=logger)
    parms_train = get_parms_TrainOrCV(parms_TrainOrCV=parms_train)

    # 数据集准备
    datTrain = lgb.Dataset(
        X_train,
        y_train,
        categorical_feature=parms_train['categorical_feature'])
    if X_valid is None and y_valid is None:
        datValid = None
    else:
        datValid = lgb.Dataset(
            X_valid,
            y_valid,
            categorical_feature=parms_train['categorical_feature'])
    valid_sets = [datTrain, datValid] if datValid is not None else [datTrain]
    valid_names = ['train', 'valid'] if X_valid is not None else ['train']
    evals_result = {}

    # 模型训练
    logger.info('模型训练中...')
    mdl = lgb.train(params=parms_mdl,
                    train_set=datTrain,
                    num_boost_round=parms_train['num_boost_round'],
                    valid_sets=valid_sets,
                    valid_names=valid_names,
                    fobj=parms_train['fobj'],
                    feval=parms_train['feval'],
                    init_model=parms_train['init_model'],
                    feature_name=parms_train['feature_name'],
                    categorical_feature=parms_train['categorical_feature'],
                    early_stopping_rounds=parms_train['early_stopping_rounds'],
                    evals_result=evals_result,
                    verbose_eval=parms_train['verbose_eval'],
                    learning_rates=parms_train['learning_rates'],
                    keep_training_booster=parms_train['keep_training_booster'],
                    callbacks=parms_train['callbacks'])

    # 模型保存
    if not isnull(mdl_save_path):
        # joblib.dump(mdl, 'mdl_save_path')
        pickleFile(mdl, 'mdl_save_path')

    return mdl, evals_result
Example #8
0
def get_parms_mdl(parms_mdl=None, objective=None, num_class=None, logger=None):
    '''
    获取模型参数
    parms_mdl为设置的模型参数,若关键参数没设置,则会补充设置默认值
    objective为任务类型,支持的任务类型(可添加其他任务类型):
        multiclass、binary、regression
    num_class:多分类任务中的类别数(对多分类任务起作用,会对此参数进行特殊检查)
    注:若新增其它任务类型,可能需要设置对应需要特殊检查的参数
    注意:由于lgb参数有别称,可能导致混淆或重复设置,故parms_mdl中出现的参数名称应与
    本函数中默认名称保持一致!
    '''

    logger = simple_logger() if logger is None else logger

    objective, num_class = check_parms_mdl(parms_mdl,
                                           objective,
                                           num_class,
                                           logger=logger)

    # 损失函数
    if objective == 'multiclass':
        metric = ['multi_logloss', 'multi_error']
    elif objective == 'binary':
        metric = ['binary_logloss', 'auc', 'binary_error']
    elif objective == 'regression':
        metric = ['l1', 'l2', 'mape']  # l1=mae, l2=mse

    # 默认参数
    parms_mdl_must_default = {
        'objective': objective,
        'num_class': num_class,
        'metric': metric,
        'boosting': 'gbdt',
        'extra_trees': False,
        'max_depth': 3,
        'num_leaves': 31,
        'min_data_in_leaf': 20,
        'bagging_fraction': 0.75,
        'bagging_freq': 5,
        'feature_fraction': 0.75,
        'max_bin': 255,
        'lambda_l1': 0.1,
        'lambda_l2': 0.1,
        'min_gain_to_split': 0.01,
        'min_sum_hessian_in_leaf': 0.01,
        'path_smooth': 0.0,
        'learning_rate': 0.05,
        'is_unbalance': False,
        # 'random_state': 62,
        'random_state': None,
        'num_threads': 4
    }

    if parms_mdl is None:
        parms_mdl = parms_mdl_must_default
    parms_mdl_loss = {x: parms_mdl_must_default[x] \
                                  for x in parms_mdl_must_default.keys() if \
                                      x not in parms_mdl.keys()}
    parms_mdl.update(parms_mdl_loss)

    return parms_mdl