-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
118 lines (107 loc) · 3.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python
import os.path
import os
from NaiveBayes import NaiveBayesClassifier
from nbio import load_words, store_classifier
import random
from stats import calc_F1
def train( data_dir, outfile , test_pct = 0.3, verbose = True):
'''
train naive bayes classifier using ML estimates.
if test percentage (test_pct) > 0, hold out that percentage of training
data files from each class and use for statistics.
'''
labels = os.listdir( data_dir )
if verbose:
print(labels)
nbc = NaiveBayesClassifier( labels )
def load_label_dir( ddir ):
'''
load all datafiles within single directory
'''
fnames = os.listdir( os.path.join( data_dir, ddir ) )
def w0(f):
''' worker function (full path) '''
return load_words( os.path.join( data_dir, ddir, f ) )
return [ (fname, w0(fname) ) for fname in fnames ]
te_dl = []
n_tr = 0
for label in labels:
for (f, ws) in load_label_dir( label ):
if random.random() < test_pct:
te_dl.append( (f, ws, label) )
else:
nbc.add_example( label, ws )
n_tr += 1
#if we've picked a test set, use it
def show_stats(s, lbl = None):
''' display F1 stats '''
f1 = s['F1'] * 100.0
pr = s['precision'] * 100.0
rc = s['recall'] * 100.0
if None != lbl:
print('%s : F1=%f precision=%f recall=%f' % ( lbl, f1, pr, rc) )
else:
print('F1=%f precision=%f recall=%f' % ( f1, pr, rc) )
if test_pct > 0.0:
preds = []
obs = []
for (f, ws, l) in te_dl:
obs.append(l)
c = nbc.classify( ws )
preds.append(c)
sts = calc_F1( preds, obs )
show_stats( sts['overall'] )
for l in labels:
show_stats( sts[l], l )
#store trained classifier
store_classifier( outfile, nbc )
if __name__ == '__main__':
import getopt
import sys
def usage():
''' show usage message and exit '''
print('train (txtcat-nb)')
print('options:')
print('\t-h --help\t\t\tshow this usage information')
print('\t-v --verbose\t\t\tverbose processing')
print('\t-o --output=\t\t\tclassification model file')
print('\t-d --datadir=\t\t\tdata directory')
print('\t-p --pct=\t\t\tpercentage of data to use for validation')
sys.exit(1)
#data_dir, outfile, test_pct, verbose
# -d --datadir %a, -o --output %a, -p --pct %f , -v --verbose
try:
opts, args = getopt.getopt( sys.argv[1:], 'hvd:o:p:',['help','verbose','datadir=','output=','pct='])
except getopt.GetoptError as err:
print(err)
usage()
if 0 != len(args):
usage()
verbose = False
tpct = 0.3
ddir = None
ofile = None
for o,a in opts:
if o in ('-h','--help'):
usage()
elif o in ('-v', '--verbose'):
verbose = True
elif o in ('-d','--datadir'):
ddir = a
elif o in ('-o','--output'):
ofile = a
elif o in ('-p','--pct'):
tpct = float( a )
else:
print('unrecognized option %s (argument %s)' % (o,a) )
usage()
if None == ddir:
usage()
if None == ofile:
usage()
assert ( 0.0 <= tpct ) and ( tpct <= 1.0 ), 'test percentage out of range'
if not os.path.exists( ddir ):
usage()
#finally time to do work
train( ddir, ofile, tpct, verbose )