def build(self, input_shape): LSTMCell.build(self, input_shape[0])
class AttentionLSTMCell(LSTMCell): def __init__(self, output_dim, attn_activation='tanh', single_attention_param=False, **kwargs): self.attn_activation = activations.get(attn_activation) self.single_attention_param = single_attention_param self.cell = LSTMCell(output_dim, **kwargs) super(AttentionLSTMCell, self).__init__(output_dim, **kwargs) def build(self, input_shape): constants_shape = input_shape[-1] self.cell.build(input_shape[0]) attention_dim = constants_shape[-1] output_dim = self.units self.U_a = self.add_weight(shape=(output_dim, output_dim), name='U_a', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.b_a = self.add_weight(shape=(output_dim,), name='b_a', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) self.U_m = self.add_weight(shape=(attention_dim, output_dim), name='U_a', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.b_m = self.add_weight(shape=(output_dim,), name='b_m', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) if self.single_attention_param: self.U_s = self.add_weight(shape=(output_dim, 1), name='U_s', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.b_s = self.add_weight(shape=(output_dim, 1), name='b_s', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.U_s = self.add_weight(shape=(output_dim, output_dim), name='U_s', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.b_s = self.add_weight(shape=(output_dim,), name='b_s', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) if self._initial_weights is not None: self.set_weights(self._initial_weights) del self._initial_weights def call(self, x, states, training=None, constants=None): h, [h, c] = self.cell.call(x, states, training) constants = constants[0] attention = K.dot(constants, self.U_m) + self.b_m m = self.attn_activation(K.dot(h, self.U_a) * attention + self.b_a) # Intuitively it makes more sense to use a sigmoid (was getting some NaN problems # which I think might have been caused by the exponential function -> gradients blow up) s = K.sigmoid(K.dot(m, self.U_s) + self.b_s) if self.single_attention_param: h = h * K.repeat_elements(s, self.units, axis=1) else: h = h * s return h, [h, c]