class MyLayer(Layer): def __init__(self, **kwargs): super(MyLayer, self).__init__(**kwargs) self.embedding = SimEmbedding self.masking = Masking() self.lstm = LSTM(32, return_sequences=True) def call(self, inputs): x = self.embedding(inputs) # Note that you could also prepare a `mask` tensor manually. # It only needs to be a boolean tensor # with the right shape, i.e. (batch_size, timesteps). mask = self.masking.compute_mask(x) output = self.lstm( x, mask=mask) # The layer will ignore the masked values return output
class CustomMasking(Masking): # this one can only be used in method 3. def __init__(self, output_dim, **kwargs): self.output_dim = output_dim self.masking = Masking() super(CustomMasking, self).__init__(**kwargs) def compute_mask(self, inputs): """ why not directly use tf.cast?? :param: inputs: [batch, token] """ expand_mask = tf.cast( tf.tile(tf.expand_dims(inputs, axis=-1), [1, 1, self.output_dim]), tf.float32) return self.masking.compute_mask(expand_mask) def get_config(self): config = {'output_dim': self.output_dim} base_config = super(Masking, self).get_config() return dict(list(base_config.items()) + list(config.items()))