forked from jlipps/hmm-tagger
/
Tagger.py
325 lines (248 loc) · 13.1 KB
/
Tagger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
######### Tagger.py #########
from __future__ import division # use floating point division
from nltk import ConditionalFreqDist # for frequency distributions
from Helper import msg # for logging
from HMM import HMM # our Hidden Markov Model class
from Treebank import Treebank # our corpus class
from PennTags import PennTags # our tag list
import time # for timing various processes
class Tagger:
"""
A class for POS-tagging text and evaluating the result
"""
######### CLASS VARIABLES #########
# a fake START tag to add to the beginning of sentences to help with tagging
start_tag = '^'
# number of times for a POS tagging mistake to occur in order to show it to user
mistake_threshold = 50
# x-fold cross-validation
test_cycles = 2
def __init__(self, corpus_path, corpus_files, test_files):
"""
Construct a Tagger object
:param corpus_path: path to corpus files
:param corpus_files: list of corpus files
"""
# object for working with training data
self.training = Treebank(corpus_path, corpus_files)
# object for working with testing data
self.testing = Treebank(corpus_path, test_files)
# will contain a list of tags in training corpus
self.pos_tags = False
# will be object for running the Hidden Markov Model for tagging
self.hmm = False
# use PennTags
self.tags = PennTags
# will hold conditional frequency distribution for P(Wi|Ck)
self.words_given_pos = False
# will hold conditional frequency distribution for P(Ci+1|Ci)
self.pos2_given_pos1 = False
######### `PUBLIC' FUNCTIONS #########
def run_test_cycles(self):
"""
Run the test cycles for training and testing the tagger.
Specifically, employ ten-fold cross-validation to train/test on different
segments of the corpus.
"""
total_time_start = time.time() # keep track of time
rights = [] # array to hold number of correctly-tagged words for each test
wrongs = [] # array to hold number of incorrectly-tagged words for each test
totals = [] # array to hold number of total words for each test
all_missed = [] # array to hold incorrect tag information for each test
sep = ''.join(["-" for i in range(50)]) + "\n" # logging s
# returns tagged sentences
training_sents = self.training.tagged_sents
self.train(training_sents)
# returns untagged sentences
testing_tagged_sents = self.testing.tagged_sents
testing_untagged_sents = self.testing.sents
testing_sents = (testing_untagged_sents, testing_tagged_sents)
(right, wrong, missed) = self.test(testing_sents)
# gather accuracy statistics for this test
total = right + wrong
rights.append(right) # store the correct count for this test cycle
wrongs.append(wrong) # store the incorrect count for this test cycle
totals.append(total) # store the total words tested for this test cycle
all_missed += missed # add incorrect tag information from this cycle
msg("Total words: %d\n" % total)
msg("Correct tags: %d (%0.2f%%)\n" % (right, right / total * 100))
msg("Incorrect tags: %d (%0.2f%%)\n" % (wrong, wrong / total * 100))
msg("%s%s" % (sep,sep))
# calculate and output statistics for the entire test
print "Total tests run: %d" % len(totals)
print "Total time taken: %0.2f seconds" % (time.time() - total_time_start)
print "Average correct tags: %0.2f%%" % (sum(rights) / sum(totals) * 100)
print "Average incorrect tags: %0.2f%%" % (sum(wrongs) / sum(totals) * 100)
print
# give the option of inspecting incorrect tags
if raw_input("Examine bad tags? ") in ['y','Y']:
self.inspect(all_missed)
def train(self, sents):
"""
Train the tagger on a set of tagged sentences
:param sents: list of tagged sentences
"""
# collect POS tags from our corpus
self.pos_tags = self.training.pos_tags()
# add start markers to help with bigram tagging
msg("Adjusting POS tags...")
sents = self._adjust_pos(sents)
msg("done\n")
# create 2 conditional frequency distributions (from the NLTK) that store
# observed probabilities that a given word has a certain POS, one for
# lowercase-normalized words and one for words as they appear in the text
msg("Training (Wi|Ck)...")
# create a CFD for words normalized to lowercase
self.words_given_pos = ConditionalFreqDist((wp[1], wp[0].lower()) for \
sent in sents for wp in sent)
# create a CFD for words left in their original capitalization
self.words_given_pos_upper = ConditionalFreqDist((wp[1], wp[0]) for \
sent in sents for wp in sent)
msg("done\n")
# create another CFD that stores probabilities that stores observed
# probabilities that one POS follows another POS
msg("Training (Ci+1|Ci)...")
self.pos2_given_pos1 = ConditionalFreqDist((sent[i-1][1], sent[i][1]) for \
sent in sents for i in range(1,len(sent)))
msg("done\n")
def test(self, sent_set):
"""
Use a Hidden Markov Model to tag a set of sentences, and evaluate accuracy.
:param sent_set: tuple like (untagged sentences, gold standard sentences)
"""
untagged_sents = sent_set[0] # recover untagged sentences
gold_tagged_sents = sent_set[1] # recover gold standard tagged sentences
# initialize an HMM object with necessary parameters
self.hmm = HMM(untagged_sents, self.pos_tags, self.words_given_pos, \
self.words_given_pos_upper, self.pos2_given_pos1, Tagger.start_tag)
# get HMM-tagged sentences
hmm_tagged_sents = self.hmm.tag()
# evaluate against gold standard and return accuracy data
return self.evaluate(hmm_tagged_sents, gold_tagged_sents)
def evaluate(self, hmm_tagged_sents, gold_tagged_sents):
"""
Evaluate one set of tagged sentences against another set
:param hmm_tagged_sents: list of tagged sentences
:param gold_tagged_sents: list of tagged sentences used as gold standard
"""
# ensure our sentence sets have the same length
if len(hmm_tagged_sents) != len(gold_tagged_sents):
raise Exception("HMM-tagged sentence set did not match gold \
standard sentence set!")
right = 0 # initialize counter of correct tags
wrong = 0 # initialize counter of incorrect tags
missed = [] # initialize array of tagged words we didn't get right
# loop through sentence sets
for i in range(len(gold_tagged_sents)):
# ensure our sentences have the same length
if len(hmm_tagged_sents[i]) != len(gold_tagged_sents[i]):
raise Exception("HMM-tagged sentence did not match gold \
standard sentence!")
# loop through words in sentence
for j in range(len(gold_tagged_sents[i])):
gold_tagged_word = gold_tagged_sents[i][j]
hmm_tagged_word = hmm_tagged_sents[i][j]
# ensure the words are the same between the sets
if gold_tagged_word[0] != hmm_tagged_word[0]:
raise Exception("HMM-tagged word did not match gold \
standard word!")
# increment counters based on tag correctness
if gold_tagged_word[1] == hmm_tagged_word[1]:
right += 1
else:
missed.append((hmm_tagged_word, gold_tagged_word, \
hmm_tagged_sents[i], gold_tagged_sents[i]))
wrong += 1
# end words loop
# end sentences loop
# return a tuple of correct vs incorrect tags
return (right, wrong, missed)
def inspect(self, missed):
"""
Inspect a testing session, and print data about tag accuracy
:param missed: list of tuples of missed tags like:
(hmm_tagged_word, gold_tagged_word, hmm_context, gold_context)
"""
# create a CFD so we can examine a matrix of incorrect vs correct tags
# ms[1][1] = tag of a gold_tagged_word
# ms[0][1] = tag of an hmm_tagged_word
cfd = ConditionalFreqDist((ms[1][1], ms[0][1]) for ms in missed)
# initialize a hash to store mistakes by frequency
mistakes = {}
# print a table showing mistake frequency
cfd.tabulate()
msg("\n")
# loop through mistake frequencies by gold standard tag, i.e., if we are
# examining gold-standard 'IN', count what we incorrectly tagged it as
conds = cfd.conditions()
for g_tag in conds:
for hmm_tag in cfd[g_tag].keys():
# how many times did we incorrectly say g_tag was hmm_tag?
count = cfd[g_tag][hmm_tag]
# add these mistakes to the count
if count not in mistakes.keys():
mistakes[count] = []
mistakes[count].append((hmm_tag, g_tag))
# get a list of all mistake types that occurred over a threshold, worst first
mistake_counts = set([count for (count, mistake_set) in \
mistakes.iteritems() if count > Tagger.mistake_threshold])
mistake_counts = reversed(sorted(mistake_counts))
# now create a list of mistake types to show the user, i.e., loop
# through all types and if they are of a high-frequency type, add to list
mistakes_to_halt = []
for count in mistake_counts:
mistake_set = mistakes[count]
for mistake_tuple in mistake_set:
mistakes_to_halt.append(mistake_tuple)
msg("%d\t%s\twas really\t%s\n" % (count, mistake_tuple[0], \
mistake_tuple[1]))
msg("\n")
# create separators used when outputting missed word contexts
sep_big = "---------------------------------------------------\n"
sep_small = "\n-----------------------------------------\n"
# loop through individual mistakes and, if they match the kind of error
# we want to halt for, show the user the mistake as well as the sentence
# context for both the gold-standard sentence and the hmm-tagged sentence
response = None
for missed_set in missed:
if response not in ['q','Q']:
(hmm_tagged_word, gold_tagged_word, hmm_tagged_sent, \
gold_tagged_sent) = missed_set
should_halt = False
# determine whether the current mistake matches a mistake type
# we want to halt for
for pair in mistakes_to_halt:
if hmm_tagged_word[1] == pair[0] and \
gold_tagged_word[1] == pair[1]:
should_halt = True
if should_halt:
msg("%sTagged '%s' with %s when it should have been %s.%s" %\
(sep_big, hmm_tagged_word[0], hmm_tagged_word[1],\
gold_tagged_word[1], sep_small))
msg("Gold: " + (' '.join([(w[0] + "/" + w[1]) for w in \
gold_tagged_sent])))
msg(sep_small)
msg("Mine: " + (' '.join([(w[0] + "/" + w[1]) for w in \
hmm_tagged_sent])))
# get user input to decide whether to keep going
response = raw_input("\n\nEnter to continue, Q to quit: ")
######### `PRIVATE' FUNCTIONS #########
def _adjust_pos(self, sents):
"""
Insert start markers (word and tag tuple) in each sentence of a list.
Add any other tags that need adding
:param sents: list of tagged sentences
"""
new_sents = [] # initialize array of start-marked sentences
# loop through tagged sentences
for sent in sents:
# add a new start-marked sentence to our array
new_sents.append([(Tagger.start_tag, Tagger.start_tag)] + sent)
# make sure our start marker tag gets added to the POS list
self.pos_tags.append(Tagger.start_tag)
# also take the opportunity to add other tags to the list
# which we may not have encountered in testing
for tag in self.tags.rare_tags:
if tag not in self.pos_tags:
self.pos_tags.append(tag)
return new_sents