def __init__(self, head_size, num_heads, dropout=0.0, name="mhsa_module", **kwargs): super(MHSAModule, self).__init__(name=name, **kwargs) # self.pc = PositionalEncoding() self.ln = tf.keras.layers.LayerNormalization() self.mha = MultiHeadAttention(head_size=head_size, num_heads=num_heads) self.do = tf.keras.layers.Dropout(dropout) self.res_add = tf.keras.layers.Add()
class MHSAModule(tf.keras.layers.Layer): def __init__(self, head_size, num_heads, dropout=0.0, name="mhsa_module", **kwargs): super(MHSAModule, self).__init__(name=name, **kwargs) # self.pc = PositionalEncoding() self.ln = tf.keras.layers.LayerNormalization() self.mha = MultiHeadAttention(head_size=head_size, num_heads=num_heads) self.do = tf.keras.layers.Dropout(dropout) self.res_add = tf.keras.layers.Add() # @tf.function(experimental_relax_shapes=True) def call(self, inputs, training=False, **kwargs): # outputs = self.pc(inputs) outputs = self.ln(inputs, training=training) outputs = self.mha([outputs, outputs, outputs], training=training) outputs = self.do(outputs, training=training) outputs = self.res_add([inputs, outputs]) return outputs def get_config(self): conf = super(MHSAModule, self).get_config() # conf.update(self.pc.get_config()) conf.update(self.ln.get_config()) conf.update(self.mha.get_config()) conf.update(self.do.get_config()) conf.update(self.res_add.get_config()) return conf