/
lab1.py
80 lines (71 loc) · 2.25 KB
/
lab1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
import dtree as d
import monkdata as m
import drawtree as dt
from texttable import Texttable
monkdata = [m.monk1,m.monk2,m.monk3]
testdata = [m.monk1test,m.monk2test,m.monk3test]
# Assignment 1
def assignment1():
print "--- Assignment 1 ---"
print "Initial entropy of the datasets"
table = Texttable(max_width=100)
table.add_row(["Dataset","Entropy"])
for i in range(3):
table.add_row(["Monk-" + str(i+1), d.entropy(monkdata[i])])
print table.draw()
print
# Assignment 2
def assignment2():
print "--- Assignment 2 ---"
print "Selecting the root of the decision tree"
table = Texttable(max_width=100)
table.add_row(["Dataset", "a1", "a2", "a3", "a4", "a5", "a6"])
for i in range(3):
gains = map(lambda att: d.averageGain(monkdata[i],att), m.attributes)
table.add_row(["Monk-" + str(i+1)] + gains)
print table.draw()
print
# Assignment 3
def assignment3():
print "--- Assignment 3 ---"
print "Performance of the decision trees"
table = Texttable(max_width=100)
table.add_row(["Dataset", "Training", "Test"])
for i in range(3):
tree = d.buildTree(monkdata[i],m.attributes)
perf = [d.check(tree, monkdata[i]), d.check(tree, testdata[i])]
table.add_row(["Monk-" + str(i+1)] + perf)
print table.draw()
print
# Return the best tree, pruned or otherwise, for the given validation set
def best_pruned(base,valid_set):
pruned = d.allPruned(base)
best = (base,d.check(base,valid_set))
for tree in pruned:
perf = d.check(tree,valid_set)
if perf >= best[1]:
best = (tree, perf)
return best
# Assignment 4
def assignment4():
print "--- Assignment 4 ---"
print "Selecting the best fraction to divide training and validation sets for pruning"
table = Texttable(max_width=100)
table.add_row(["Dataset", "0.3", "0.4", "0.5", "0.6", "0.7", "0.8", "Benchmark"])
for i in range(3):
row = ["Monk-" + str(i+1)]
for frac in [(x * 0.1) for x in range(3,9)]:
train_set, valid_set = m.partition(monkdata[i], frac)
base = d.buildTree(train_set,m.attributes)
best = best_pruned(base,valid_set)
true_perf = d.check(best[0],testdata[i])
row += [true_perf]
row += [d.check(d.buildTree(monkdata[i],m.attributes),testdata[i])]
table.add_row(row)
print table.draw()
print
assignment1()
assignment2()
assignment3()
assignment4()