/
example_production_steps.py
309 lines (271 loc) · 12.7 KB
/
example_production_steps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""
Example creating an LSTM for character-level language models.
"""
from __future__ import print_function, division
# normal imports
import numpy
import theano.tensor as T
# import cPickle as pickle
import cloudpickle as pickle
import string
# opendeep imports
from opendeep.data import TextDataset
from opendeep.models import LSTM
from opendeep.monitor import Monitor
from opendeep.optimization import RMSProp
from opendeep.utils.misc import numpy_one_hot
#################
# Training data #
#################
def get_dataset(path_to_data='data/tokenized/'):
# our input data is going to be .txt files in a folder that are formatted as follows:
# each line is a new token (word) separated from a class label with a tab character.
# our preprocessing includes converting to lowercase, splitting into characters, and repeating
# the label for each character. Because punctuation counts as a word, we are doing special
# rules with adding spaces around punctuation tokens to build a more accurate language model
class StringProcessor:
"""
This is a helper class (normally we would just do functions for preprocessing) to preprocess
our text files (line by line) into the appropriate input and target data. The class is used
because we needed to keep track of state when processing line by line.
"""
def __init__(self):
self.previous_label = ''
self.space_before_punct = ['(', '``', '[', '{', '$', '#', '&']
self.space_after_punct = ['&']
self.previous_char = ''
def process_line(self, line):
chars, label = line.split('\t', 1)
chars = chars.lower()
label = label.rstrip()
labels = [label] * len(chars)
if (not chars[0] in string.punctuation or chars[0] in self.space_before_punct) and \
(not self.previous_char in self.space_before_punct or self.previous_char in self.space_after_punct):
chars = ' ' + chars
if label == self.previous_label:
labels = [label] + labels
else:
labels = ['O'] + labels
self.previous_label = label
self.previous_char = chars[-1]
return chars, labels
def get_inputs(self, line):
return self.process_line(line)[0]
def get_labels(self, line):
return self.process_line(line)[1]
# now that we defined our preprocessor, create a new TextDataset (works over files)
# a TextDataset is an OpenDeep class that creates one-hot encodings of inputs and outputs automatically
# and keeps them in vocab and entity_vocab dictionaries.
processor = StringProcessor()
dataset = TextDataset(path=path_to_data,
inputs_preprocess=lambda line: processor.get_inputs(line),
targets_preprocess=lambda line: processor.get_labels(line),
level="char", sequence_length=120)
# save the computed dictionaries to use for converting inputs and outputs from running the model.
with open('vocab.pkl', 'wb') as f:
pickle.dump(dataset.vocab, f, protocol=pickle.HIGHEST_PROTOCOL)
with open('entity_vocab.pkl', 'wb') as f:
pickle.dump(dataset.label_vocab, f, protocol=pickle.HIGHEST_PROTOCOL)
return dataset
##############################
# Step 1 and 2: create model #
##############################
def create_model(init_config_file=None, vocab={}, label_vocab={}):
# load from a configuration file, or define the model configuration
if init_config_file is not None:
with open(init_config_file, 'rb') as f:
init_config = pickle.load(f)
else:
init_config = {
'input_size': len(vocab),
'hidden_size': 128,
'output_size': len(label_vocab),
'hidden_activation': 'tanh',
'inner_hidden_activation': 'sigmoid',
'activation': 'softmax',
'weights_init': 'uniform',
'weights_interval': 'montreal',
'r_weights_init': 'orthogonal',
'clip_recurrent_grads': 5.,
'noise': 'dropout',
'noise_level': 0.5,
'direction': 'bidirectional',
'cost_function': 'nll',
'cost_args': {'one_hot': True}
}
# instantiate the model!
lstm = LSTM(**init_config)
return lstm
############################
# Step 2a: train the model #
############################
def train_model(model, dataset):
# train the lstm on our dataset!
# let's monitor the error %
# output is in shape (n_timesteps, n_sequences, data_dim)
# calculate the mean prediction error over timesteps and batches
predictions = T.argmax(model.get_outputs(), axis=2)
actual = T.argmax(model.get_targets()[0].dimshuffle(1, 0, 2), axis=2)
char_error = T.mean(T.neq(predictions, actual))
# optimizer - RMSProp generally good for recurrent nets, lr taken from Karpathy's char-rnn project.
# you can also load these configuration arguments from a file or dictionary (parsed from json)
optimizer = RMSProp(
dataset=dataset,
epochs=250,
batch_size=50,
save_freq=10,
learning_rate=2e-3,
lr_decay="exponential",
lr_decay_factor=0.97,
decay=0.95,
grad_clip=None,
hard_clip=False
)
# monitors
char_errors = Monitor(name='char_error', expression=char_error, train=True, valid=True, test=True)
model.train(optimizer=optimizer, monitor_channels=[char_errors])
#################################
# Step 3: load model parameters #
#################################
def load_model_params(model, param_file='outputs/lstm/trained_epoch_100.pkl'):
# load params
model.load_params(param_file)
################
# Step 4: compile model's run function
################
def compile_model_run_fn(model):
success, path = model.save_run('lstm_run_fn.pkl')
return path
############
# Step 5: run on real data
############
# parse a string into some input data
def string_to_data(query):
vocab = pickle.load(open('vocab.pkl', 'rb'))
# process the raw input data string
data = []
# get the integer encodings
for data_char in query:
data.append(vocab.get(data_char, 0))
# convert the integers to one-hot arrays
data = numpy_one_hot(numpy.asarray(data), n_classes=numpy.amax(vocab.values()) + 1)
# make 3D for model input
seq, dim = data.shape
data = numpy.reshape(data, (1, seq, dim))
return data
def data_to_str(data, predictions):
pass
def run_model(model, data, vocab, label_vocab, compiled_fn_path=None):
# in our case here, data will be given from the user as a string. we have to process with the encoding
# and decoding vocabulary to do anything meaningful with the results
def _get_entities(data, predictions, vocab_inv, entity_vocab):
# find contiguous entity characters across timesteps
non_entity_label = entity_vocab.get('O')
entities = []
for i, query in enumerate(predictions):
previous_label = non_entity_label
entity_string = ""
used_indices = set()
for j, label in enumerate(query):
# find entity start point (expand to space character) and extract the continuous entity
if label != non_entity_label and label != previous_label and j not in used_indices:
entity_start = j
while vocab_inv.get(
numpy.argmax(data[i, entity_start])) not in string.whitespace and entity_start >= 0:
entity_start -= 1
# move start point forward one to get out of whitespace or back to 0 index
entity_start += 1
# now from the start point, extract continuous until whitespace or punctuation
entity_idx = entity_start
while entity_idx < len(query) and \
(
query[entity_idx] == label or
entity_idx == entity_start or
(
entity_idx > entity_start and
vocab_inv.get(numpy.argmax(data[i, entity_idx])) not in string.whitespace + string.punctuation and
vocab_inv.get(numpy.argmax(data[i, entity_idx - 1])) not in string.whitespace + string.punctuation
)
):
entity_string += vocab_inv.get(numpy.argmax(data[i, entity_idx]))
used_indices.add(entity_idx)
entity_idx += 1
# get rid of trailing matched punctuation
if entity_string[-1] in string.punctuation:
entity_string = entity_string[:-1]
# add the entity stripped of whitespace in beginning and end, and reset the string
entities.append(entity_string.strip())
entity_string = ""
previous_label = label
return entities
data = string_to_data(data, vocab)
######
# running the model
######
# use the model's run function
if compiled_fn_path is None:
character_probs = model.run(data)
# or alternatively, load the pickled run function from
else:
with open(compiled_fn_path, 'rb') as f:
run_fn = pickle.load(f)
character_probs = run_fn(data)
# this has the shape (timesteps, batches, data), so swap axes to (batches, timesteps, data)
character_probs = numpy.swapaxes(character_probs, 0, 1)
# now extract the guessed entities
predictions = numpy.argmax(character_probs, axis=2)
entities = _get_entities(data, predictions, {v:k for k,v in vocab.items()}, label_vocab)
return entities
def postprocess(input, output):
vocab = pickle.load(open('vocab.pkl', 'rb'))
entity_vocab = pickle.load(open('entity_vocab.pkl', 'rb'))
entity_inv = {v:k for k,v in entity_vocab.items()}
vocab_inv = {v:k for k,v in vocab.items()}
# this has the shape (timesteps, batches, data), so swap axes to (batches, timesteps, data)
character_probs = numpy.swapaxes(output, 0, 1)
# now extract the guessed entities
predictions = numpy.argmax(output, axis=2)
# find contiguous entity characters across timesteps
non_entity_label = entity_vocab.get('O')
entities = []
for i, query in enumerate(predictions):
previous_label = non_entity_label
entity_string = ""
used_indices = set()
for j, label in enumerate(query):
# find entity start point (expand to space character) and extract the continuous entity
if label != non_entity_label and label != previous_label and j not in used_indices:
entity_start = j
while vocab_inv.get(
numpy.argmax(input[i, entity_start])) not in string.whitespace and entity_start >= 0:
entity_start -= 1
# move start point forward one to get out of whitespace or back to 0 index
entity_start += 1
# now from the start point, extract continuous until whitespace or punctuation
entity_idx = entity_start
while entity_idx < len(query) and \
(
query[entity_idx] == label or
entity_idx == entity_start or
(
entity_idx > entity_start and
vocab_inv.get(numpy.argmax(input[i, entity_idx])) not in string.whitespace + string.punctuation and
vocab_inv.get(numpy.argmax(input[i, entity_idx - 1])) not in string.whitespace + string.punctuation
)
):
entity_string += vocab_inv.get(numpy.argmax(input[i, entity_idx]))
used_indices.add(entity_idx)
entity_idx += 1
# get rid of trailing matched punctuation
if entity_string[-1] in string.punctuation:
entity_string = entity_string[:-1]
# add the entity stripped of whitespace in beginning and end, and reset the string
entities.append((entity_string.strip(), entity_inv.get(label)))
entity_string = ""
previous_label = label
return entities
if __name__ == '__main__':
with open('input_process.pkl', 'wb') as f:
pickle.dump(string_to_data, f)
with open('output_process.pkl', 'wb') as f:
pickle.dump(postprocess, f)