forked from matpalm/snli_nn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple_rnn.py
62 lines (50 loc) · 2.02 KB
/
simple_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
import theano
import theano.tensor as T
from updates import vanilla, rmsprop
import util
class SimpleRnn(object):
def __init__(self, name, input_dim, hidden_dim, opts, update_fn, h0, inputs,
context=None, context_dim=None):
self.name_ = name
self.update_fn = update_fn
self.h0 = h0
self.inputs = inputs # input sequence
self.context = context # additional context to add at each timestep of input
# hidden -> hidden
self.Uh = util.sharedMatrix(hidden_dim, hidden_dim, 'Uh', orthogonal_init=True)
# embedded input -> hidden
self.Wh = util.sharedMatrix(hidden_dim, input_dim, 'Wh', orthogonal_init=True)
# context -> hidden (if applicable)
if self.context:
self.Whc = util.sharedMatrix(hidden_dim, context_dim, 'Wch',
orthogonal_init=True)
# bias
self.bh = util.shared(util.zeros((hidden_dim,)), 'bh')
def name(self):
return self.name_
def dense_params(self):
params = [self.Uh, self.Wh, self.bh]
if self.context:
params.append(self.Whc)
return params
def params_for_l2_penalty(self):
return self.dense_params()
def updates_wrt_cost(self, cost, learning_opts):
gradients = util.clipped(T.grad(cost=cost, wrt=self.dense_params()))
return self.update_fn(self.dense_params(), gradients, learning_opts)
def recurrent_step(self, inp, h_t_minus_1):
h_t = (T.dot(self.Uh, h_t_minus_1) +
T.dot(self.Wh, inp) +
self.bh)
if self.context:
h_t += T.dot(self.Whc, self.context)
h_t = T.tanh(h_t)
return [h_t, h_t]
def all_states(self):
[_h_t, h_t], _ = theano.scan(fn=self.recurrent_step,
sequences=[self.inputs],
outputs_info=[self.h0, None])
return h_t
def final_state(self):
return self.all_states()[-1]