class SpanEnd(Layer): def __init__(self, **kwargs): super(SpanEnd, self).__init__(**kwargs) def build(self, input_shape): emdim = input_shape[0][-1] // 2 input_shape_bilstm_1 = input_shape[0][:-1] + (emdim * 14, ) self.bilstm_1 = Bidirectional(LSTM(emdim, return_sequences=True)) self.bilstm_1.build(input_shape_bilstm_1) input_shape_dense_1 = input_shape[0][:-1] + (emdim * 10, ) self.dense_1 = Dense(units=1) self.dense_1.build(input_shape_dense_1) self.trainable_weights = self.bilstm_1.trainable_weights + self.dense_1.trainable_weights super(SpanEnd, self).build(input_shape) def call(self, inputs): encoded_passage, merged_context, modeled_passage, span_begin_probabilities = inputs weighted_sum = K.sum( K.expand_dims(span_begin_probabilities, axis=-1) * modeled_passage, -2) passage_weighted_by_predicted_span = K.expand_dims(weighted_sum, axis=1) tile_shape = K.concatenate([[1], [K.shape(encoded_passage)[1]], [1]], axis=0) passage_weighted_by_predicted_span = K.tile( passage_weighted_by_predicted_span, tile_shape) multiply1 = modeled_passage * passage_weighted_by_predicted_span span_end_representation = K.concatenate([ merged_context, modeled_passage, passage_weighted_by_predicted_span, multiply1 ]) span_end_representation = self.bilstm_1(span_end_representation) span_end_input = K.concatenate( [merged_context, span_end_representation]) span_end_weights = TimeDistributed(self.dense_1)(span_end_input) span_end_probabilities = Softmax()(K.squeeze(span_end_weights, axis=-1)) return span_end_probabilities def compute_output_shape(self, input_shape): _, merged_context_shape, _, _ = input_shape return merged_context_shape[:-1] def get_config(self): config = super().get_config() return config
class Gated_attention_with_self(Layer): def __init__(self, passage_len=200, num_units=300, emb_dim=600, **kwargs): self.num_units = num_units self.emb_dim = emb_dim self.passaage_len = passage_len super(Gated_attention_with_self, self).__init__(**kwargs) def build(self, input_shape): print(input_shape) passage_shape = (input_shape[0], self.passaage_len, input_shape[-1]) query_shape = (input_shape[0], input_shape[1] - self.passaage_len, input_shape[-1]) concat_shape = (input_shape[0], self.passaage_len, input_shape[-1] * 2) self.dense_1 = Dense(self.num_units, activation=relu) self.dense_1.build(passage_shape) self.dense_2 = Dense(self.num_units, activation=relu) self.dense_2.build(query_shape) self.dense_3 = Dense(input_shape[-1], activation=sigmoid) self.dense_3.build(concat_shape) self.dense_4 = Dense(input_shape[-1], activation=sigmoid) self.dense_4.build((input_shape[0], self.passaage_len, input_shape[-1] + (self.emb_dim * 2))) self.bilstm_1 = Bidirectional(LSTM(self.emb_dim, return_sequences=True)) self.bilstm_1.build(passage_shape) self.bilstm_2 = Bidirectional(LSTM(self.emb_dim, return_sequences=True)) self.bilstm_2.build(passage_shape) self.trainable_weight = self.dense_1.trainable_weights + self.dense_2.trainable_weights + self.dense_3.trainable_weights + self.dense_4.trainable_weights super(Gated_attention_with_self, self).build(input_shape) def call(self, stacked_input): # unstacked_input = tf.unstack(stacked_input) input_1 = stacked_input[:, :self.passaage_len, :] # Passage Input input_2 = stacked_input[:, self. passaage_len:, :] # Query Input (P' in case of Self Attention) print(input_1.shape, input_2.shape) dense_1_op = self.dense_1(input_1) dense_2_op = self.dense_2(input_2) co_mat = tf.matmul(dense_1_op, tf.transpose(dense_2_op, perm=[0, 2, 1])) co_mat = 1 / np.sqrt( self.emb_dim ) * co_mat # TODO: check if emb_dim or input_passage_dim activation = Activation(softmax) passage_bar = activation(co_mat) passage_bar = tf.matmul(passage_bar, input_2) passage_concat = tf.concat([input_1, passage_bar], 2) print(passage_concat.shape) dense_3_op = self.dense_3(passage_concat) passage_mul = tf.multiply(dense_3_op, input_1) query_depend_passage = self.bilstm_1(passage_mul) # Self Attention Part p_dash_concat = tf.concat([input_1, query_depend_passage], 2) print(query_depend_passage.shape, p_dash_concat.shape) dense_4_op = self.dense_4(p_dash_concat) p_dash_mul = tf.multiply(dense_4_op, input_1) query_depend_p_dash = self.bilstm_2(p_dash_mul) return query_depend_passage, query_depend_p_dash