-
Notifications
You must be signed in to change notification settings - Fork 0
/
tagger.py
executable file
·377 lines (273 loc) · 12.9 KB
/
tagger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
#!/usr/bin/env python
import argparse;
import cPickle;
import logging;
import os;
import shutil;
import subprocess;
import sys;
import tempfile;
from joblib.parallel import Parallel, delayed;
from align import Aligner;
from chunk import BILOUChunkEncoder;
from features import OrthographicEncoder;
from io_ import load_doc, LTFDocument, LAFDocument, write_crfsuite_file;
from logger import configure_logger;
from util import get_ABG_value_sets
from math import log;
logger = logging.getLogger();
configure_logger(logger);
def tag_file(ltf, aligner, enc, chunker, modelf, tagged_dir, tagged_ext, threshold, A_vals, B_vals, G_vals):
"""Extract features for tokenization in LTF file and tag named entities.
Inputs
------
ltf : str
LTF file.
aligner : align.Aligner
Aligner instance used to obtain character onsets/offsets of discovered
mentions.
enc : features.Encoder
Encoder instance for feature extraction.
chunker : chunk.ChunkEncoder
ChunkEncoder instance for obtaining token onsets/offsets of discovered
mentions from tag sequences.
modelf : str
CRFSuite model file.
tagged_dir : str
Directory to which to output LAF files.
tagged_ext : str
Extension to used for output LAF files.
"""
# Create working directory.
temp_dir = tempfile.mkdtemp();
# Load LTF.
ltf_doc = load_doc(ltf, LTFDocument, logger);
if ltf_doc is None:
shutil.rmtree(temp_dir);
return;
# Attempt tagging.
try:
# Extract tokens.
try:
tokens, token_ids, token_onsets, token_offsets, token_nums, token_As, token_Bs, token_Gs, token_Fs, token_Js = ltf_doc.tokenizedWithABG();
except:
tokens, token_ids, token_onsets, token_offsets, token_nums = ltf_doc.tokenized();
token_As = token_Bs = token_Gs = token_Fs = token_Js = None
txt = ltf_doc.text();
spans = aligner.align(txt, tokens);
# Extract features
featsf = os.path.join(temp_dir, 'feats.txt');
# feats = enc.get_feats(tokens, token_As, token_Bs, token_Gs);
feats = enc.get_feats(tokens, token_nums, token_As, token_Bs, token_Gs, token_Fs, token_Js, A_vals, B_vals, G_vals);
write_crfsuite_file(featsf, feats);
shutil.copy(featsf, "featuresfile") #DEBUG
# Tag.
tagsf = os.path.join(temp_dir, 'tags.txt');
cmd = ['crfsuite', 'tag',
'--marginal', # outputs probability of each tag as extra field in tagsfile
# '--probability', # outputs probability of tag sequence at top of tagsfile
'-m', modelf,
featsf];
with open(tagsf, 'w') as f:
subprocess.call(cmd, stdout=f);
shutil.copy(tagsf, "taggingprobs") #DEBUG
# Look for NEs in the tagfile with marginal probs.
# If the tag is 'O', keep it.
# If the tag is anything else, keep if marginal prob is above threshold.
tagsf2 = os.path.join(temp_dir, 'tags2.txt');
"""
Helper method for checking the tag sequence output in the section below.
Checks for full BI*L sequence, returning that seqeunce if mean logprob exceeds
threshold logprob - returns sequence of O's of equal length otherwise.
If the seqeuence contains only one tag, that tag is returned as a U tag.
"""
def _check_BIL_sequence(tags, probs, threshold):
nextpart = ''
if len(tags) < 1:
logging.warn("Empty tag sequence submitted as BI*L sequence.")
elif len(tags) == 1:
logging.warn("Tag sequence of length 1 submitted as BI*L sequence.")
if probs[0] >= threshold: # compare probs, not abs vals of logprobs, hence >= and not <=
nextpart = 'U{}'.format(tags[0][1:])
else:
nextpart = 'O\n'
else:
try:
assert tags[0][0] == 'B' and tags[-1][0] == 'L'
except AssertionError:
logging.warn('Incomplete BI*L sequence submitted.')
tags[0] = 'B{}'.format(tags[0][1:])
tags[-1] = 'L{}'.format(tags[-1][1:])
# NElogProb = reduce(lambda x, y: (log(x) * -1) + (log(y) * -1), probs)/len(probs)
# if NElogProb <= (log(threshold) * -1): # compare abs vals of logprobs, hence <= and not >=
count = 0
for prob in probs:
if prob >= threshold:
count+=1
if count >= len(probs)/2.0:
nextpart = ''.join(tags)
else:
nextpart = 'O\n'*len(NEtags)
return nextpart
""" Retain or reject NE hypotheses based on probs and write new tags file """
with open(tagsf2, 'w') as f_out:
with open(tagsf, 'r') as f_in:
NEtags = None
NEprobs = None
for line in f_in.read().split('\n'):
try:
assert ':' in line
tag, prob = line.strip().split(':')
if tag[0] == 'O':
# if seq in play, check seq
# write tag
if NEtags:
f_out.write(_check_BIL_sequence(NEtags, NEprobs, threshold))
NEtags = None
NEprobs = None
f_out.write(tag+'\n')
elif tag[0] == 'U':
# if seq in play, check seq
# if prob >= threshold, write tag
# else, write tag = O
if NEtags:
f_out.write(_check_BIL_sequence(NEtags, NEprobs, threshold))
NEtags = None
NEprobs = None
if float(prob) >= threshold: # compare probs, not abs vals of logprobs, hence >= and not <=
f_out.write(tag+'\n')
else:
f_out.write('O\n')
elif tag[0] == 'B':
# if seq in play, check seq
# start new seq with tag
if NEtags:
f_out.write(_check_BIL_sequence(NEtags, NEprobs, threshold))
NEtags = [tag+'\n']
NEprobs = [float(prob)]
elif tag[0] == 'I':
# if seq in play, add tag to seq
# else, start new seq with tag = B
if NEtags:
NEtags.append(tag+'\n')
NEprobs.append(float(prob))
else:
logging.warn("Found an out of sequence I tag.")
tag = 'B{}'.format(tag[1:])
NEtags = [tag+'\n']
NEprobs = [float(prob)]
elif tag[0] == 'L':
# if seq in play, add tag to seq and check seq
# else, start new seq with tag = B
if NEtags:
NEtags.append(tag+'\n')
NEprobs.append(float(prob))
f_out.write(_check_BIL_sequence(NEtags, NEprobs, threshold))
NEtags = None
NEprobs = None
else:
logging.warn("Found an out of sequence L tag.")
tag = 'B{}'.format(tag[1:])
NEtags = [tag+'\n']
NEprobs = [float(prob)]
except AssertionError:
pass
# logging.warn('No ":" in line {}'.format(line)) #DEBUG
if NEtags: # Necessary if tagsf ends with an incomplete BI*L sequence
f_out.write(_check_BIL_sequence(NEtags, NEprobs, threshold))
NEtags = None
NEprobs = None
tagsf = tagsf2 # Set the checked tag file as the new tag file
# Continue
shutil.copy(tagsf, "tagsfile") #DEBUG
# Load tagged output.
with open(tagsf, 'r') as f:
tags = [line.strip() for line in f];
tags = tags[:len(tokens)];
# Chunk tags.
chunks = chunker.tags_to_chunks(tags);
# Construct mentions.
doc_id = ltf_doc.doc_id;
mentions = [];
n = 1;
for token_bi, token_ei, tag in chunks:
if tag == 'O':
continue;
# Assign entity id.
entity_id = '%s-NE%d' % (doc_id, n);
# Determine char onsets/offset for mention extent.
start_char = token_onsets[token_bi];
end_char = token_offsets[token_ei];
# Finally, determine text of extent and append.
extent_bi = spans[token_bi][0];
extent_ei = spans[token_ei][1];
extent = txt[extent_bi:extent_ei+1];
mentions.append([entity_id, # entity id
tag, # NE type
extent, # extent text
start_char, # extent char onset
end_char, # extent char offset
]);
n += 1;
# Write detected mentions to LAF file.
bn = os.path.basename(ltf);
laf = os.path.join(tagged_dir, bn.replace('.ltf.xml', tagged_ext));
laf_doc = LAFDocument(mentions=mentions, lang=ltf_doc.lang, doc_id=doc_id);
laf_doc.write_to_file(laf);
except:
logger.warn('Problem with %s. Skipping.' % ltf);
# Clean up.
shutil.rmtree(temp_dir);
##########################
# Ye olde' main
##########################
if __name__ == '__main__':
# parse command line args
parser = argparse.ArgumentParser(description='Perform named entity tagging.',
add_help=False,
usage='%(prog)s [options] model ltfs');
parser.add_argument('model_dir', nargs='?',
help='Model dir');
parser.add_argument('ltfs', nargs='*',
help='LTF files to be processed');
parser.add_argument('-S', nargs='?', default=None,
metavar='fn', dest='scpf',
help='Set script file (Default: None)');
parser.add_argument('-L', nargs='?', default='./',
metavar='dir', dest='tagged_dir',
help="Set output mentions dir (Default: current)");
parser.add_argument('-X', nargs='?', default='.laf.xml',
metavar='ext', dest='ext',
help="Set output mentions file extension (Default: .laf.xml)");
parser.add_argument('-j', nargs='?', default=1, type=int,
metavar='n', dest='n_jobs',
help='Set num threads to use (default: 1)');
parser.add_argument('-t', nargs='?', default=(2**-149), type=float,
metavar='t', dest='threshold',
help='Set threshold for NE probability (default: 2**-149)');
args = parser.parse_args();
if len(sys.argv) == 1:
parser.print_help();
sys.exit(1);
# Determine ltfs to process.
if not args.scpf is None:
with open(args.scpf, 'r') as f:
args.ltfs = [l.strip() for l in f.readlines()];
# Initialize chunker, aligner, and encoder.
chunker = BILOUChunkEncoder();
aligner = Aligner();
encf = os.path.join(args.model_dir, 'tagger.enc');
with open(encf, 'r') as f:
enc = cPickle.load(f);
# Get values of A, B, and G now to pass to each call of tag_file.
A_vals, B_vals, G_vals = get_ABG_value_sets(args.ltfs, logger)
# Perform tagging in parallel, dumping results to args.tagged_dir.
n_jobs = min(len(args.ltfs), args.n_jobs);
modelf = os.path.join(args.model_dir, 'tagger.crf');
f = delayed(tag_file);
Parallel(n_jobs=n_jobs, verbose=0)(f(ltf, aligner, enc, chunker,
modelf,
args.tagged_dir,
args.ext,
args.threshold,
A_vals, B_vals, G_vals) for ltf in args.ltfs);