def __init__(self, parms_func={}, parms_opter={}, parms_log={}): ''' parms_func: 目标函数信息,默认应包含'func_name', `x_lb`, `x_ub`, `dim` parms_opter: 优化函数需要用到的参数信息,默认应包含'opter_name', `PopSize`, `Niter` parms_log: 寻优过程中控制打印或日志记录的参数,默认应包含`logger`, `nshow` ''' # 目标函数信息 parms_func_default = { 'func_name': '', 'x_lb': None, 'x_ub': None, 'dim': None, 'kwargs': {} } parms_loss = {x: parms_func_default[x] \ for x in parms_func_default.keys() if \ x not in parms_func.keys()} parms_func.update(parms_loss) self.parms_func = parms_func # 优化算法参数 parms_opter_default = {'opter_name': '', 'PopSize': 20, 'Niter': 100} parms_loss = {x: parms_opter_default[x] \ for x in parms_opter_default.keys() if \ x not in parms_opter.keys()} parms_opter.update(parms_loss) self.parms_opter = parms_opter # 日志参数 parms_log_default = {'logger': simple_logger(), 'nshow': 10} parms_loss = {x: parms_log_default[x] \ for x in parms_log_default.keys() if \ x not in parms_log.keys()} parms_log.update(parms_loss) self.parms_log = parms_log # 优化过程和结果 self.__best_val = None # 全局最优值 self.__best_x = [] # 全局最优解 self.__convergence_curve = [] # 收敛曲线(每轮最优) self.__convergence_curve_mean = [] # 收敛曲线(每轮平均) self.__startTime = None # 开始时间 self.__endTime = None # 结束时间 self.__exeTime = None # 优化用时(单位秒)
def check_parms_mdl(parms_mdl, objective, num_class, logger=None): ''' 检查模型相关参数设置是否正确 parms_mdl为设置的模型参数 检查的参数项目(可添加新的检查项目): 任务类型objective multiclass中的num_class参数 返回检查的参数值 注:后续若更改该函数检查的参数项,则在调用该函数的地方也须做相应修改 ''' logger = simple_logger() if logger is None else logger # objective检查 if objective is None: if isinstance(parms_mdl, dict): if 'objective' not in parms_mdl.keys(): raise ValueError('必须设置任务类型objective或在parms_mdl中设置!') else: objective = parms_mdl['objective'] else: raise ValueError('必须设置任务类型objective或在parms_mdl中设置!') else: if isinstance(parms_mdl, dict) and 'objective' in parms_mdl.keys() \ and parms_mdl['objective'] != objective: logger.warning('objective与parms_mdl设置不一样,以前者为准!') parms_mdl['objective'] = objective if objective not in ['multiclass', 'binary', 'regression']: raise ValueError('{}不是允许的任务类型!'.format(objective)) # multiclass的num_class检查 if objective in ['multiclass']: if not isinstance(num_class, int): if 'num_class' not in parms_mdl.keys(): raise ValueError('多分类任务必须指定类别数num_class, int!') else: num_class = parms_mdl['num_class'] else: if isinstance(parms_mdl, dict) and 'num_class' in \ parms_mdl.keys() and num_class != parms_mdl['num_class']: logger.warning('num_class与parms_mdl设置不一样,以前者为准!') parms_mdl['num_class'] = num_class return objective, num_class
def plot_Series(data, cols_styl_up_left, cols_styl_up_right=None, cols_styl_low_left=None, cols_styl_low_right=None, cols_to_label_info = {}, xparls_info={}, yparls_info_up=None, yparls_info_low=None, ylabels=None, grids=False, figsize=(12, 9), title=None, nXticks=8, fontsize=15, markersize=10, fig_save_path=None, logger=None): ''' todo: markersize分开设置,正常绘制与特殊标注重复绘制问题, x轴平行线对应列不一定非要在主图绘制列中选择 平行线图层绘制在主线下面 标注图层绘制在线型图层上面(根据输入顺序绘制图层而不是根据坐标轴区域顺序绘制) pd.DataFrame多列绘图 Parameters ---------- cols_styl_up_left, cols_styl_up_right, cols_styl_low_left, cols_styl_low_right: 分别指定顶部左轴、顶部右轴、底部左轴和底部右需要绘制的序列及其线型和图例,格式如: {'col1': ('.-b', 'lbl1'), 'col2': ...}或{'col1': '.-b', 'col2': ...} 第一种格式中lbl设置图例(legend),若为None则默认取列名,为False,则不设置图例; 第二种格式只设置线型,legend默认取列名 cols_to_label_info: 设置需要特殊标注的列绘图信息,格式形如: {col1: [[col_lbl1, (v1, v2, ..), (styl1, styl2, ..), (lbl1, lbl2, ..)], [col_lbl2, (v1, v2, ..), ...]], col2: ..},其中col是需要被特殊标注的列,col_lbl为标签列;v指定哪些标签值对应的 数据用于绘图;styl设置线型;lbl设置图例标签,若为None,则设置为v,若为False, 则不设置图例标签 xparls_info: 设置x轴平行线信息,格式形如: {col1: [(yval1, clor1, styl1, width1), (yval2, ...)], col2:, ...}, 其中yval指定平行线y轴位置,clor设置颜色,styl设置线型,width设置线宽 yparls_info_up, yparls_info_low: 分别设置顶部和底部x轴平行线格式信息,格式形如: [(xval1, clor1, styl1, width1), (xval2, clor2, style2, width2), ...], 其中xval指定平行线x轴位置,clor设置颜色,styl设置线型,width设置线宽 ylabels: None或列表,设置四个y轴标签文本内容,若为None则不设置标签文本, 若为False则既不设置y轴标签文本内容,也不显示y轴刻度 grids: 设置四个坐标轴网格,若grids=True,则将顶部左轴和底部左轴绘制网格; 若grids=False,则全部没有网格,grids可设置为列表分别对四个坐标轴设置网格 ''' df = data.copy() logger = simple_logger() if logger is None else logger # 网格设置,grids分别设置顶部左边、顶部右边、底部左边、底部右边的网格 if grids == True: grids = [True, False, True, False] elif grids == False or grids is None: grids = [False, False, False, False] # y轴标签设置 if ylabels is None: ylabels = [None, None, None, None] # 索引列处理 if df.index.name is None: df.index.name = 'idx' idx_name = df.index.name if idx_name in df.columns: df.drop(idx_name, axis=1, inplace=True) df.reset_index(inplace=True) if cols_styl_low_left is None and cols_styl_low_right is not None: logger.warning('当底部图只指定右边坐标轴时,默认绘制在左边坐标轴!') cols_styl_low_left, cols_styl_low_right = cols_styl_low_right, None # 坐标准备 plt.figure(figsize=figsize) if cols_styl_low_left is not None: gs = GridSpec(3, 1) axUpLeft = plt.subplot(gs[:2, :]) # 顶部为主图,占三分之二高度 axLowLeft = plt.subplot(gs[2, :]) else: gs = GridSpec(1, 1) axUpLeft = plt.subplot(gs[:, :]) def get_cols_to_label_info(cols_to_label_info, col): '''需要进行特殊点标注的列绘图设置信息获取''' to_plots = [] for label_infos in cols_to_label_info[col]: lbl_col = label_infos[0] if label_infos[2] is None: label_infos = [lbl_col, label_infos[1], [None]*len(label_infos[1]), label_infos[3]] if label_infos[3] == False: label_infos = [lbl_col, label_infos[1], label_infos[2], [False]*len(label_infos[1])] elif isnull(label_infos[3]) or \ all([isnull(x) for x in label_infos[3]]): label_infos = [lbl_col, label_infos[1], label_infos[2], label_infos[1]] vals = label_infos[1] for k in range(len(vals)): series = df[df[lbl_col] == vals[k]][col] ln_styl = label_infos[2][k] lbl_str = label_infos[3][k] to_plots.append([series, (ln_styl, lbl_str)]) return to_plots def get_xparls_info(parls_info, col, clor_default='r', lnstyl_default='--', lnwidth_default=2): '''x轴平行线绘图设置信息获取''' parls = parls_info[col] to_plots = [] for val, clor, lnstyl, lnwidth in parls: clor = clor_default if clor is None else clor lnstyl = lnstyl_default if lnstyl is None else lnstyl lnwidth = lnwidth_default if lnwidth is None else lnwidth to_plots.append([val, clor, lnstyl, lnwidth]) return to_plots def get_yparls_info(parls_info, clor_default='r', lnstyl_default='--', lnwidth_default=2): '''y轴平行线绘图设置信息获取''' to_plots = [] for val, clor, lnstyl, lnwidth in parls_info: clor = clor_default if clor is None else clor lnstyl = lnstyl_default if lnstyl is None else lnstyl lnwidth = lnwidth_default if lnwidth is None else lnwidth val = df[df[idx_name] == val].index[0] to_plots.append([val, clor, lnstyl, lnwidth]) return to_plots # lns存放双坐标legend信息 # 双坐标轴legend参考:https://www.cnblogs.com/Atanisi/p/8530693.html lns = [] # 顶部左边坐标轴 for col, styl in cols_styl_up_left.items(): ln = plot_series_with_styls_info(axUpLeft, df[col], styl) if ln is not None: lns.append(ln) # 特殊点标注 if col in cols_to_label_info.keys(): to_plots = get_cols_to_label_info(cols_to_label_info, col) for series, styls_info in to_plots: ln = plot_series_with_styls_info(axUpLeft, series, styls_info, lnstyl_default='ko', markersize=markersize) if ln is not None: lns.append(ln) # x轴平行线 if col in xparls_info.keys(): to_plots = get_xparls_info(xparls_info, col) for yval, clor, lnstyl, lnwidth in to_plots: axUpLeft.axhline(y=yval, c=clor, ls=lnstyl, lw=lnwidth) # y轴平行线 if yparls_info_up is not None: to_plots = get_yparls_info(yparls_info_up) for xval, clor, lnstyl, lnwidth in to_plots: axUpLeft.axvline(x=xval, c=clor, ls=lnstyl, lw=lnwidth) # 顶部左边坐标轴网格 axUpLeft.grid(grids[0]) # 标题绘制在顶部图上 if title is not None: axUpLeft.set_title(title, fontsize=fontsize) # y轴标签文本 if ylabels[0] is False: axUpLeft.set_ylabel(None) axUpLeft.set_yticks([]) else: axUpLeft.set_ylabel(ylabels[0], fontsize=fontsize) # 顶部右边坐标轴 if cols_styl_up_right is not None: axUpRight = axUpLeft.twinx() for col, styl in cols_styl_up_right.items(): ln = plot_series_with_styls_info(axUpRight, df[col], styl, lbl_str_ext='(r)') if ln is not None: lns.append(ln) # 特殊点标注 if col in cols_to_label_info.keys(): to_plots = get_cols_to_label_info(cols_to_label_info, col) for series, styls_info in to_plots: ln = plot_series_with_styls_info(axUpRight, series, styls_info, lnstyl_default='ko', markersize=markersize, lbl_str_ext='(r)') if ln is not None: lns.append(ln) # x轴平行线 if col in xparls_info.keys(): to_plots = get_xparls_info(xparls_info, col) for yval, clor, lnstyl, lnwidth in to_plots: axUpRight.axhline(y=yval, c=clor, ls=lnstyl, lw=lnwidth) # 顶部右边坐标轴网格 axUpRight.grid(grids[1]) # y轴标签文本 if ylabels[1] is False: axUpRight.set_ylabel(None) axUpRight.set_yticks([]) else: axUpRight.set_ylabel(ylabels[1], fontsize=fontsize) # 顶部图legend合并显示 if len(lns) > 0: lnsAdd = lns[0] for ln in lns[1:]: lnsAdd = lnsAdd + ln labs = [l.get_label() for l in lnsAdd] axUpLeft.legend(lnsAdd, labs, loc=0, fontsize=fontsize) if cols_styl_low_left is not None: # 要绘制底部图时取消顶部图x轴刻度 # axUpLeft.set_xticks([]) # 这样会导致设置网格线时没有竖线 axUpLeft.set_xticklabels([]) # 这样不会影响设置网格 lns = [] # 底部左边坐标轴 for col, styl in cols_styl_low_left.items(): ln = plot_series_with_styls_info(axLowLeft, df[col], styl) if ln is not None: lns.append(ln) # 特殊点标注 if col in cols_to_label_info.keys(): to_plots = get_cols_to_label_info(cols_to_label_info, col) for series, styls_info in to_plots: ln = plot_series_with_styls_info(axLowLeft, series, styls_info, lnstyl_default='ko', markersize=markersize) if ln is not None: lns.append(ln) # x轴平行线 if col in xparls_info.keys(): to_plots = get_xparls_info(xparls_info, col) for yval, clor, lnstyl, lnwidth in to_plots: axLowLeft.axhline(y=yval, c=clor, ls=lnstyl, lw=lnwidth) # y轴平行线 if yparls_info_low is not None: to_plots = get_yparls_info(yparls_info_low) for xval, clor, lnstyl, lnwidth in to_plots: axLowLeft.axvline(x=xval, c=clor, ls=lnstyl, lw=lnwidth) # 底部左边坐标轴网格 axLowLeft.grid(grids[2]) # y轴标签文本 if ylabels[2] is False: axLowLeft.set_ylabel(None) axLowLeft.set_yticks([]) else: axLowLeft.set_ylabel(ylabels[2], fontsize=fontsize) # 底部右边坐标轴 if cols_styl_low_right is not None: axLowRight = axLowLeft.twinx() for col, styl in cols_styl_low_right.items(): ln = plot_series_with_styls_info(axLowRight, df[col], styl, lbl_str_ext='(r)') if ln is not None: lns.append(ln) # 特殊点标注 if col in cols_to_label_info.keys(): to_plots = get_cols_to_label_info(cols_to_label_info, col) for series, styls_info in to_plots: ln = plot_series_with_styls_info(axLowRight, series, styls_info, lnstyl_default='ko', markersize=markersize, lbl_str_ext='(r)') if ln is not None: lns.append(ln) # x轴平行线 if col in xparls_info.keys(): to_plots = get_xparls_info(xparls_info, col) for yval, clor, lnstyl, lnwidth in to_plots: axLowRight.axhline(y=yval, c=clor, ls=lnstyl, lw=lnwidth) # 底部右边坐标轴网格 axLowRight.grid(grids[3]) # y轴标签文本 if ylabels[3] is False: axLowRight.set_ylabel(None) axLowRight.set_yticks([]) else: axLowRight.set_ylabel(ylabels[3], fontsize=fontsize) # 底部图legend合并显示 if len(lns) > 0: lnsAdd = lns[0] for ln in lns[1:]: lnsAdd = lnsAdd + ln labs = [l.get_label() for l in lnsAdd] axLowLeft.legend(lnsAdd, labs, loc=0, fontsize=fontsize) # x轴刻度 n = df.shape[0] xpos = [int(x*n/nXticks) for x in range(0, nXticks)] + [n-1] plt.xticks(xpos, [df.loc[x, idx_name] for x in xpos]) plt.tight_layout() # 保存图片 if fig_save_path: plt.savefig(fig_save_path) plt.show()
def lgb_cv_hoo(X_train, y_train, objective=None, parms_mdl_list=None, parms_train_list=None, Nfold=5, shuffle=True, random_state=62, mdl_path_list=None, logger=None): ''' 自定义lgb交叉验证,返回模型列表和结果列表 ''' logger = simple_logger() if logger is None else logger # 模型参数设置为列表(每个模型单独设置) if parms_mdl_list is None or isinstance(parms_mdl_list, dict): parms_mdl_list = [parms_mdl_list] * Nfold # 训练参数设置为列表(每个模型单独设置) if parms_train_list is None or isinstance(parms_train_list, dict): parms_train_list = [parms_train_list] * Nfold # 若模型存放路径为str,则新建文件夹并设置保存路径列表 if mdl_path_list is None: mdl_path_list = [None] * Nfold elif isinstance(mdl_path_list, str): abs_dir = os.path.abspath(mdl_path_list) if not os.path.isdir(mdl_path_list): logger.warning('将创建模型存放文件夹{}!'.format(abs_dir)) os.mkdir(abs_dir) mdl_path_list = [os.path.join(abs_dir, 'mdl_kf'+str(k)+'.bin') \ for k in range(1, Nfold+1)] # 交叉验证 mdls, evals_results = [], [] folds = KFold(n_splits=Nfold, shuffle=shuffle, random_state=random_state) for Ikf, (trnIdxs, valIdxs) in enumerate(folds.split(X_train, y_train)): logger.info('{}/{}折交叉验证训练中...'.format(Ikf + 1, Nfold)) if isinstance(X_train, pd.core.frame.DataFrame): X_train_Ikf = X_train.iloc[trnIdxs, :] X_valid_Ikf = X_train.iloc[valIdxs, :] elif isinstance(X_train, np.ndarray): X_train_Ikf = X_train[trnIdxs] X_valid_Ikf = X_train[valIdxs] if isinstance(y_train, pd.core.series.Series): y_train_Ikf = y_train.iloc[trnIdxs] y_valid_Ikf = y_train.iloc[valIdxs] elif isinstance(y_train, np.ndarray): y_train_Ikf = y_train[trnIdxs] y_valid_Ikf = y_train[valIdxs] mdl, evals_result = lgb_train(X_train_Ikf, y_train_Ikf, X_valid=X_valid_Ikf, y_valid=y_valid_Ikf, objective=objective, parms_mdl=parms_mdl_list[Ikf], parms_train=parms_train_list[Ikf], mdl_save_path=mdl_path_list[Ikf], logger=logger) mdls.append(mdl) evals_results.append(evals_result) return mdls, evals_results
def lgb_cv_GridSearch(X_train, y_train, objective=None, parms_mdl=None, parms_to_opt=None, parms_cv=None, logger=None): '''lgb交叉验证网格搜索调参''' logger = simple_logger() if logger is None else logger if not isinstance(parms_to_opt, dict) or len(parms_to_opt) == 0: logger.error('检测到待优化参数parms_to_opt为空!') return None, None # 检查任务相关参数(注意:若添加其他任务,可能需要添加对应需要检查的参数) num_class = len(set(y_train)) if objective == 'multiclass' else 1 objective, num_class = check_parms_mdl(parms_mdl, objective, num_class, logger=logger) # 模型参数和训练参数准备 parms_mdl = get_parms_mdl(parms_mdl=parms_mdl, objective=objective, num_class=num_class, logger=logger) # 注意:这里metric没考虑自定义的情况,自定义metric需要再修改 if len(parms_mdl['metric']) > 1 and \ not isinstance(parms_mdl['metric'], str): metric = list(parms_mdl['metric'])[0] logger.warning('发现多个metric,将以{}为优化目标!'.format(metric)) # 注:由于set是无序的,故当parms_mdl['metric']是set时可能取不到第一个 parms_mdl['metric'] = metric if not isinstance(parms_mdl['metric'], str): parms_mdl['metric'] = list(parms_mdl['metric'])[0] # 判断metric越大越好还是越小越好 if parms_mdl['metric'] in ['auc']: max_good = True best = -np.inf elif parms_mdl['metric'] in [ 'l1', 'l2', 'mape', 'rmse', 'binary_error', 'multi_error', 'binary_logloss', 'multi_logloss' ]: max_good = False best = np.inf else: raise ValueError('未识别的metric: {},请更改或在此函数中增加该支持项!' \ .format(parms_mdl['metric'])) # 将待优化参数网格化 grid_parms = [] opt_items = sorted(parms_to_opt.items()) keys, values = zip(*opt_items) for v in product(*values): grid_parm = dict(zip(keys, v)) grid_parms.append(grid_parm) # cv网格搜索 metric = parms_mdl['metric'] best_parms = None k = 1 for grid_parm in grid_parms: logger.info('交叉验证网格搜索中:{} / {} ...'.format(k, len(grid_parms))) logger.info('当前参数:{}'.format(grid_parm)) k += 1 parms_mdl_now = parms_mdl.copy() parms_mdl_now.update(grid_parm) eval_hist = lgb_cv(X_train, y_train, objective=objective, parms_mdl=parms_mdl_now, parms_cv=parms_cv, logger=logger) if max_good: best_now = max(eval_hist[metric + '-mean']) if best_now > best: best = best_now best_parms = grid_parm else: best_now = min(eval_hist[metric + '-mean']) if best_now < best: best = best_now best_parms = grid_parm return best_parms, {'best ' + metric: best}
def lgb_cv(X_train, y_train, objective=None, parms_mdl=None, parms_cv=None, logger=None): ''' lightgbm交叉验证 X_train, y_train为pd或np格式,最好为pd格式数据 objective为任务类型,支持的任务类型(可添加其他任务类型): multiclass、binary、regression parms_mdl和parms_cv为模型参数和cv参数(dict) 返回各损失函数值? ''' logger = simple_logger() if logger is None else logger # 检查任务相关参数(注意:若添加其他任务,可能需要添加对应需要检查的参数) num_class = len(set(y_train)) if objective == 'multiclass' else 1 objective, num_class = check_parms_mdl(parms_mdl, objective, num_class, logger=logger) # 模型参数和训练参数准备 parms_mdl = get_parms_mdl(parms_mdl=parms_mdl, objective=objective, num_class=num_class, logger=logger) parms_cv = get_parms_TrainOrCV(parms_TrainOrCV=parms_cv) if objective == 'regression': parms_cv['stratified'] = None # 回归任务不适用分层抽样 # 数据集准备 datTrain = lgb.Dataset(X_train, y_train, categorical_feature=parms_cv['categorical_feature']) # 交叉验证 logger.info('交叉验证...') eval_hist = lgb.cv( params=parms_mdl, train_set=datTrain, num_boost_round=parms_cv['num_boost_round'], folds=parms_cv['folds'], nfold=parms_cv['nfold'], stratified=parms_cv['stratified'], shuffle=parms_cv['shuffle'], metrics=parms_cv['metrics'], fobj=parms_cv['fobj'], feval=parms_cv['feval'], init_model=parms_cv['init_model'], feature_name=parms_cv['feature_name'], categorical_feature=parms_cv['categorical_feature'], early_stopping_rounds=parms_cv['early_stopping_rounds'], fpreproc=parms_cv['fpreproc'], verbose_eval=parms_cv['verbose_eval'], show_stdv=parms_cv['show_stdv'], seed=parms_cv['seed'], callbacks=parms_cv['callbacks'], eval_train_metric=parms_cv['eval_train_metric'], # return_cvbooster=parms_cv['return_cvbooster'] ) return eval_hist
def lgb_train(X_train, y_train, X_valid=None, y_valid=None, objective=None, parms_mdl=None, parms_train=None, mdl_save_path=None, logger=None): ''' lightgbm模型训练 X_train, y_train, X_valid, y_valid为pd或np格式,最好为pd格式数据 objective为任务类型,支持的任务类型(可添加其他任务类型): multiclass、binary、regression parms_mdl和parms_train为模型参数和训练参数(dict) mdl_save_path为模型本地化路径 返回训练好的模型和损失函数变化曲线数据 ''' logger = simple_logger() if logger is None else logger # 检查任务相关参数(注意:若添加其他任务,可能需要添加对应需要检查的参数) num_class = len(set(y_train)) if objective == 'multiclass' else 1 objective, num_class = check_parms_mdl(parms_mdl, objective, num_class, logger=logger) # 模型参数和训练参数准备 parms_mdl = get_parms_mdl(parms_mdl=parms_mdl, objective=objective, num_class=num_class, logger=logger) parms_train = get_parms_TrainOrCV(parms_TrainOrCV=parms_train) # 数据集准备 datTrain = lgb.Dataset( X_train, y_train, categorical_feature=parms_train['categorical_feature']) if X_valid is None and y_valid is None: datValid = None else: datValid = lgb.Dataset( X_valid, y_valid, categorical_feature=parms_train['categorical_feature']) valid_sets = [datTrain, datValid] if datValid is not None else [datTrain] valid_names = ['train', 'valid'] if X_valid is not None else ['train'] evals_result = {} # 模型训练 logger.info('模型训练中...') mdl = lgb.train(params=parms_mdl, train_set=datTrain, num_boost_round=parms_train['num_boost_round'], valid_sets=valid_sets, valid_names=valid_names, fobj=parms_train['fobj'], feval=parms_train['feval'], init_model=parms_train['init_model'], feature_name=parms_train['feature_name'], categorical_feature=parms_train['categorical_feature'], early_stopping_rounds=parms_train['early_stopping_rounds'], evals_result=evals_result, verbose_eval=parms_train['verbose_eval'], learning_rates=parms_train['learning_rates'], keep_training_booster=parms_train['keep_training_booster'], callbacks=parms_train['callbacks']) # 模型保存 if not isnull(mdl_save_path): # joblib.dump(mdl, 'mdl_save_path') pickleFile(mdl, 'mdl_save_path') return mdl, evals_result
def get_parms_mdl(parms_mdl=None, objective=None, num_class=None, logger=None): ''' 获取模型参数 parms_mdl为设置的模型参数,若关键参数没设置,则会补充设置默认值 objective为任务类型,支持的任务类型(可添加其他任务类型): multiclass、binary、regression num_class:多分类任务中的类别数(对多分类任务起作用,会对此参数进行特殊检查) 注:若新增其它任务类型,可能需要设置对应需要特殊检查的参数 注意:由于lgb参数有别称,可能导致混淆或重复设置,故parms_mdl中出现的参数名称应与 本函数中默认名称保持一致! ''' logger = simple_logger() if logger is None else logger objective, num_class = check_parms_mdl(parms_mdl, objective, num_class, logger=logger) # 损失函数 if objective == 'multiclass': metric = ['multi_logloss', 'multi_error'] elif objective == 'binary': metric = ['binary_logloss', 'auc', 'binary_error'] elif objective == 'regression': metric = ['l1', 'l2', 'mape'] # l1=mae, l2=mse # 默认参数 parms_mdl_must_default = { 'objective': objective, 'num_class': num_class, 'metric': metric, 'boosting': 'gbdt', 'extra_trees': False, 'max_depth': 3, 'num_leaves': 31, 'min_data_in_leaf': 20, 'bagging_fraction': 0.75, 'bagging_freq': 5, 'feature_fraction': 0.75, 'max_bin': 255, 'lambda_l1': 0.1, 'lambda_l2': 0.1, 'min_gain_to_split': 0.01, 'min_sum_hessian_in_leaf': 0.01, 'path_smooth': 0.0, 'learning_rate': 0.05, 'is_unbalance': False, # 'random_state': 62, 'random_state': None, 'num_threads': 4 } if parms_mdl is None: parms_mdl = parms_mdl_must_default parms_mdl_loss = {x: parms_mdl_must_default[x] \ for x in parms_mdl_must_default.keys() if \ x not in parms_mdl.keys()} parms_mdl.update(parms_mdl_loss) return parms_mdl