def train(args): if args.verbosity == 0: logging.basicConfig(level=logging.WARNING) elif args.verbosity == 1: logging.basicConfig(level=logging.INFO) else: logging.basicConfig(level=logging.DEBUG) assert args.role in ['leader', 'follower', 'local'], \ "role must be leader, follower, or local" assert args.mode in ['train', 'test', 'eval'], \ "mode must be train, test, or eval" if args.data_path.endswith('.csv'): with open(args.data_path, 'rb') as fin: data = np.loadtxt(fin, delimiter=',') if args.mode == 'train' or args.mode == 'test': if args.role == 'leader' or args.role == 'local': X = data[:, :-1] y = data[:, -1] else: X = data y = None else: # eval X = data y = None else: raise ValueError("Unsupported data type %s" % args.data_path) if args.role != 'local': bridge = Bridge(args.role, int(args.local_addr.split(':')[1]), args.peer_addr, args.application_id, 0) else: bridge = None booster = BoostingTreeEnsamble(bridge, learning_rate=args.learning_rate, max_iters=args.max_iters, max_depth=args.max_depth, l2_regularization=args.l2_regularization, max_bins=args.max_bins, num_parallel=args.num_parallel) if args.load_model_path: booster.load_saved_model(args.load_model_path) if args.mode == 'train': booster.fit(X, y, args.checkpoint_path) elif args.mode == 'test': pred = booster.batch_predict(X) acc = sum((pred > 0.5) == y) / len(y) logging.info("Test accuracy: %f", acc) else: pred = booster.batch_predict(X) for i in pred: print(i) if args.export_path: booster.save_model(args.export_path)
def run(args): if args.verbosity == 0: logging.basicConfig(level=logging.WARNING) elif args.verbosity == 1: logging.basicConfig(level=logging.INFO) else: logging.basicConfig(level=logging.DEBUG) assert args.role in ['leader', 'follower', 'local'], \ "role must be leader, follower, or local" assert args.mode in ['train', 'test', 'eval'], \ "mode must be train, test, or eval" #follower或leader if args.role != 'local': bridge = Bridge(args.role, int(args.local_addr.split(':')[1]), args.peer_addr, args.application_id, 0, streaming_mode=args.use_streaming) else: bridge = None try: #boost booster = BoostingTreeEnsamble( bridge, learning_rate=args.learning_rate, max_iters=args.max_iters, max_depth=args.max_depth, l2_regularization=args.l2_regularization, max_bins=args.max_bins, num_parallel=args.num_parallel, loss_type=args.loss_type, send_scores_to_follower=args.send_scores_to_follower, send_metrics_to_follower=args.send_metrics_to_follower) #加载已存储的模型 if args.load_model_path: booster.load_saved_model(args.load_model_path) #训练不需要bridge,为什么呢 if args.mode == 'train': train(args, booster) #测试,评估模型需要bridge else: # args.mode == 'test, eval' test(args, bridge, booster) #把模型存起来 if args.export_path: booster.save_model(args.export_path) except Exception as e: logging.fatal( 'Exception raised during training: %s', traceback.format_exc()) raise e finally: #结束bridge if bridge: bridge.terminate()