class Attention(Layer): def __init__(self, units, **kwargs): super(Attention, self).__init__(**kwargs) self.units = units def build(self, input_shape): print("input_shape (build): {}".format(input_shape)) # matrix W multiplied by hidden state vector # self.W = TimeDistributed(Dense(self.units), input_shape=(input_shape[1], input_shape[2])) self.W = Dense(self.units, name="W", activation="tanh", input_shape=(input_shape[2], ), trainable=True) self.W.build((input_shape[0], input_shape[2])) self.W_td = TimeDistributed(self.W, name="W_td", input_shape=(input_shape[1], input_shape[2])) print("'W' trainable weights (build): {}".format( self.W.trainable_weights)) print("'W_td output shape: {}".format( self.W_td.compute_output_shape(input_shape))) self._trainable_weights.extend(self.W.trainable_weights) # vector V multiplied by tanh of the output of the above self.V = Dense(1, use_bias=False, activation="softmax", name="V", trainable=True) self.V.build((input_shape[0], input_shape[1], self.units)) self.V_td = TimeDistributed(self.V, name="V_td", input_shape=(input_shape[1], self.units)) print("'V' trainable weights (build): {}".format( self.V.trainable_weights)) self._trainable_weights.extend(self.V.trainable_weights) print("trainable weights (build): {}".format(self._trainable_weights)) super(Attention, self).build(input_shape) def call(self, hidden, **kwargs): batch_size, time_steps, hidden_size = int_shape(hidden) # hidden shape: (batch_size, time_steps, hidden_size) print("'hidden' shape: {}".format(int_shape(hidden))) print("'W' trainable weights: {}".format(self.W.trainable_weights)) print("'W' matrix shape: {}".format( int_shape(self.W.trainable_weights[0]))) # TODO: this is pretty much useless and not really self-attention => should be corrected scores = self.W_td(hidden) # score shape: (batch_size, time_steps, self.units) print("'scores' shape: {}".format(int_shape(scores))) attention_weights = self.V_td(scores) # attention_weights shape: (batch_size, time_steps, 1) => attention distribution over time # TODO: should potentially mask this (different length sequences) print("'attention_weights' shape: {}".format( int_shape(attention_weights))) attention_weights = K.repeat_elements(attention_weights, hidden_size, 2) # attention_weights shape: (batch_size, time_steps, hidden_size) => repeat so that Multiply() can be used print("'attention_weights' shape: {}".format( int_shape(attention_weights))) context_vector = Multiply(name="context_vector_mul")( [attention_weights, hidden]) # context_vector shape: (batch_size, time_steps, hidden_size) => weighted hidden states print("'context_vector' shape: {}".format(int_shape(context_vector))) context_vector = Lambda(lambda x: K.sum(x, axis=1), name="context_vector_max")(context_vector) # context_vector shape: (batch_size, hidden_size) => context vector which can be used for w/e # => basically a feature vector which can be passed through a dense classification layer print("'context_vector' shape: {}".format(int_shape(context_vector))) if kwargs.get("return_attention", None): return context_vector, attention_weights return context_vector def compute_output_shape(self, input_shape): # should be (batch_size, hidden_size) return input_shape[0], input_shape[2] def get_config(self): config = super(Attention, self).get_config() config.update({"units": self.units}) return config