forked from juliuskunze/thalnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
206 lines (158 loc) · 7.35 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from typing import Optional, Callable
import numpy as np
import tensorflow as tf
import functools
from util import define_scope, unzip, single, lazy_property
'''
Classifier base class
'''
class Classifier:
def __init__(self, data, target, dropout):
self.data = data
self.target = target
self.dropout = dropout
self.prediction
self.cross_entropy
self.accuracy
self.optimize
self.train_summary
self.test_summary
self.weights_summary
@lazy_property
def run(self):
raise NotImplementedError("Please Implement this method")
@lazy_property
def logits(self):
return tf.contrib.layers.fully_connected(self.run, num_outputs=int(self.target.shape[1]),
activation_fn=None)
@define_scope
def prediction(self):
return tf.nn.softmax(self.logits)
@define_scope
def cross_entropy(self):
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.target))
@define_scope
def accuracy(self):
correct = tf.equal(tf.argmax(self.logits, axis=1), tf.argmax(self.target, axis=1))
return tf.reduce_mean(tf.cast(correct, dtype=tf.float32))
@define_scope
def optimize(self):
return tf.train.AdamOptimizer(learning_rate=1e-3).minimize(self.cross_entropy)
@define_scope('weights')
def weights_summary(self):
variables_except_from_optimizer = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='^(?!.*optimize).*$')
return tf.summary.merge([tf.summary.histogram(v.name, v) for v in variables_except_from_optimizer])
def summary(self):
test_cross_entropy_summary = tf.summary.scalar(f'cross_entropy', self.cross_entropy)
test_accuracy_summary = tf.summary.scalar(f'accuracy', self.accuracy)
return tf.summary.merge([test_cross_entropy_summary, test_accuracy_summary])
@define_scope('train')
def train_summary(self):
return self.summary()
@define_scope('test')
def test_summary(self):
return self.summary()
@lazy_property
def num_parameters(self):
return np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
'''
Classifier extensions
'''
class MLPClassifier(Classifier):
def __init__(self, data, target, dropout,num_hidden: int,num_layers: int):
self.num_layers = num_layers
self.num_hidden = num_hidden
super().__init__(data, target, dropout)
# override run()
@lazy_property
def run(self):
lastoutput = self.data
for _ in range(self.num_layers):
lastoutput = tf.contrib.layers.fully_connected(lastoutput,num_outputs=self.num_hidden,activation_fn=tf.nn.relu)
return lastoutput
class SequenceClassifier(Classifier):
def __init__(self, data, target, dropout,
get_rnn_cell: Callable[[], tf.nn.rnn_cell.RNNCell],num_rows,row_size):
self.get_rnn_cell = get_rnn_cell
data = tf.reshape(data,shape=[-1,num_rows,row_size])
super().__init__(data,target,dropout)
@lazy_property
def run(self):
rnn_cell_with_dropout = tf.nn.rnn_cell.DropoutWrapper(self.get_rnn_cell(), output_keep_prob=1 - self.dropout)
output, last_state = tf.nn.dynamic_rnn(rnn_cell_with_dropout, self.data, dtype=tf.float32)
output = tf.transpose(output, [1, 0, 2])
last_output = tf.nn.embedding_lookup(output, int(output.shape[0])-1)
return last_output
'''
RNN extensions
'''
def GRUCell(num_hidden: int, num_layers=4):
return tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.GRUCell(num_hidden) for _ in range(num_layers)])
class FfGruModule:
def __init__(self,
center_size: int,
context_input_size: int,
center_output_size: int,
input_size: Optional[int] = None,
output_size: Optional[int] = None,
name: str = ''):
self.name = name
self.center_size = center_size
self.context_input_size = context_input_size
self.center_output_size = center_output_size
self.input_size = input_size
self.output_size = output_size
self.num_gru_units = self.output_size + self.center_output_size
def __call__(self, inputs, center_state, module_state):
"""
:return: output, new_center_features, new_module_state
"""
with tf.variable_scope(self.name):
reading_weights = tf.get_variable('reading_weights',shape=[self.center_size,self.context_input_size],initializer=tf.truncated_normal_initializer(stddev=0.1))
context_input = tf.matmul(center_state, tf.clip_by_norm(reading_weights,1.0))
inputs = tf.concat([inputs, context_input], axis=1) if self.input_size else context_input
inputs = tf.contrib.layers.fully_connected(inputs, num_outputs=self.center_output_size)
gru = tf.nn.rnn_cell.GRUCell(self.num_gru_units)
gru_output, new_module_state = gru(inputs=inputs, state=module_state)
output, center_feature_output = tf.split(gru_output,
[self.output_size, self.center_output_size],
axis=1) if self.output_size else (None, gru_output)
return output, center_feature_output, new_module_state
class ThalNetCell(tf.nn.rnn_cell.RNNCell):
def __init__(self,
input_size: int,
output_size: int,
context_input_size: int,
center_size_per_module: int,
num_modules: int = 4):
self._context_input_size = context_input_size
self._input_size = input_size
self._output_size = output_size
self._center_size = num_modules * center_size_per_module
self.center_size_per_module = center_size_per_module
self._num_modules = num_modules
super().__init__(_reuse=None)
@lazy_property
def state_size(self):
return [module.center_output_size for module in self.modules] + \
[module.num_gru_units for module in self.modules]
@lazy_property
def output_size(self):
return self._output_size
@lazy_property
def modules(self):
return [FfGruModule(center_size=self._center_size,
context_input_size=self._context_input_size,
center_output_size=self.center_size_per_module,
input_size=self._input_size if i == 0 else 0,
output_size=self.output_size if i == self._num_modules - 1 else 0,
name=f'module{i}') for i in range(self._num_modules)]
def __call__(self, inputs, state, scope=None):
center_state_per_module = state[:self._num_modules]
module_states = state[self._num_modules:]
center_state = tf.concat(center_state_per_module, axis=1)
outputs, new_center_features, new_module_states = unzip(
[module(inputs if module.input_size else None, center_state=center_state, module_state=module_state)
for module, module_state in zip(self.modules, module_states)])
output = single([o for o in outputs if o is not None])
return output, list((new_center_features + new_module_states))