def run(args): model = None # 获取训练和测试数据 data = get_data(args.model)[0] test_data = get_data(args.model)[1] # 创建模型结果的目录 if not os.path.exists('results'): os.makedirs('results') if len(os.listdir('results')) > 0: shutil.rmtree('results') os.makedirs('results') # 初始化模型 if args.model == 'regression': model = GradientBoostingRegressor(learning_rate=args.lr, n_trees=args.trees, max_depth=args.depth, min_samples_split=args.count, is_log=args.log, is_plot=args.plot) if args.model == 'binary_cf': model = GradientBoostingBinaryClassifier(learning_rate=args.lr, n_trees=args.trees, max_depth=args.depth, is_log=args.log, is_plot=args.plot) if args.model == 'multi_cf': model = GradientBoostingMultiClassifier(learning_rate=args.lr, n_trees=args.trees, max_depth=args.depth, is_log=args.log, is_plot=args.plot) # 训练模型 model.fit(data) # 记录日志 logger.removeHandler(logger.handlers[-1]) logger.addHandler( logging.FileHandler('results/result.log'.format(iter), mode='w', encoding='utf-8')) logger.info(data) # 模型预测 model.predict(test_data) # 记录日志 logger.setLevel(logging.INFO) if args.model == 'regression': logger.info((test_data['predict_value'])) if args.model == 'binary_cf': logger.info((test_data['predict_proba'])) logger.info((test_data['predict_label'])) if args.model == 'multi_cf': logger.info((test_data['predict_label'])) pass
ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) logger.addHandler(ch) if __name__ == '__main__': data = pd.DataFrame(data=[ [1, 5, 20, 1.1], [2, 7, 30, 1.3], [3, 21, 70, 1.7], [4, 30, 60, 1.8], ], columns=['id', 'age', 'weight', 'label']) model = GradientBoostingRegressor(learning_rate=0.1, n_trees=10, max_depth=3, min_samples_split=2, is_log=False, is_plot=True) model.fit(data) logger.removeHandler(logger.handlers[-1]) logger.addHandler( logging.FileHandler('results/result.log'.format(iter), mode='w', encoding='utf-8')) logger.info(data) test_data = pd.DataFrame(data=[[5, 25, 65]], columns=['id', 'age', 'weight']) model.predict(test_data) logger.setLevel(logging.INFO) logger.info((test_data['predict_value']))