-
Notifications
You must be signed in to change notification settings - Fork 0
/
decoder.py
77 lines (67 loc) · 2.31 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# /usr/bin/python
from __future__ import division
import viterbi
from framework import read_data, get_maps
import sys
def read_weights(weightsfile):
weights = {}
feats = open(weightsfile, 'r')
while 1:
line = feats.readline()
if not line:
break
line = line.strip()
f, wt = line.split(' ')
weights[int(f)] = float(wt)
feats.close()
return weights
def decode(sents, goldtagseqs, postagseqs, info, weights) : #estfile, weightsfile):
labelset = ['B', 'I', 'O', '*']
tp_bi = 0
tp_o = 0
acc = 0.0
tot = 0
tot_rec_bi = 0
tot_rec_o = 0
tot_prec_bi = 0
tot_prec_o = 0
sys.stderr.write("total test sentences = " + str(len(sents)) + "\n")
for i in range(len(sents)):
sys.stderr.write(str(i) + "\r")
sent = sents[i]
postags = postagseqs[i]
tags = viterbi.execute(sent, labelset, postags, weights, info)
for j in range(len(tags)):
if tags[j] == goldtagseqs[i][j]:
acc += 1
if goldtagseqs[i][j] in ('B','I') and tags[j] in ('B','I'):
tp_bi += 1
elif goldtagseqs[i][j] == "O" and tags[j] == "O":
tp_o += 1
if goldtagseqs[i][j] in ('B', 'I'):
tot_rec_bi += 1
else:
tot_rec_o += 1
if tags[j] in ('B', 'I'):
tot_prec_bi += 1
else:
tot_prec_o += 1
print sent[j]+"\t"+postags[j]+"\t"+goldtagseqs[i][j]+"\t"+tags[j]
print
tot += len(tags)
sys.stderr.write("accuracy = " + str(acc/tot) + "\n")
sys.stderr.write("BI recall = " + str(tp_bi/tot_rec_bi) + "\n")
if tot_prec_bi > 0:
sys.stderr.write("BI precision = " + str(tp_bi/tot_prec_bi) + "\n")
sys.stderr.write("O recall = " + str(tp_o/tot_rec_o) + "\n")
if tot_prec_o > 0:
sys.stderr.write("O precision = " + str(tp_o/tot_prec_o) + "\n\n")
if __name__ == "__main__":
testfile = sys.argv[1]
weightsfile = sys.argv[2]
gazfile = sys.argv[3]
brownfile = sys.argv[4]
sents, goldtagseqs, postagseqs = read_data(testfile)
info = get_maps(sents, postagseqs, gazfile, brownfile)
weights = read_weights(weightsfile)
decode(sents, goldtagseqs, postagseqs, info, weights)