Example #1
0
class SpanEnd(Layer):
    def __init__(self, **kwargs):
        super(SpanEnd, self).__init__(**kwargs)

    def build(self, input_shape):
        emdim = input_shape[0][-1] // 2
        input_shape_bilstm_1 = input_shape[0][:-1] + (emdim * 14, )
        self.bilstm_1 = Bidirectional(LSTM(emdim, return_sequences=True))
        self.bilstm_1.build(input_shape_bilstm_1)
        input_shape_dense_1 = input_shape[0][:-1] + (emdim * 10, )
        self.dense_1 = Dense(units=1)
        self.dense_1.build(input_shape_dense_1)
        self.trainable_weights = self.bilstm_1.trainable_weights + self.dense_1.trainable_weights
        super(SpanEnd, self).build(input_shape)

    def call(self, inputs):
        encoded_passage, merged_context, modeled_passage, span_begin_probabilities = inputs
        weighted_sum = K.sum(
            K.expand_dims(span_begin_probabilities, axis=-1) * modeled_passage,
            -2)
        passage_weighted_by_predicted_span = K.expand_dims(weighted_sum,
                                                           axis=1)
        tile_shape = K.concatenate([[1], [K.shape(encoded_passage)[1]], [1]],
                                   axis=0)
        passage_weighted_by_predicted_span = K.tile(
            passage_weighted_by_predicted_span, tile_shape)
        multiply1 = modeled_passage * passage_weighted_by_predicted_span
        span_end_representation = K.concatenate([
            merged_context, modeled_passage,
            passage_weighted_by_predicted_span, multiply1
        ])

        span_end_representation = self.bilstm_1(span_end_representation)

        span_end_input = K.concatenate(
            [merged_context, span_end_representation])

        span_end_weights = TimeDistributed(self.dense_1)(span_end_input)

        span_end_probabilities = Softmax()(K.squeeze(span_end_weights,
                                                     axis=-1))
        return span_end_probabilities

    def compute_output_shape(self, input_shape):
        _, merged_context_shape, _, _ = input_shape
        return merged_context_shape[:-1]

    def get_config(self):
        config = super().get_config()
        return config
def lrpify_model(model):
		'''
		This function takes as input user defined keras Model, and replace all LSTM/Bi_LSTM 
		with equivalent one which have LRP_LSTMCell as core cell 
		'''
		
		cell_config_keys = ['units', 'activation', 'recurrent_activation', 'use_bias', 'unit_forget_bias', 'kernel_constraint', 'recurrent_constraint', 'bias_constraint']
		rnn_config_keys = ['return_sequences', 'return_state', 'go_backwards', 'stateful', 'unroll']
		bidirect_config_keys = ['merge_mode']
		
		
		for i,layer in enumerate(model.layers):
			if isinstance(layer,Bidirectional):
				weights = layer.get_weights()
				inp_shape = layer.input_shape
				cell_config = {key:layer.get_config()['layer']['config'][key] for key in cell_config_keys}
				rnn_config  = {key:layer.get_config()['layer']['config'][key] for key in rnn_config_keys}
				bidirect_config = {key:layer.get_config()[key] for key in bidirect_config_keys}
				
				with CustomObjectScope({'LRP_LSTMCell': LRP_LSTMCell}):
					cell = LRP_LSTMCell(**cell_config, implementation=1)
					bi_lstm = Bidirectional(RNN(cell,**rnn_config),**bidirect_config)
					bi_lstm.build(inp_shape)
					bi_lstm.call(layer.input)
					bi_lstm._inbound_nodes = layer._inbound_nodes
					bi_lstm._outbound_nodes = layer._outbound_nodes
					bi_lstm.set_weights(weights)

				model.layers[i] = bi_lstm 
			
			if isinstance(layer,LSTM):
				weights = layer.get_weights()
				inp_shape = layer.input_shape
				cell_config = {key:layer.get_config()[key] for key in cell_config_keys}
				rnn_config  = {key:layer.get_config()[key] for key in rnn_config_keys}
				with CustomObjectScope({'LRP_LSTMCell': LRP_LSTMCell}):
					cell = LRP_LSTMCell(**cell_config,implementation=1)

					lstm = RNN(cell,**rnn_config)
					lstm.build(inp_shape)
					lstm.call(layer.input)
					lstm.set_weights(weights)
					lstm._inbound_nodes = layer._inbound_nodes
					lstm._outbound_nodes = layer._outbound_nodes
				model.layers[i] = lstm
		return model