Exemple #1
0
    def __split(self):
        """split current node"""
        max_gain, attr_index, partition_value = self.__information_gain()

        self.leaf = False
        self.attr_index = attr_index
        self.partition_value = partition_value

        tn_le = DTreeGain(DTreeSample(), self.param)
        tn_le.sample.m = self.sample.m
        tn_gt = DTreeGain(DTreeSample(), self.param)
        tn_gt.sample.m = self.sample.m

        for i in range(0, self.sample.n):
            x = self.sample.X[i]
            y = self.sample.Y[i]
            attr_value = x[attr_index]
            if attr_value <= partition_value:
                tn_le.sample.add_xy(x, y)
            else:
                tn_gt.sample.add_xy(x, y)

        left_n = len(tn_le.sample)
        right_n = len(tn_gt.sample)
        print >> sys.stderr, '[split]level=%d, max_gain=%f, attr_index=%d, partition_value=%f, left/right=%d/%d' % (
            self.level, max_gain, attr_index, partition_value, left_n, right_n)

        if left_n == 0:
            tn_le = None
        if right_n == 0:
            tn_gt = None
        return tn_le, tn_gt
Exemple #2
0
    def __split(self):
        """split current node"""
        max_gain, attr_index, partition_value, y_value_left, y_value_right = self.__gain()

        self.leaf = False
        self.attr_index = attr_index
        self.partition_value = partition_value
        self.value = None

        tn_le = DTreeLoss(DTreeSample(), self.param)
        tn_le.sample.m = self.sample.m
        tn_le.level = self.level + 1
        tn_le.value = y_value_left
        tn_gt = DTreeLoss(DTreeSample(), self.param)
        tn_gt.sample.m = self.sample.m
        tn_gt.level = self.level + 1
        tn_gt.value = y_value_right

        for i in range(0, self.sample.n):
            x = self.sample.X[i]
            y = self.sample.Y[i]
            y_residual = self.sample.Y_residual[i]
            attr_value = x[attr_index]
            if attr_value <= partition_value:
                tn_le.sample.add_xyr(x, y, y_residual)
            else:
                tn_gt.sample.add_xyr(x, y, y_residual)
        tn_le.__residual_2_response()
        tn_gt.__residual_2_response()

        left_n = len(tn_le.sample)
        right_n = len(tn_gt.sample)
        print >>sys.stderr, '[split]level=%d, max_gain=%f, attr_index=%d, partition_value=%f, left/right=%d/%d' % (self.level, max_gain, attr_index, partition_value, left_n, right_n)

        if left_n == 0:
            tn_le = None
        if right_n == 0:
            tn_gt = None
        return tn_le, tn_gt
Exemple #3
0
# -*- coding: utf-8 -*-
#
# 预测房地产价格
#
# author: yafei([email protected])
#
import sys
import codecs
import locale
from dtree_loss import DTreeLoss
from dtree_parameter import DTreeParameter
from dtree_sample import DTreeSample

if __name__ == '__main__':
    param = DTreeParameter()
    sample = DTreeSample()
    sample.load('real-estate.txt')
    dt = DTreeLoss(sample, param)
    dt.train(None)

    feature_map = {
        0: u'结构',
        1: u'装修',
        2: u'周边',
        3: u'地段',
        4: u'绿化',
        5: u'交通',
        6: u'户均车位',
    }
    # 为了输出中文
    locale.setlocale(locale.LC_ALL, '')
# -*- coding: utf-8 -*-
#
# 预测房地产价格
#
# author: yafei([email protected])
#
import sys
import codecs
import locale
from dtree_loss import DTreeLoss
from dtree_parameter import DTreeParameter
from dtree_sample import DTreeSample

if __name__ == '__main__':
    param = DTreeParameter()
    sample = DTreeSample()
    sample.load('real-estate.txt')
    dt = DTreeLoss(sample, param)
    dt.train(None)

    feature_map = {
            0: u'结构',
            1: u'装修',
            2: u'周边',
            3: u'地段',
            4: u'绿化',
            5: u'交通',
            6: u'户均车位',
            }
    # 为了输出中文
    locale.setlocale(locale.LC_ALL, '')
Exemple #5
0
#
# 预测weibo粉丝是否是僵尸粉
#
# author: yafei([email protected])
#
import sys
import codecs
import locale
from dtree_gain import DTreeGain
from dtree_parameter import DTreeParameter
from dtree_sample import DTreeSample

if __name__ == '__main__':
    param = DTreeParameter()
    param.split_threshold = 0.93
    sample = DTreeSample()
    sample.load('weibo.txt')
    dt = DTreeGain(sample, param)
    dt.train()

    feature_map = {
            0: u'注册天数',
            1: u'加V',
            2: u'关注',
            3: u'粉丝',
            4: u'微博',
            5: u'收藏',
            6: u'互粉',
            7: u'共同好友',
            8: u'tag数',
            9: u'等级',
Exemple #6
0
                last_tree = self.trees[i-1]
                residual = last_tree.next_residual()

            print >>sys.stderr, 'training tree #%d' % (i)
            tree.train(residual)
            self.trees.append(tree)


    def predict(self, x):
        y = self.F0
        for tree in self.trees:
            y += tree.predict(x)
        return y


if __name__ == '__main__':
    param = DTreeParameter()
    param.max_level = 4
    param.split_threshold = 0.8
    param.max_attr_try_time = 1000
    param.tree_number = 20
    param.learning_rate = 0.5

    sample = DTreeSample()
    sample.load_liblinear('heart_scale.txt')

    gbdt = GBDT(sample)
    gbdt.train(param)
    print gbdt.predict([0.708333,1,1,-0.320755,-0.105023,-1,1,-0.419847,-1,-0.225806,0,1,-1])
    print gbdt.predict([0.583333,-1,0.333333,-0.603774,1,-1,1,0.358779,-1,-0.483871,0,-1,1])
Exemple #7
0
            self.trees.append(tree)

    def predict(self, x):
        y = self.F0
        for tree in self.trees:
            y += tree.predict(x)
        return y


if __name__ == '__main__':
    param = DTreeParameter()
    param.max_level = 4
    param.split_threshold = 0.8
    param.max_attr_try_time = 1000
    param.tree_number = 20
    param.learning_rate = 0.5

    sample = DTreeSample()
    sample.load_liblinear('heart_scale.txt')

    gbdt = GBDT(sample)
    gbdt.train(param)
    print gbdt.predict([
        0.708333, 1, 1, -0.320755, -0.105023, -1, 1, -0.419847, -1, -0.225806,
        0, 1, -1
    ])
    print gbdt.predict([
        0.583333, -1, 0.333333, -0.603774, 1, -1, 1, 0.358779, -1, -0.483871,
        0, -1, 1
    ])