def __init__(self, emb_dim=768, num_layers=6, num_heads=12, mlp_dim=3072, mlp_act=activations.approximate_gelu, output_dropout=0.1, attention_dropout=0.1, mlp_dropout=0.1, norm_first=True, norm_input=False, norm_output=True, causal=False, trainable_posemb=False, posemb_init=initializers.HarmonicEmbeddings(scale_factor=1e-4, max_freq=1.0), aaemb_init=tf.initializers.RandomNormal(stddev=1.0), kernel_init=tf.initializers.GlorotUniform(), aaemb_scale_factor=None, max_len=1024, **kwargs): super().__init__(**kwargs) self._causal = causal self.posemb_layer = nlp_layers.PositionEmbedding( max_length=max_len, initializer=posemb_init, trainable=trainable_posemb, name='embeddings/positional') self.aaemb_layer = nlp_layers.OnDeviceEmbedding( vocab_size=len(self._vocab), embedding_width=emb_dim, initializer=aaemb_init, scale_factor=aaemb_scale_factor, name='embeddings/aminoacid') layer_norm_cls = functools.partial(tf.keras.layers.LayerNormalization, axis=-1, epsilon=1e-12) self._input_norm_layer = (layer_norm_cls( name='embeddings/layer_norm') if norm_input else None) self._output_norm_layer = (layer_norm_cls( name='output/layer_norm') if norm_output else None) self._dropout_layer = tf.keras.layers.Dropout( rate=output_dropout, name='embeddings/dropout') self._attention_mask = nlp_layers.SelfAttentionMask() self._transformer_layers = [] for i in range(num_layers): self._transformer_layers.append( nlp_layers.TransformerEncoderBlock( num_attention_heads=num_heads, inner_dim=mlp_dim, inner_activation=mlp_act, output_dropout=output_dropout, attention_dropout=attention_dropout, inner_dropout=mlp_dropout, kernel_initializer=kernel_init, norm_first=norm_first, name=f'transformer/layer_{i}'))
def __init__(self, emb_dim=768, dropout=0.0, use_layer_norm=False, use_positional_embedding=False, position_embed_init=None, train_position_embed=True, aaemb_init=None, aaemb_scale_factor=None, max_len=1024, **kwargs): super().__init__(**kwargs) if position_embed_init is None: position_embed_init = initializers.HarmonicEmbeddings( scale_factor=1e-4, max_freq=1.0) if aaemb_init is None: aaemb_init = tf.initializers.TruncatedNormal(stddev=1.0) self._use_layer_norm = use_layer_norm if use_positional_embedding: self._positional_embedding = nlp_layers.PositionEmbedding( max_length=max_len, initializer=position_embed_init, trainable=train_position_embed, name='embeddings/positional') else: self._positional_embedding = None self._aa_embed = nlp_layers.OnDeviceEmbedding( vocab_size=len(self._vocab), embedding_width=emb_dim, initializer=aaemb_init, scale_factor=aaemb_scale_factor, name='embeddings/aminoacid') if use_layer_norm: self._layer_norm = tf.keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name='embeddings/layer_norm') else: self._layer_norm = None self._dropout = tf.keras.layers.Dropout(rate=dropout, name='embeddings/dropout')