class SpanEnd(Layer): def __init__(self, **kwargs): super(SpanEnd, self).__init__(**kwargs) def build(self, input_shape): emdim = input_shape[0][-1] // 2 input_shape_bilstm_1 = input_shape[0][:-1] + (emdim * 14, ) self.bilstm_1 = Bidirectional(LSTM(emdim, return_sequences=True)) self.bilstm_1.build(input_shape_bilstm_1) input_shape_dense_1 = input_shape[0][:-1] + (emdim * 10, ) self.dense_1 = Dense(units=1) self.dense_1.build(input_shape_dense_1) self.trainable_weights = self.bilstm_1.trainable_weights + self.dense_1.trainable_weights super(SpanEnd, self).build(input_shape) def call(self, inputs): encoded_passage, merged_context, modeled_passage, span_begin_probabilities = inputs weighted_sum = K.sum( K.expand_dims(span_begin_probabilities, axis=-1) * modeled_passage, -2) passage_weighted_by_predicted_span = K.expand_dims(weighted_sum, axis=1) tile_shape = K.concatenate([[1], [K.shape(encoded_passage)[1]], [1]], axis=0) passage_weighted_by_predicted_span = K.tile( passage_weighted_by_predicted_span, tile_shape) multiply1 = modeled_passage * passage_weighted_by_predicted_span span_end_representation = K.concatenate([ merged_context, modeled_passage, passage_weighted_by_predicted_span, multiply1 ]) span_end_representation = self.bilstm_1(span_end_representation) span_end_input = K.concatenate( [merged_context, span_end_representation]) span_end_weights = TimeDistributed(self.dense_1)(span_end_input) span_end_probabilities = Softmax()(K.squeeze(span_end_weights, axis=-1)) return span_end_probabilities def compute_output_shape(self, input_shape): _, merged_context_shape, _, _ = input_shape return merged_context_shape[:-1] def get_config(self): config = super().get_config() return config
def lrpify_model(model): ''' This function takes as input user defined keras Model, and replace all LSTM/Bi_LSTM with equivalent one which have LRP_LSTMCell as core cell ''' cell_config_keys = ['units', 'activation', 'recurrent_activation', 'use_bias', 'unit_forget_bias', 'kernel_constraint', 'recurrent_constraint', 'bias_constraint'] rnn_config_keys = ['return_sequences', 'return_state', 'go_backwards', 'stateful', 'unroll'] bidirect_config_keys = ['merge_mode'] for i,layer in enumerate(model.layers): if isinstance(layer,Bidirectional): weights = layer.get_weights() inp_shape = layer.input_shape cell_config = {key:layer.get_config()['layer']['config'][key] for key in cell_config_keys} rnn_config = {key:layer.get_config()['layer']['config'][key] for key in rnn_config_keys} bidirect_config = {key:layer.get_config()[key] for key in bidirect_config_keys} with CustomObjectScope({'LRP_LSTMCell': LRP_LSTMCell}): cell = LRP_LSTMCell(**cell_config, implementation=1) bi_lstm = Bidirectional(RNN(cell,**rnn_config),**bidirect_config) bi_lstm.build(inp_shape) bi_lstm.call(layer.input) bi_lstm._inbound_nodes = layer._inbound_nodes bi_lstm._outbound_nodes = layer._outbound_nodes bi_lstm.set_weights(weights) model.layers[i] = bi_lstm if isinstance(layer,LSTM): weights = layer.get_weights() inp_shape = layer.input_shape cell_config = {key:layer.get_config()[key] for key in cell_config_keys} rnn_config = {key:layer.get_config()[key] for key in rnn_config_keys} with CustomObjectScope({'LRP_LSTMCell': LRP_LSTMCell}): cell = LRP_LSTMCell(**cell_config,implementation=1) lstm = RNN(cell,**rnn_config) lstm.build(inp_shape) lstm.call(layer.input) lstm.set_weights(weights) lstm._inbound_nodes = layer._inbound_nodes lstm._outbound_nodes = layer._outbound_nodes model.layers[i] = lstm return model