forked from jaredly/decision-tree
-
Notifications
You must be signed in to change notification settings - Fork 0
/
crossvalidation.py
47 lines (38 loc) · 1.16 KB
/
crossvalidation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python
from scipy.io.arff import loadarff
import numpy as np
from dtree import Node
from pandas import DataFrame, concat
def nfold_arff(fname, num, target=None):
data, meta = loadarff(fname)
if target is None:
target = meta.names()[-1]
data = DataFrame(data)
return nfold(meta, data, target, num)
def nfold(meta, data, target, num=10):
total = len(data)
validate = total // num
ix = np.array(data.index)
np.random.shuffle(ix)
test = []
tries = []
nodes = []
for i in range(num):
print 'val with', validate, total
first = ix[:i*validate]
second = ix[(i+1)*validate:]
training = concat([data.loc[first], data.loc[second]])
validating = data.loc[ix[i*validate:(i+1)*validate]]
node = Node(meta, training, target)
tests = node.run()
vals = node.validate(validating)[1]
print tests, vals
test.append(tests)
tries.append(vals)
nodes.append(node)
avg = sum(tries)/len(tries)
print avg
node = Node(meta, data, target)
wrong = node.run()
return avg, wrong, node, test, tries, nodes
# vim: et sw=4 sts=4