Пример #1
0
def train(args):
    if args.verbosity == 0:
        logging.basicConfig(level=logging.WARNING)
    elif args.verbosity == 1:
        logging.basicConfig(level=logging.INFO)
    else:
        logging.basicConfig(level=logging.DEBUG)

    assert args.role in ['leader', 'follower', 'local'], \
        "role must be leader, follower, or local"
    assert args.mode in ['train', 'test', 'eval'], \
        "mode must be train, test, or eval"

    if args.data_path.endswith('.csv'):
        with open(args.data_path, 'rb') as fin:
            data = np.loadtxt(fin, delimiter=',')
        if args.mode == 'train' or args.mode == 'test':
            if args.role == 'leader' or args.role == 'local':
                X = data[:, :-1]
                y = data[:, -1]
            else:
                X = data
                y = None
        else:  # eval
            X = data
            y = None
    else:
        raise ValueError("Unsupported data type %s" % args.data_path)

    if args.role != 'local':
        bridge = Bridge(args.role, int(args.local_addr.split(':')[1]),
                        args.peer_addr, args.application_id, 0)
    else:
        bridge = None

    booster = BoostingTreeEnsamble(bridge,
                                   learning_rate=args.learning_rate,
                                   max_iters=args.max_iters,
                                   max_depth=args.max_depth,
                                   l2_regularization=args.l2_regularization,
                                   max_bins=args.max_bins,
                                   num_parallel=args.num_parallel)

    if args.load_model_path:
        booster.load_saved_model(args.load_model_path)

    if args.mode == 'train':
        booster.fit(X, y, args.checkpoint_path)
    elif args.mode == 'test':
        pred = booster.batch_predict(X)
        acc = sum((pred > 0.5) == y) / len(y)
        logging.info("Test accuracy: %f", acc)
    else:
        pred = booster.batch_predict(X)
        for i in pred:
            print(i)

    if args.export_path:
        booster.save_model(args.export_path)
Пример #2
0
def run(args):
    if args.verbosity == 0:
        logging.basicConfig(level=logging.WARNING)
    elif args.verbosity == 1:
        logging.basicConfig(level=logging.INFO)
    else:
        logging.basicConfig(level=logging.DEBUG)
    
    assert args.role in ['leader', 'follower', 'local'], \
        "role must be leader, follower, or local"
    assert args.mode in ['train', 'test', 'eval'], \
        "mode must be train, test, or eval"
    #follower或leader
    if args.role != 'local':
        bridge = Bridge(args.role, int(args.local_addr.split(':')[1]),
                        args.peer_addr, args.application_id, 0,
                        streaming_mode=args.use_streaming)
    else:
        bridge = None

    try:
        #boost
        booster = BoostingTreeEnsamble(
            bridge,
            learning_rate=args.learning_rate,
            max_iters=args.max_iters,
            max_depth=args.max_depth,
            l2_regularization=args.l2_regularization,
            max_bins=args.max_bins,
            num_parallel=args.num_parallel,
            loss_type=args.loss_type,
            send_scores_to_follower=args.send_scores_to_follower,
            send_metrics_to_follower=args.send_metrics_to_follower)
        #加载已存储的模型
        if args.load_model_path:
            booster.load_saved_model(args.load_model_path)
        #训练不需要bridge,为什么呢
        if args.mode == 'train':
            train(args, booster)
        #测试,评估模型需要bridge
        else:  # args.mode == 'test, eval'
            test(args, bridge, booster)
        #把模型存起来
        if args.export_path:
            booster.save_model(args.export_path)
    except Exception as e:
        logging.fatal(
            'Exception raised during training: %s',
            traceback.format_exc())
        raise e
    finally:
        #结束bridge
        if bridge:
            bridge.terminate()