예제 #1
0
파일: startup.py 프로젝트: yuanjungod/GBDT
def train_model():
    data_file = './data/feature_data.csv'
    dateset = DataSet(data_file)

    gbdt = GBDT(max_iter=80,
                sample_rate=0.8,
                learn_rate=0.1,
                max_depth=7,
                loss_type='regression')
    gbdt.fit(dateset, set(list(dateset.get_instances_idset())[:1200]))

    GBDT.save_model(gbdt, "./", "test")

    predict = gbdt.predict(dateset.instances[1])
    print "predict", predict, dateset.get_instance(1)['label']
    print "#########################"
    predict = gbdt.predict(dateset.instances[2])
    print "predict", predict, dateset.get_instance(2)['label']
    print "#########################"
    predict = gbdt.predict(dateset.instances[3])
    print "predict", predict, dateset.get_instance(3)['label']
    print "#########################"
    predict = gbdt.predict(dateset.instances[4])
    print "predict", predict, dateset.get_instance(4)['label']
    predict = gbdt.predict(dateset.instances[402])
    print "predict", predict, dateset.get_instance(402)['label']
예제 #2
0
파일: startup.py 프로젝트: buptss/GBDT
iter9 : train loss=0.035695
iter10 : train loss=0.030581
iter11 : train loss=0.027034
iter12 : train loss=0.024570
iter13 : train loss=0.019227
iter14 : train loss=0.015794
iter15 : train loss=0.013484
iter16 : train loss=0.010941
iter17 : train loss=0.009879
iter18 : train loss=0.008619
iter19 : train loss=0.007306
iter20 : train loss=0.005610
"""
from gbdt.data import DataSet
from gbdt.model import GBDT

if __name__ == '__main__':
    data_file = './data/credit.data.csv'
    dataset = DataSet(data_file)
    gbdt = GBDT(max_iter=20,
                sample_rate=0.8,
                learn_rate=0.5,
                max_depth=7,
                loss_type='binary-classification')
    print(dataset.get_instances_idset())
    gbdt.fit(dataset, dataset.get_instances_idset())
    for id in dataset.get_instances_idset():
        instance = dataset.get_instance(id)
        gbdt.predict(instance)
    #gbdt.predict(dataset)
예제 #3
0
from gbdt.model import GBDT
from gbdt.data import DataSet

model = GBDT(tree_depth=3, learning_rate=0.01, max_iter=2000)
dataset = DataSet('data/ages.csv', 'Age')
model.fit(dataset)
x = {'LikesGardening': False, 'PlaysVideoGames': True, 'LikesHats': False}
print model.predict(x)