forked from matpalm/rnn_lm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention_rnn.py
101 lines (87 loc) · 5.38 KB
/
attention_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python
import numpy as np
import util
import theano
import theano.tensor as T
class AttentionRnn(object):
def __init__(self, n_in, n_embedding, n_hidden, orthogonal_init):
# for trivial annotation network; both _f (forward) and _b (backwards)
self.Wx_a_f = util.sharedMatrix(n_in, n_embedding, 'Wx_a_f', orthogonal_init) # embeddings for annotations
self.Whx_f = util.sharedMatrix(n_hidden, n_embedding, 'Whx_f', orthogonal_init)
self.Wx_a_b = util.sharedMatrix(n_in, n_embedding, 'Wx_a_b', orthogonal_init) # embeddings for annotations
self.Whx_b = util.sharedMatrix(n_hidden, n_embedding, 'Whx_b', orthogonal_init)
# for attention network
self.Wx_g = util.sharedMatrix(n_in, n_embedding, 'Wx_g', orthogonal_init) # embeddings for glimpses
self.Wug = util.sharedMatrix(n_hidden, n_embedding, 'Wug', orthogonal_init)
self.Wag = util.sharedMatrix(n_hidden, n_hidden, 'Wag', orthogonal_init)
self.wgs = util.sharedVector(n_hidden, 'Wgs')
# final mapping to y
self.Wy = util.sharedMatrix(n_in, n_hidden, 'Wy', orthogonal_init)
def params(self):
return [self.Wx_a_f, self.Whx_f,
self.Wx_a_b, self.Whx_b,
self.Wx_g, self.Wag, self.Wug, self.wgs,
self.Wy]
def _annotation_step(self,
x_t, # sequence to scan
h_t_minus_1, # recurrent state
Wx_a, Whx): # non sequences
# calc new hidden state; elementwise add of embedded input &
# recurrent weights dot last hiddenstate
embedding = Wx_a[x_t]
h_t = T.tanh(h_t_minus_1 + T.dot(Whx, embedding))
# TODO annotation_t = some_f(h_t) ?
# return next hidden state and annotation (which, for now, are the same thing)
return [h_t, h_t]
def _attended_annotation(self,
u, # sequence to scan
h_t_minus_1, # recurrent state
annotations): # non sequences
# first we need to mix the annotations using 'u' as a the context of
# attention. we'll be doing _all_ annotations wrt u in one hit, so we
# need a column broadcastable version of u
embedding = self.Wx_g[u]
u_col = embedding.dimshuffle(0, 'x')
# we also want to mix in last hidden state. again, for the same reason as u above,
# needs to be handled as a broadcasted column.
h_t_minus_1_col = h_t_minus_1.dimshuffle(0, 'x')
# can now combine annotations with u and hidden state
glimpse_vectors = T.tanh(T.dot(self.Wag, annotations.T) + T.dot(self.Wug, u_col) + h_t_minus_1_col)
# now collapse the glimpse vectors (there's one per token) to scalars
unnormalised_glimpse_scalars = T.dot(self.wgs, glimpse_vectors)
# normalise glimpses with a softmax
exp_glimpses = T.exp(unnormalised_glimpse_scalars)
glimpses = exp_glimpses / T.sum(exp_glimpses)
# attended version of the annotations is the the affine combo of the
# annotations using the normalised glimpses as the combo weights
attended_annotations = T.dot(annotations.T, glimpses)
# return 1) attended_annotations to pass as next hidden state
# 2) (same) attended annotations & glimpses to collect
return [attended_annotations, attended_annotations, glimpses]
def _softmax(self, annotation):
# calc output; softmax over output weights dot hidden state
return T.flatten(T.nnet.softmax(T.dot(self.Wy, annotation)), 1)
def t_y_softmax(self, x, h0):
# first pass is building base annotation vectors. for this
# simple example it's just a forward/backwards pass of a simple RNN concatenated
[forward_annotations, _hidden], _ = theano.scan(fn=self._annotation_step,
go_backwards=False,
sequences=[x],
non_sequences=[self.Wx_a_f, self.Whx_f],
outputs_info=[h0, None])
[backwards_annotations, _hidden], _ = theano.scan(fn=self._annotation_step,
go_backwards=True,
sequences=[x],
non_sequences=[self.Wx_a_b, self.Whx_b],
outputs_info=[h0, None])
backwards_annotations = backwards_annotations[::-1] # to make indexing same as forwards_
annotations = T.concatenate([forward_annotations, backwards_annotations])
# second pass; calculate attention over annotations
[_hidden, attended_annotations, glimpses], _ = theano.scan(fn=self._attended_annotation,
sequences=[x],
non_sequences=[annotations],
outputs_info=[h0, None, None])
# final pass; apply softmax
y_softmax, _ = theano.scan(fn=self._softmax,
sequences=[attended_annotations])
return y_softmax, glimpses