forked from jiyfeng/RSTParser
/
model.py
114 lines (95 loc) · 3.32 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
## model.py
## Author: Yangfeng Ji
## Date: 09-09-2014
## Time-stamp: <yangfeng 11/05/2014 20:44:25>
""" As a parsing model, it includes the following functions
1, Mini-batch training on the data generated by the Data class
2, Shift-Reduce RST parsing for a given text sequence
3, Save/load parsing model
"""
from sklearn.svm import LinearSVC
from cPickle import load, dump
from parser import SRParser
from feature import FeatureGenerator
from tree import RSTTree
from util import *
from datastructure import ActionError
import gzip, sys
class ParsingModel(object):
def __init__(self, vocab=None, idxlabelmap=None, clf=None):
""" Initialization
:type vocab: dict
:param vocab: mappint from feature templates to feature indices
:type idxrelamap: dict
:param idxrelamap: mapping from parsing action indices to
parsing actions
:type clf: LinearSVC
:param clf: an multiclass classifier from sklearn
"""
self.vocab = vocab
# print labelmap
self.labelmap = idxlabelmap
if clf is None:
self.clf = LinearSVC()
def train(self, trnM, trnL):
""" Perform batch-learning on parsing model
"""
self.clf.fit(trnM, trnL)
def predict(self, features):
""" Predict parsing actions for a given set
of features
:type features: list
:param features: feature list generated by
FeatureGenerator
"""
vec = vectorize(features, self.vocab)
label = self.clf.predict(vec)
# print label
return self.labelmap[label[0]]
def savemodel(self, fname):
""" Save model and vocab
"""
if not fname.endswith('.gz'):
fname += '.gz'
D = {'clf':self.clf, 'vocab':self.vocab,
'idxlabelmap':self.labelmap}
with gzip.open(fname, 'w') as fout:
dump(D, fout)
print 'Save model into file: {}'.format(fname)
def loadmodel(self, fname):
""" Load model
"""
with gzip.open(fname, 'r') as fin:
D = load(fin)
self.clf = D['clf']
self.vocab = D['vocab']
self.labelmap = D['idxlabelmap']
print 'Load model from file: {}'.format(fname)
def sr_parse(self, texts):
""" Shift-reduce RST parsing based on model prediction
:type texts: list of string
:param texts: list of EDUs for parsing
"""
# Initialize parser
srparser = SRParser([],[])
srparser.init(texts)
# Parsing
while not srparser.endparsing():
# Generate features
stack, queue = srparser.getstatus()
# Make sure call the generator with
# same arguments as in data generation part
fg = FeatureGenerator(stack, queue)
features = fg.features()
label = self.predict(features)
action = label2action(label)
# The best choice here is to choose the first
# legal action
try:
srparser.operate(action)
except ActionError:
print "Parsing action error with {}".format(action)
sys.exit()
tree = srparser.getparsetree()
rst = RSTTree(tree=tree)
return rst