Ejemplo n.º 1
0
def main():
    fr = open('lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree = trees.createTree(lenses, lensesLabels)

    print lensesTree

    plotTree.createPlot(lensesTree)
Ejemplo n.º 2
0
    if type(dataSet[0][bestFeat]).__name__ == 'str':
        currentlabel = labels_full.index(labels[bestFeat])
        featValuesFull = [example[currentlabel] for example in data_full]
        uniqueValsFull = set(featValuesFull)
    del (labels[bestFeat])
    #针对bestFeat的每个取值,划分出一个子树。
    for value in uniqueVals:
        subLabels = labels[:]
        if type(dataSet[0][bestFeat]).__name__ == 'str':
            uniqueValsFull.remove(value)
        myTree[bestFeatLabel][value] = createTree(
            splitDataSet(dataSet, bestFeat, value), subLabels, data_full,
            labels_full)
    if type(dataSet[0][bestFeat]).__name__ == 'str':
        for value in uniqueValsFull:
            myTree[bestFeatLabel][value] = majorityCnt(classList)
    return myTree


df = pd.read_csv('watermelon_2.csv')
data = df.values[:11, 1:].tolist()
data_full = data[:]
labels = df.columns.values[1:-1].tolist()
labels_full = labels[:]
myTree = createTree(data, labels, data_full, labels_full)
print(myTree)

import plotTree

plotTree.createPlot(myTree)
    tree = {FeatureLabel: {}}
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in data]
    featValues = set(featValues)  # Remove duplicate values
    for value in featValues:
        newLabels = labels[:]
        tree[FeatureLabel][value] = buildTree(splitData(data, bestFeat, value), newLabels)
    return tree


# use date & labels
tree = buildTree(data, labels)
print tree

import plotTree

plotTree.createPlot(tree)


def classify(tree, featLabels, test):
    firstKey = tree.keys()[0]
    secondDict = tree[firstKey]
    featIndex = featLabels.index(firstKey)
    for key in secondDict.keys():
        if test[featIndex] == key:
            if type(secondDict[key]).__name__ == "dict":
                classLabel = classify(secondDict[key], featLabels, test)
            else:
                classLabel = secondDict[key]
    return classLabel
Ejemplo n.º 4
0
        data = [ord(rawData[0]) - ord('a'),ord(rawData[1]) - ord('1'),\
                ord(rawData[2]) - ord('a'),ord(rawData[3]) - ord('1'),\
                ord(rawData[4]) - ord('a'),ord(rawData[5]) - ord('1')]
        if rawData[6] not in label2Index:  #将分类用编号代替,因为感觉字符串的开销更高
            label2Index[rawData[6]] = index
            classifications.append(rawData[6])
            index += 1
        data.append(label2Index[rawData[6]])
        dataset.append(data)
    return [data[0:6] for data in dataset], [data[-1] for data in dataset]


def cutTree(tree, depth):
    newTree = {}
    if (depth == 4):
        return "more"
    if not isinstance(tree, dict):
        return tree
    for node in tree:
        newTree[node] = cutTree(tree[node], depth + 1)
    return newTree


if __name__ == '__main__':
    (trainset, trainlabel) = readDataset("trainset.csv")
    (testset, testlabel) = readDataset("testset.csv")
    tree = createTree(trainset, trainlabel, testset, testlabel)
    #print(tree)
    #print(cutTree(tree,0))
    plotTree.createPlot(cutTree(tree, 0))
def createDataSet(filename):
    file = open(filename)
    lines = file.readlines()
    file.close()

    dataSet = []
    labels = []
    features = ['age', 'prescript', 'astigmatic', 'tearRate']
    for line in lines:
        lst = line.split()
        dataSet.append(lst[:4])
        labels.append(' '.join(lst[4:]))

    return dataSet, labels, features


# Create training data
dataSet, labels, features = createDataSet('data/lenses/training.txt')
tree = decisionTreeID3.buildTree(dataSet, labels)
plotTree.createPlot(tree, features)

# Create test data
testData, testLabels, _ = createDataSet('data/lenses/test.txt')
correct = 0
for i in range(len(testData)):
    res = decisionTreeID3.evaluate(tree, testData[i])
    if res == testLabels[i]:
        correct += 1

print('Correctness: %d/%d' % (correct, len(testData)))
    print("error: the testVect feat:%s, has the value: %s, but there is no the value in trainingDataSet" % (str(firstStr), str(testVect[featIndex])))      
    

"""
使用pickle模块存储决策树
"""
def storeTree(inputTree, fileName):
    #以二进制的形式存储
    fw = open(fileName, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()

"""
使用pickle模块读取决策树
"""
def grabTree(fileName):
    #以二进制的形式读取
    fr = open(fileName, 'rb')
    return pickle.load(fr)

if __name__ == '__main__':
    fr = open(r"lenses.txt")
    dataSet = [line.strip().split('\t') for line in fr.readlines()]
    classLables = ['age','prescript', 'astigmatic', 'tearRate']
    lensesTree = createTree(dataSet, classLables)
    print(lensesTree)
    import plotTree as pt
    pt.createPlot(lensesTree)
    #storeTree(lensesTree, r"lensesTree.txt")
    #print(grabTree(r"lensesTree.txt"))
Ejemplo n.º 7
0
import decideTree
import storeTree
import plotTree

if __name__ == '__main__':
    #data = [[1,1,'是'],[1,1,'是'],[2,1,'否'],[2,2,'否'],[2,2,'否']]
    # shannoEnt = calcshannoEnt(data)
    # print(shannoEnt)
    #print(classifyDataSetByFeature(data,0,2))
    #dataset,labels = decideTree.createDataSet()
    #decideTree = decideTree.createDecideTree(dataset,labels)
    #print(decideTree)
    #plotTree.createPlot()
    #--------------------------------------------------------------
    #storeTree.storeTree(decideTree,'myDecideTree.txt')
    # myTree = storeTree.grabTree('myDecideTree.txt')
    # print(myTree)

    #--------------------------------------------------------------
    file = open('lenses.txt')
    labels = file.readline().strip().split('\t')
    lenses = [example.strip().split('\t') for example in file.readlines()]
    decideTree = decideTree.createDecideTree(lenses, labels)
    print(decideTree)
    plotTree.createPlot(decideTree)
Ejemplo n.º 8
0
import plotTree
import createDecisionTree
import computePara

if __name__ == '__main__':
    data_set, labels = computePara.createDataset()
    origin_labels = labels.copy()
    decision_tree = createDecisionTree.createTree(data_set, labels)
    test_label = createDecisionTree.classify(decision_tree, origin_labels,
                                             [1, 0, 1])
    print('该测试向量属于的类别是: ', test_label)
    plotTree.createPlot(decision_tree)
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 06 14:21:43 2015

@author: Herbert
"""
import tree
import plotTree
import saveTree

fr = open('lenses.txt')
lensesData = [data.strip().split('\t') for data in fr.readlines()]
lensesLabel = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = tree.buildTree(lensesData, lensesLabel)
#print lensesData
print lensesTree

print plotTree.createPlot(lensesTree)
Ejemplo n.º 10
0
#!/usr/bin/python
#coding = utf-8

import decision_tree
import plotTree

dataSet = [[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no'],[1,1,'yes']]
labels = ['one', 'two', 'answer']

tree = decision_tree.createTree(dataSet, labels)
plotTree.createPlot()
Ejemplo n.º 11
0
#shannonEnt = calcShannonEnt(dataset)
#print shannonEnt
#reduDataset = splitDataset(dataset, 0, 0)
#print reduDataset
#bestFeatIdx = chooseBestFeatureToSplitV2(dataset)
#print bestFeatIdx
#print dataset
copyLabels = []
for label in labels:
	copyLabels.append(label)
tree = createTree(dataset, copyLabels);
print tree
#classLabel = classify(tree, labels, [1, 0])
#print classLabel
#classLabel = classify(tree, labels, [1, 1])
#print classLabel
plotTree.createPlot(tree)
#treeFile = 'storeTree.txt'
#storeTree(tree, treeFile)
#rebuildTree = grabTree(treeFile)
#print rebuildTree