コード例 #1
0
class Attention(Layer):
    def __init__(self, units, **kwargs):
        super(Attention, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        print("input_shape (build): {}".format(input_shape))
        # matrix W multiplied by hidden state vector
        # self.W = TimeDistributed(Dense(self.units), input_shape=(input_shape[1], input_shape[2]))
        self.W = Dense(self.units,
                       name="W",
                       activation="tanh",
                       input_shape=(input_shape[2], ),
                       trainable=True)
        self.W.build((input_shape[0], input_shape[2]))
        self.W_td = TimeDistributed(self.W,
                                    name="W_td",
                                    input_shape=(input_shape[1],
                                                 input_shape[2]))
        print("'W' trainable weights (build): {}".format(
            self.W.trainable_weights))
        print("'W_td output shape: {}".format(
            self.W_td.compute_output_shape(input_shape)))
        self._trainable_weights.extend(self.W.trainable_weights)

        # vector V multiplied by tanh of the output of the above
        self.V = Dense(1,
                       use_bias=False,
                       activation="softmax",
                       name="V",
                       trainable=True)
        self.V.build((input_shape[0], input_shape[1], self.units))
        self.V_td = TimeDistributed(self.V,
                                    name="V_td",
                                    input_shape=(input_shape[1], self.units))
        print("'V' trainable weights (build): {}".format(
            self.V.trainable_weights))
        self._trainable_weights.extend(self.V.trainable_weights)
        print("trainable weights (build): {}".format(self._trainable_weights))

        super(Attention, self).build(input_shape)

    def call(self, hidden, **kwargs):
        batch_size, time_steps, hidden_size = int_shape(hidden)
        # hidden shape: (batch_size, time_steps, hidden_size)
        print("'hidden' shape: {}".format(int_shape(hidden)))
        print("'W' trainable weights: {}".format(self.W.trainable_weights))
        print("'W' matrix shape: {}".format(
            int_shape(self.W.trainable_weights[0])))

        # TODO: this is pretty much useless and not really self-attention => should be corrected
        scores = self.W_td(hidden)
        # score shape: (batch_size, time_steps, self.units)
        print("'scores' shape: {}".format(int_shape(scores)))

        attention_weights = self.V_td(scores)
        # attention_weights shape: (batch_size, time_steps, 1) => attention distribution over time
        # TODO: should potentially mask this (different length sequences)
        print("'attention_weights' shape: {}".format(
            int_shape(attention_weights)))

        attention_weights = K.repeat_elements(attention_weights, hidden_size,
                                              2)
        # attention_weights shape: (batch_size, time_steps, hidden_size) => repeat so that Multiply() can be used
        print("'attention_weights' shape: {}".format(
            int_shape(attention_weights)))

        context_vector = Multiply(name="context_vector_mul")(
            [attention_weights, hidden])
        # context_vector shape: (batch_size, time_steps, hidden_size) => weighted hidden states
        print("'context_vector' shape: {}".format(int_shape(context_vector)))

        context_vector = Lambda(lambda x: K.sum(x, axis=1),
                                name="context_vector_max")(context_vector)
        # context_vector shape: (batch_size, hidden_size) => context vector which can be used for w/e
        # => basically a feature vector which can be passed through a dense classification layer
        print("'context_vector' shape: {}".format(int_shape(context_vector)))

        if kwargs.get("return_attention", None):
            return context_vector, attention_weights
        return context_vector

    def compute_output_shape(self, input_shape):
        # should be (batch_size, hidden_size)
        return input_shape[0], input_shape[2]

    def get_config(self):
        config = super(Attention, self).get_config()
        config.update({"units": self.units})
        return config