forked from rlsummerscales/acres
-
Notifications
You must be signed in to change notification settings - Fork 0
/
finder.py
executable file
·135 lines (113 loc) · 5.16 KB
/
finder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/python
# base class for an object that identifies mentions and quantities
# author: Rodney Summerscales
import sys
import os.path
import nltk
from nltk.corpus import stopwords
from irstats import IRstats
######################################################################
# class for recording RPF statistics for an entity finder
######################################################################
class EntityStats:
""" Used to calculate recall, precision, f-score for mention finder """
irstats = {}
entityTypes = [] # list of entity types found by an entity finder
def __init__(self, entityTypes):
""" start computing RPF statistics for new set of abstracts """
self.irstats = {}
self.entityTypes = entityTypes
for mType in self.entityTypes:
self.irstats[mType] = IRstats()
def add(self, ms):
""" add tp, fp, fn counts from an other mention stat object """
for mType in self.entityTypes:
self.irstats[mType].tp = self.irstats[mType].tp + ms.irstats[mType].tp
self.irstats[mType].fp = self.irstats[mType].fp + ms.irstats[mType].fp
self.irstats[mType].fn = self.irstats[mType].fn + ms.irstats[mType].fn
def printStats(self):
""" Output the RPF statistics to screen """
print 'TP', 'FP', 'FN', 'R', 'P', 'F'
for mType in self.entityTypes:
print mType,' ',
self.irstats[mType].displayrpf()
def saveStats(self, statList, keyPrefix=''):
""" Add these stats to a given list of stats.
keyPrefix is prefix string attached to the beginning of the entity type
which is used as the key into the given has of stats. """
for mType in self.entityTypes:
statList.addIRstats(keyPrefix+mType, self.irstats[mType])
def writeStats(self, out):
""" write RPF stats to given output stream """
out.write('\tTP FP FN R P F\n')
for mType in self.entityTypes:
out.write(mType+'\t')
self.irstats[mType].writerpf(out)
######################################################################
# Mention finder base class
######################################################################
class Finder:
""" Used for training/testing a classifier to find mentions
in a list of abstracts.
(NOTE: This is the base class. The actual mention finders should
be derived from this class.
"""
finderType = 'basefinder'
entityTypes = [] # list of types mention finder will look for
entityTypesString = ''
def __init__(self, entityTypes):
""" Create a new mention finder to find a given list of mention types.
entityTypes = list of mention types (e.g. group, outcome) to find
"""
self.finderType = 'basefinder'
self.entityTypes = entityTypes
self.entityTypesString = self.entityTypesToString(self.entityTypes)
def getDefaultModelFilename(self):
""" Return the default filename used for creating a model file during train
"""
return self.entityTypesString
def entityTypesToString(self, entityTypes):
""" Return a string containing all of the entity types in a given list
"""
return '-'.join(entityTypes)
def getFoldString(self, foldIndex):
""" return string with fold index formatted for filenames """
if foldIndex != None:
return '.%d' % foldIndex
else:
return ''
def computeFeatures(self, absList, mode=''):
""" compute classifier features for each token in each abstract in a
given list of abstracts.
mode = 'test', 'train', or 'crossval'
"""
raise NotImplementedError("Need to implement computeFeatures()")
def train(self, absList, modelfilename):
""" Train a mention finder model given a list of abstracts """
raise NotImplementedError("Need to implement train()")
def test(self, absList, modelfilename, fold=None):
""" Apply the mention finder to a given list of abstracts
using the given model file.
"""
raise NotImplementedError("Need to implement test()")
def crossvalidate(self, absList, modelPath):
""" Apply mention finder to list of abstracts using k-fold
crossvalidation. The crossvalidation sets should be defined
in the AbstractList object (absList).
"""
if modelPath[-1] != '/':
modelPath = modelPath + '/'
k = 0
for dataSet in absList.cvSets:
print k+1, len(dataSet.train), len(dataSet.test)
modelFilename = modelPath+self.entityTypesString+'.'+str(k)+'.model'
k += 1
# train model
self.train(dataSet.train, modelFilename)
# apply to test set
self.test(dataSet.test, modelFilename)
print '-----------------'
def computeStats(self, absList, out=None, errorOut=None):
""" compute RPF stats for detected mentions in a list of abstracts.
write results to output stream. """
raise NotImplementedError("Need to implement computeStats()")