Exemplo n.º 1
0
 def build(self, input_shape):
     LSTMCell.build(self, input_shape[0])
class AttentionLSTMCell(LSTMCell):
    def __init__(self, output_dim, attn_activation='tanh', single_attention_param=False, **kwargs):
        self.attn_activation = activations.get(attn_activation)
        self.single_attention_param = single_attention_param
        self.cell = LSTMCell(output_dim, **kwargs)

        super(AttentionLSTMCell, self).__init__(output_dim, **kwargs)

    def build(self, input_shape):
        constants_shape = input_shape[-1]
        self.cell.build(input_shape[0])
        attention_dim = constants_shape[-1]
        output_dim = self.units

        self.U_a = self.add_weight(shape=(output_dim, output_dim),
                                    name='U_a',
                                    initializer=self.recurrent_initializer,
                                    regularizer=self.recurrent_regularizer,
                                    constraint=self.recurrent_constraint)
        self.b_a = self.add_weight(shape=(output_dim,),
                        name='b_a',
                        initializer=self.bias_initializer,
                        regularizer=self.bias_regularizer,
                        constraint=self.bias_constraint)

        self.U_m = self.add_weight(shape=(attention_dim, output_dim),
                                  name='U_a',
                                  initializer=self.recurrent_initializer,
                                  regularizer=self.recurrent_regularizer,
                                  constraint=self.recurrent_constraint)
        self.b_m = self.add_weight(shape=(output_dim,),
                                   name='b_m',
                                   initializer=self.bias_initializer,
                                   regularizer=self.bias_regularizer,
                                   constraint=self.bias_constraint)

        if self.single_attention_param:
            self.U_s = self.add_weight(shape=(output_dim, 1),
                                       name='U_s',
                                       initializer=self.recurrent_initializer,
                                       regularizer=self.recurrent_regularizer,
                                       constraint=self.recurrent_constraint)
            self.b_s = self.add_weight(shape=(output_dim, 1),
                                       name='b_s',
                                       initializer=self.bias_initializer,
                                       regularizer=self.bias_regularizer,
                                       constraint=self.bias_constraint)
        else:
            self.U_s = self.add_weight(shape=(output_dim, output_dim),
                                       name='U_s',
                                       initializer=self.recurrent_initializer,
                                       regularizer=self.recurrent_regularizer,
                                       constraint=self.recurrent_constraint)
            self.b_s = self.add_weight(shape=(output_dim,),
                                       name='b_s',
                                       initializer=self.bias_initializer,
                                       regularizer=self.bias_regularizer,
                                       constraint=self.bias_constraint)

        if self._initial_weights is not None:
            self.set_weights(self._initial_weights)
            del self._initial_weights

    def call(self, x, states, training=None, constants=None):
        h, [h, c] = self.cell.call(x, states, training)
        constants = constants[0]
        attention = K.dot(constants, self.U_m) + self.b_m

        m = self.attn_activation(K.dot(h, self.U_a) * attention + self.b_a)
        # Intuitively it makes more sense to use a sigmoid (was getting some NaN problems
        # which I think might have been caused by the exponential function -> gradients blow up)
        s = K.sigmoid(K.dot(m, self.U_s) + self.b_s)

        if self.single_attention_param:
            h = h * K.repeat_elements(s, self.units, axis=1)
        else:
            h = h * s

        return h, [h, c]