Example #1
0
 def __init__(self,
              emb_dim=768,
              num_layers=6,
              num_heads=12,
              mlp_dim=3072,
              mlp_act=activations.approximate_gelu,
              output_dropout=0.1,
              attention_dropout=0.1,
              mlp_dropout=0.1,
              norm_first=True,
              norm_input=False,
              norm_output=True,
              causal=False,
              trainable_posemb=False,
              posemb_init=initializers.HarmonicEmbeddings(scale_factor=1e-4,
                                                          max_freq=1.0),
              aaemb_init=tf.initializers.RandomNormal(stddev=1.0),
              kernel_init=tf.initializers.GlorotUniform(),
              aaemb_scale_factor=None,
              max_len=1024,
              **kwargs):
     super().__init__(**kwargs)
     self._causal = causal
     self.posemb_layer = nlp_layers.PositionEmbedding(
         max_length=max_len,
         initializer=posemb_init,
         trainable=trainable_posemb,
         name='embeddings/positional')
     self.aaemb_layer = nlp_layers.OnDeviceEmbedding(
         vocab_size=len(self._vocab),
         embedding_width=emb_dim,
         initializer=aaemb_init,
         scale_factor=aaemb_scale_factor,
         name='embeddings/aminoacid')
     layer_norm_cls = functools.partial(tf.keras.layers.LayerNormalization,
                                        axis=-1,
                                        epsilon=1e-12)
     self._input_norm_layer = (layer_norm_cls(
         name='embeddings/layer_norm') if norm_input else None)
     self._output_norm_layer = (layer_norm_cls(
         name='output/layer_norm') if norm_output else None)
     self._dropout_layer = tf.keras.layers.Dropout(
         rate=output_dropout, name='embeddings/dropout')
     self._attention_mask = nlp_layers.SelfAttentionMask()
     self._transformer_layers = []
     for i in range(num_layers):
         self._transformer_layers.append(
             nlp_layers.TransformerEncoderBlock(
                 num_attention_heads=num_heads,
                 inner_dim=mlp_dim,
                 inner_activation=mlp_act,
                 output_dropout=output_dropout,
                 attention_dropout=attention_dropout,
                 inner_dropout=mlp_dropout,
                 kernel_initializer=kernel_init,
                 norm_first=norm_first,
                 name=f'transformer/layer_{i}'))
    def __init__(self,
                 emb_dim=768,
                 dropout=0.0,
                 use_layer_norm=False,
                 use_positional_embedding=False,
                 position_embed_init=None,
                 train_position_embed=True,
                 aaemb_init=None,
                 aaemb_scale_factor=None,
                 max_len=1024,
                 **kwargs):
        super().__init__(**kwargs)
        if position_embed_init is None:
            position_embed_init = initializers.HarmonicEmbeddings(
                scale_factor=1e-4, max_freq=1.0)
        if aaemb_init is None:
            aaemb_init = tf.initializers.TruncatedNormal(stddev=1.0)

        self._use_layer_norm = use_layer_norm

        if use_positional_embedding:
            self._positional_embedding = nlp_layers.PositionEmbedding(
                max_length=max_len,
                initializer=position_embed_init,
                trainable=train_position_embed,
                name='embeddings/positional')
        else:
            self._positional_embedding = None

        self._aa_embed = nlp_layers.OnDeviceEmbedding(
            vocab_size=len(self._vocab),
            embedding_width=emb_dim,
            initializer=aaemb_init,
            scale_factor=aaemb_scale_factor,
            name='embeddings/aminoacid')

        if use_layer_norm:
            self._layer_norm = tf.keras.layers.LayerNormalization(
                axis=-1, epsilon=1e-12, name='embeddings/layer_norm')
        else:
            self._layer_norm = None

        self._dropout = tf.keras.layers.Dropout(rate=dropout,
                                                name='embeddings/dropout')