def __init__(self, emb_dim=768, num_layers=6, num_heads=12, mlp_dim=3072, mlp_act=activations.approximate_gelu, output_dropout=0.1, attention_dropout=0.1, mlp_dropout=0.1, norm_first=True, norm_input=False, norm_output=True, causal=False, trainable_posemb=False, posemb_init=initializers.HarmonicEmbeddings(scale_factor=1e-4, max_freq=1.0), aaemb_init=tf.initializers.RandomNormal(stddev=1.0), kernel_init=tf.initializers.GlorotUniform(), aaemb_scale_factor=None, max_len=1024, **kwargs): super().__init__(**kwargs) self._causal = causal self.posemb_layer = nlp_layers.PositionEmbedding( max_length=max_len, initializer=posemb_init, trainable=trainable_posemb, name='embeddings/positional') self.aaemb_layer = nlp_layers.OnDeviceEmbedding( vocab_size=len(self._vocab), embedding_width=emb_dim, initializer=aaemb_init, scale_factor=aaemb_scale_factor, name='embeddings/aminoacid') layer_norm_cls = functools.partial(tf.keras.layers.LayerNormalization, axis=-1, epsilon=1e-12) self._input_norm_layer = (layer_norm_cls( name='embeddings/layer_norm') if norm_input else None) self._output_norm_layer = (layer_norm_cls( name='output/layer_norm') if norm_output else None) self._dropout_layer = tf.keras.layers.Dropout( rate=output_dropout, name='embeddings/dropout') self._attention_mask = nlp_layers.SelfAttentionMask() self._transformer_layers = [] for i in range(num_layers): self._transformer_layers.append( nlp_layers.TransformerEncoderBlock( num_attention_heads=num_heads, inner_dim=mlp_dim, inner_activation=mlp_act, output_dropout=output_dropout, attention_dropout=attention_dropout, inner_dropout=mlp_dropout, kernel_initializer=kernel_init, norm_first=norm_first, name=f'transformer/layer_{i}'))
def build(self, input_shape): """Implements build() for the layer.""" self.encoder_layers = [] for i in range(self.num_layers): self.encoder_layers.append( layers.TransformerEncoderBlock( num_attention_heads=self.num_attention_heads, inner_dim=self._intermediate_size, inner_activation=self._activation, output_dropout=self._dropout_rate, attention_dropout=self._attention_dropout_rate, use_bias=self._use_bias, norm_first=self._norm_first, norm_epsilon=self._norm_epsilon, inner_dropout=self._intermediate_dropout, attention_initializer=attention_initializer( input_shape[2]), name=("layer_%d" % i))) self.output_normalization = tf.keras.layers.LayerNormalization( epsilon=self._norm_epsilon, dtype="float32") super(TransformerEncoder, self).build(input_shape)
def __init__( self, vocab_size, embedding_width=128, hidden_size=768, num_layers=12, num_attention_heads=12, max_sequence_length=512, type_vocab_size=16, intermediate_size=3072, activation=activations.gelu, dropout_rate=0.1, attention_dropout_rate=0.1, initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), dict_outputs=False, **kwargs): activation = tf.keras.activations.get(activation) initializer = tf.keras.initializers.get(initializer) word_ids = tf.keras.layers.Input(shape=(None, ), dtype=tf.int32, name='input_word_ids') mask = tf.keras.layers.Input(shape=(None, ), dtype=tf.int32, name='input_mask') type_ids = tf.keras.layers.Input(shape=(None, ), dtype=tf.int32, name='input_type_ids') if embedding_width is None: embedding_width = hidden_size embedding_layer = layers.OnDeviceEmbedding( vocab_size=vocab_size, embedding_width=embedding_width, initializer=initializer, name='word_embeddings') word_embeddings = embedding_layer(word_ids) # Always uses dynamic slicing for simplicity. position_embedding_layer = layers.PositionEmbedding( initializer=initializer, max_length=max_sequence_length, name='position_embedding') position_embeddings = position_embedding_layer(word_embeddings) type_embeddings = (layers.OnDeviceEmbedding( vocab_size=type_vocab_size, embedding_width=embedding_width, initializer=initializer, use_one_hot=True, name='type_embeddings')(type_ids)) embeddings = tf.keras.layers.Add()( [word_embeddings, position_embeddings, type_embeddings]) embeddings = (tf.keras.layers.LayerNormalization( name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)(embeddings)) embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings)) # We project the 'embedding' output to 'hidden_size' if it is not already # 'hidden_size'. if embedding_width != hidden_size: embeddings = tf.keras.layers.experimental.EinsumDense( '...x,xy->...y', output_shape=hidden_size, bias_axes='y', kernel_initializer=initializer, name='embedding_projection')(embeddings) data = embeddings attention_mask = layers.SelfAttentionMask()(data, mask) shared_layer = layers.TransformerEncoderBlock( num_attention_heads=num_attention_heads, inner_dim=intermediate_size, inner_activation=activation, output_dropout=dropout_rate, attention_dropout=attention_dropout_rate, kernel_initializer=initializer, name='transformer') encoder_outputs = [] for _ in range(num_layers): data = shared_layer([data, attention_mask]) encoder_outputs.append(data) # Applying a tf.slice op (through subscript notation) to a Keras tensor # like this will create a SliceOpLambda layer. This is better than a Lambda # layer with Python code, because that is fundamentally less portable. first_token_tensor = data[:, 0, :] cls_output = tf.keras.layers.Dense( units=hidden_size, activation='tanh', kernel_initializer=initializer, name='pooler_transform')(first_token_tensor) if dict_outputs: outputs = dict( sequence_output=data, encoder_outputs=encoder_outputs, pooled_output=cls_output, ) else: outputs = [data, cls_output] # b/164516224 # Once we've created the network using the Functional API, we call # super().__init__ as though we were invoking the Functional API Model # constructor, resulting in this object having all the properties of a model # created using the Functional API. Once super().__init__ is called, we # can assign attributes to `self` - note that all `self` assignments are # below this line. super(AlbertEncoder, self).__init__(inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs) config_dict = { 'vocab_size': vocab_size, 'embedding_width': embedding_width, 'hidden_size': hidden_size, 'num_layers': num_layers, 'num_attention_heads': num_attention_heads, 'max_sequence_length': max_sequence_length, 'type_vocab_size': type_vocab_size, 'intermediate_size': intermediate_size, 'activation': tf.keras.activations.serialize(activation), 'dropout_rate': dropout_rate, 'attention_dropout_rate': attention_dropout_rate, 'initializer': tf.keras.initializers.serialize(initializer), } # We are storing the config dict as a namedtuple here to ensure checkpoint # compatibility with an earlier version of this model which did not track # the config dict attribute. TF does not track immutable attrs which # do not contain Trackables, so by creating a config namedtuple instead of # a dict we avoid tracking it. config_cls = collections.namedtuple('Config', config_dict.keys()) self._config = config_cls(**config_dict) self._embedding_layer = embedding_layer self._position_embedding_layer = position_embedding_layer
def __init__( self, vocab_size: int, hidden_size: int = 768, num_layers: int = 12, num_attention_heads: int = 12, max_sequence_length: int = 512, type_vocab_size: int = 16, inner_dim: int = 3072, inner_activation: Callable[..., Any] = _approx_gelu, output_dropout: float = 0.1, attention_dropout: float = 0.1, initializer: _Initializer = tf.keras.initializers.TruncatedNormal( stddev=0.02), output_range: Optional[int] = None, embedding_width: Optional[int] = None, embedding_layer: Optional[tf.keras.layers.Layer] = None, norm_first: bool = False, **kwargs): # Pops kwargs that are used in V1 implementation. if 'dict_outputs' in kwargs: kwargs.pop('dict_outputs') if 'return_all_encoder_outputs' in kwargs: kwargs.pop('return_all_encoder_outputs') if 'intermediate_size' in kwargs: inner_dim = kwargs.pop('intermediate_size') if 'activation' in kwargs: inner_activation = kwargs.pop('activation') if 'dropout_rate' in kwargs: output_dropout = kwargs.pop('dropout_rate') if 'attention_dropout_rate' in kwargs: attention_dropout = kwargs.pop('attention_dropout_rate') super().__init__(**kwargs) activation = tf.keras.activations.get(inner_activation) initializer = tf.keras.initializers.get(initializer) if embedding_width is None: embedding_width = hidden_size if embedding_layer is None: self._embedding_layer = layers.OnDeviceEmbedding( vocab_size=vocab_size, embedding_width=embedding_width, initializer=initializer, name='word_embeddings') else: self._embedding_layer = embedding_layer self._position_embedding_layer = layers.PositionEmbedding( initializer=initializer, max_length=max_sequence_length, name='position_embedding') self._type_embedding_layer = layers.OnDeviceEmbedding( vocab_size=type_vocab_size, embedding_width=embedding_width, initializer=initializer, use_one_hot=True, name='type_embeddings') self._embedding_norm_layer = tf.keras.layers.LayerNormalization( name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32) self._embedding_dropout = tf.keras.layers.Dropout( rate=output_dropout, name='embedding_dropout') # We project the 'embedding' output to 'hidden_size' if it is not already # 'hidden_size'. self._embedding_projection = None if embedding_width != hidden_size: self._embedding_projection = tf.keras.layers.experimental.EinsumDense( '...x,xy->...y', output_shape=hidden_size, bias_axes='y', kernel_initializer=initializer, name='embedding_projection') self._transformer_layers = [] self._attention_mask_layer = layers.SelfAttentionMask( name='self_attention_mask') for i in range(num_layers): layer = layers.TransformerEncoderBlock( num_attention_heads=num_attention_heads, inner_dim=inner_dim, inner_activation=inner_activation, output_dropout=output_dropout, attention_dropout=attention_dropout, norm_first=norm_first, output_range=output_range if i == num_layers - 1 else None, kernel_initializer=initializer, name='transformer/layer_%d' % i) self._transformer_layers.append(layer) self._pooler_layer = tf.keras.layers.Dense( units=hidden_size, activation='tanh', kernel_initializer=initializer, name='pooler_transform') self._config = { 'vocab_size': vocab_size, 'hidden_size': hidden_size, 'num_layers': num_layers, 'num_attention_heads': num_attention_heads, 'max_sequence_length': max_sequence_length, 'type_vocab_size': type_vocab_size, 'inner_dim': inner_dim, 'inner_activation': tf.keras.activations.serialize(activation), 'output_dropout': output_dropout, 'attention_dropout': attention_dropout, 'initializer': tf.keras.initializers.serialize(initializer), 'output_range': output_range, 'embedding_width': embedding_width, 'embedding_layer': embedding_layer, 'norm_first': norm_first, } self.inputs = dict( input_word_ids=tf.keras.Input(shape=(None,), dtype=tf.int32), input_mask=tf.keras.Input(shape=(None,), dtype=tf.int32), input_type_ids=tf.keras.Input(shape=(None,), dtype=tf.int32))
def __init__( self, vocab_size, hidden_size=768, num_layers=12, num_attention_heads=12, max_sequence_length=512, type_vocab_size=16, inner_dim=3072, inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True), output_dropout=0.1, attention_dropout=0.1, initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), output_range=None, embedding_width=None, embedding_layer=None, norm_first=False, dict_outputs=False, return_all_encoder_outputs=False, **kwargs): if 'sequence_length' in kwargs: kwargs.pop('sequence_length') logging.warning('`sequence_length` is a deprecated argument to ' '`BertEncoder`, which has no effect for a while. Please ' 'remove `sequence_length` argument.') # Handles backward compatible kwargs. if 'intermediate_size' in kwargs: inner_dim = kwargs.pop('intermediate_size') if 'activation' in kwargs: inner_activation = kwargs.pop('activation') if 'dropout_rate' in kwargs: output_dropout = kwargs.pop('dropout_rate') if 'attention_dropout_rate' in kwargs: attention_dropout = kwargs.pop('attention_dropout_rate') activation = tf.keras.activations.get(inner_activation) initializer = tf.keras.initializers.get(initializer) word_ids = tf.keras.layers.Input( shape=(None,), dtype=tf.int32, name='input_word_ids') mask = tf.keras.layers.Input( shape=(None,), dtype=tf.int32, name='input_mask') type_ids = tf.keras.layers.Input( shape=(None,), dtype=tf.int32, name='input_type_ids') if embedding_width is None: embedding_width = hidden_size if embedding_layer is None: embedding_layer_inst = layers.OnDeviceEmbedding( vocab_size=vocab_size, embedding_width=embedding_width, initializer=initializer, name='word_embeddings') else: embedding_layer_inst = embedding_layer word_embeddings = embedding_layer_inst(word_ids) # Always uses dynamic slicing for simplicity. position_embedding_layer = layers.PositionEmbedding( initializer=initializer, max_length=max_sequence_length, name='position_embedding') position_embeddings = position_embedding_layer(word_embeddings) type_embedding_layer = layers.OnDeviceEmbedding( vocab_size=type_vocab_size, embedding_width=embedding_width, initializer=initializer, use_one_hot=True, name='type_embeddings') type_embeddings = type_embedding_layer(type_ids) embeddings = tf.keras.layers.Add()( [word_embeddings, position_embeddings, type_embeddings]) embedding_norm_layer = tf.keras.layers.LayerNormalization( name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32) embeddings = embedding_norm_layer(embeddings) embeddings = (tf.keras.layers.Dropout(rate=output_dropout)(embeddings)) # We project the 'embedding' output to 'hidden_size' if it is not already # 'hidden_size'. if embedding_width != hidden_size: embedding_projection = tf.keras.layers.experimental.EinsumDense( '...x,xy->...y', output_shape=hidden_size, bias_axes='y', kernel_initializer=initializer, name='embedding_projection') embeddings = embedding_projection(embeddings) else: embedding_projection = None transformer_layers = [] data = embeddings attention_mask = layers.SelfAttentionMask()(data, mask) encoder_outputs = [] for i in range(num_layers): if i == num_layers - 1 and output_range is not None: transformer_output_range = output_range else: transformer_output_range = None layer = layers.TransformerEncoderBlock( num_attention_heads=num_attention_heads, inner_dim=inner_dim, inner_activation=inner_activation, output_dropout=output_dropout, attention_dropout=attention_dropout, norm_first=norm_first, output_range=transformer_output_range, kernel_initializer=initializer, name='transformer/layer_%d' % i) transformer_layers.append(layer) data = layer([data, attention_mask]) encoder_outputs.append(data) last_encoder_output = encoder_outputs[-1] # Applying a tf.slice op (through subscript notation) to a Keras tensor # like this will create a SliceOpLambda layer. This is better than a Lambda # layer with Python code, because that is fundamentally less portable. first_token_tensor = last_encoder_output[:, 0, :] pooler_layer = tf.keras.layers.Dense( units=hidden_size, activation='tanh', kernel_initializer=initializer, name='pooler_transform') cls_output = pooler_layer(first_token_tensor) outputs = dict( sequence_output=encoder_outputs[-1], pooled_output=cls_output, encoder_outputs=encoder_outputs, ) if dict_outputs: super().__init__( inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs) else: cls_output = outputs['pooled_output'] if return_all_encoder_outputs: encoder_outputs = outputs['encoder_outputs'] outputs = [encoder_outputs, cls_output] else: sequence_output = outputs['sequence_output'] outputs = [sequence_output, cls_output] super().__init__( # pylint: disable=bad-super-call inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs) self._pooler_layer = pooler_layer self._transformer_layers = transformer_layers self._embedding_norm_layer = embedding_norm_layer self._embedding_layer = embedding_layer_inst self._position_embedding_layer = position_embedding_layer self._type_embedding_layer = type_embedding_layer if embedding_projection is not None: self._embedding_projection = embedding_projection config_dict = { 'vocab_size': vocab_size, 'hidden_size': hidden_size, 'num_layers': num_layers, 'num_attention_heads': num_attention_heads, 'max_sequence_length': max_sequence_length, 'type_vocab_size': type_vocab_size, 'inner_dim': inner_dim, 'inner_activation': tf.keras.activations.serialize(activation), 'output_dropout': output_dropout, 'attention_dropout': attention_dropout, 'initializer': tf.keras.initializers.serialize(initializer), 'output_range': output_range, 'embedding_width': embedding_width, 'embedding_layer': embedding_layer, 'norm_first': norm_first, 'dict_outputs': dict_outputs, } # pylint: disable=protected-access self._setattr_tracking = False self._config = config_dict self._setattr_tracking = True
def __init__( self, vocab_size, hidden_size=768, num_layers=12, num_attention_heads=12, max_sequence_length=512, type_vocab_size=16, inner_dim=3072, inner_activation=lambda x: tf.keras.activations.gelu( x, approximate=True), output_dropout=0.1, attention_dropout=0.1, pool_type=_MAX, pool_stride=2, unpool_length=0, initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02), output_range=None, embedding_width=None, embedding_layer=None, norm_first=False, **kwargs): super().__init__(**kwargs) activation = tf.keras.activations.get(inner_activation) initializer = tf.keras.initializers.get(initializer) if embedding_width is None: embedding_width = hidden_size if embedding_layer is None: self._embedding_layer = layers.OnDeviceEmbedding( vocab_size=vocab_size, embedding_width=embedding_width, initializer=initializer, name='word_embeddings') else: self._embedding_layer = embedding_layer self._position_embedding_layer = layers.PositionEmbedding( initializer=initializer, max_length=max_sequence_length, name='position_embedding') self._type_embedding_layer = layers.OnDeviceEmbedding( vocab_size=type_vocab_size, embedding_width=embedding_width, initializer=initializer, use_one_hot=True, name='type_embeddings') self._embedding_norm_layer = tf.keras.layers.LayerNormalization( name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32) self._embedding_dropout = tf.keras.layers.Dropout( rate=output_dropout, name='embedding_dropout') # We project the 'embedding' output to 'hidden_size' if it is not already # 'hidden_size'. self._embedding_projection = None if embedding_width != hidden_size: self._embedding_projection = tf.keras.layers.experimental.EinsumDense( '...x,xy->...y', output_shape=hidden_size, bias_axes='y', kernel_initializer=initializer, name='embedding_projection') self._transformer_layers = [] self._attention_mask_layer = layers.SelfAttentionMask( name='self_attention_mask') for i in range(num_layers): layer = layers.TransformerEncoderBlock( num_attention_heads=num_attention_heads, inner_dim=inner_dim, inner_activation=inner_activation, output_dropout=output_dropout, attention_dropout=attention_dropout, norm_first=norm_first, output_range=output_range if i == num_layers - 1 else None, kernel_initializer=initializer, name='transformer/layer_%d' % i) self._transformer_layers.append(layer) self._pooler_layer = tf.keras.layers.Dense( units=hidden_size, activation='tanh', kernel_initializer=initializer, name='pooler_transform') if isinstance(pool_stride, int): # TODO(b/197133196): Pooling layer can be shared. pool_strides = [pool_stride] * num_layers else: if len(pool_stride) != num_layers: raise ValueError( 'Lengths of pool_stride and num_layers are not equal.') pool_strides = pool_stride # TODO(crickwu): explore tf.keras.layers.serialize method. if pool_type == _MAX: pool_cls = tf.keras.layers.MaxPooling1D elif pool_type == _AVG: pool_cls = tf.keras.layers.AveragePooling1D elif pool_type == _TRUNCATED_AVG: # TODO(b/203665205): unpool_length should be implemented. if unpool_length != 0: raise ValueError( 'unpool_length is not supported by truncated_avg now.') # Compute the attention masks and pooling transforms. self._pooling_transforms = _create_truncated_avg_transforms( max_sequence_length, pool_strides) else: raise ValueError('pool_type not supported.') if pool_type in (_MAX, _AVG): self._att_input_pool_layers = [] for layer_pool_stride in pool_strides: att_input_pool_layer = pool_cls(pool_size=layer_pool_stride, strides=layer_pool_stride, padding='same', name='att_input_pool_layer') self._att_input_pool_layers.append(att_input_pool_layer) self._pool_strides = pool_strides # This is a list here. self._unpool_length = unpool_length self._pool_type = pool_type self._config = { 'vocab_size': vocab_size, 'hidden_size': hidden_size, 'num_layers': num_layers, 'num_attention_heads': num_attention_heads, 'max_sequence_length': max_sequence_length, 'type_vocab_size': type_vocab_size, 'inner_dim': inner_dim, 'inner_activation': tf.keras.activations.serialize(activation), 'output_dropout': output_dropout, 'attention_dropout': attention_dropout, 'initializer': tf.keras.initializers.serialize(initializer), 'output_range': output_range, 'embedding_width': embedding_width, 'embedding_layer': embedding_layer, 'norm_first': norm_first, 'pool_type': pool_type, 'pool_stride': pool_stride, 'unpool_length': unpool_length, }
def __init__(self, network, bert_config, initializer='glorot_uniform', seq_length=128, use_pointing=True, is_training=True): """Creates Felix Tagger. Setting up all of the layers needed for call. Args: network: An encoder network, which should output a sequence of hidden states. bert_config: A config file which in addition to the BertConfig values also includes: num_classes, hidden_dropout_prob, and query_transformer. initializer: The initializer (if any) to use in the classification networks. Defaults to a Glorot uniform initializer. seq_length: Maximum sequence length. use_pointing: Whether a pointing network is used. is_training: The model is being trained. """ super(FelixTagger, self).__init__() self._network = network self._seq_length = seq_length self._bert_config = bert_config self._use_pointing = use_pointing self._is_training = is_training self._tag_logits_layer = tf.keras.layers.Dense( self._bert_config.num_classes) if not self._use_pointing: return # An arbitrary heuristic (sqrt vocab size) for the tag embedding dimension. self._tag_embedding_layer = tf.keras.layers.Embedding( self._bert_config.num_classes, int(math.ceil(math.sqrt(self._bert_config.num_classes))), input_length=seq_length) self._position_embedding_layer = layers.PositionEmbedding( max_length=seq_length) self._edit_tagged_sequence_output_layer = tf.keras.layers.Dense( self._bert_config.hidden_size, activation=activations.gelu) if self._bert_config.query_transformer: self._self_attention_mask_layer = layers.SelfAttentionMask() self._transformer_query_layer = layers.TransformerEncoderBlock( num_attention_heads=self._bert_config.num_attention_heads, inner_dim=self._bert_config.intermediate_size, inner_activation=activations.gelu, output_dropout=self._bert_config.hidden_dropout_prob, attention_dropout=self._bert_config.hidden_dropout_prob, output_range=seq_length, ) self._query_embeddings_layer = tf.keras.layers.Dense( self._bert_config.query_size) self._key_embeddings_layer = tf.keras.layers.Dense( self._bert_config.query_size)
def __init__( self, vocab_size: int, hidden_size: int = 768, num_layers: int = 12, num_attention_heads: int = 12, max_sequence_length: int = 512, type_vocab_size: int = 16, inner_dim: int = 3072, inner_activation: _Activation = _approx_gelu, output_dropout: float = 0.1, attention_dropout: float = 0.1, token_loss_init_value: float = 10.0, token_loss_beta: float = 0.995, token_keep_k: int = 256, token_allow_list: Tuple[int, ...] = (100, 101, 102, 103), token_deny_list: Tuple[int, ...] = (0, ), initializer: _Initializer = tf.keras.initializers.TruncatedNormal( stddev=0.02), output_range: Optional[int] = None, embedding_width: Optional[int] = None, embedding_layer: Optional[tf.keras.layers.Layer] = None, norm_first: bool = False, with_dense_inputs: bool = False, **kwargs): # Pops kwargs that are used in V1 implementation. if 'dict_outputs' in kwargs: kwargs.pop('dict_outputs') if 'return_all_encoder_outputs' in kwargs: kwargs.pop('return_all_encoder_outputs') if 'intermediate_size' in kwargs: inner_dim = kwargs.pop('intermediate_size') if 'activation' in kwargs: inner_activation = kwargs.pop('activation') if 'dropout_rate' in kwargs: output_dropout = kwargs.pop('dropout_rate') if 'attention_dropout_rate' in kwargs: attention_dropout = kwargs.pop('attention_dropout_rate') super().__init__(**kwargs) activation = tf.keras.activations.get(inner_activation) initializer = tf.keras.initializers.get(initializer) if embedding_width is None: embedding_width = hidden_size if embedding_layer is None: self._embedding_layer = layers.OnDeviceEmbedding( vocab_size=vocab_size, embedding_width=embedding_width, initializer=tf_utils.clone_initializer(initializer), name='word_embeddings') else: self._embedding_layer = embedding_layer self._position_embedding_layer = layers.PositionEmbedding( initializer=tf_utils.clone_initializer(initializer), max_length=max_sequence_length, name='position_embedding') self._type_embedding_layer = layers.OnDeviceEmbedding( vocab_size=type_vocab_size, embedding_width=embedding_width, initializer=tf_utils.clone_initializer(initializer), use_one_hot=True, name='type_embeddings') self._embedding_norm_layer = tf.keras.layers.LayerNormalization( name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32) self._embedding_dropout = tf.keras.layers.Dropout( rate=output_dropout, name='embedding_dropout') # We project the 'embedding' output to 'hidden_size' if it is not already # 'hidden_size'. self._embedding_projection = None if embedding_width != hidden_size: self._embedding_projection = tf.keras.layers.EinsumDense( '...x,xy->...y', output_shape=hidden_size, bias_axes='y', kernel_initializer=tf_utils.clone_initializer(initializer), name='embedding_projection') # The first 999 tokens are special tokens such as [PAD], [CLS], [SEP]. # We want to always mask [PAD], and always not to maks [CLS], [SEP]. init_importance = tf.constant(token_loss_init_value, shape=(vocab_size)) if token_allow_list: init_importance = tf.tensor_scatter_nd_update( tensor=init_importance, indices=[[x] for x in token_allow_list], updates=[1.0e4 for x in token_allow_list]) if token_deny_list: init_importance = tf.tensor_scatter_nd_update( tensor=init_importance, indices=[[x] for x in token_deny_list], updates=[-1.0e4 for x in token_deny_list]) self._token_importance_embed = layers.TokenImportanceWithMovingAvg( vocab_size=vocab_size, init_importance=init_importance, moving_average_beta=token_loss_beta) self._token_separator = layers.SelectTopK(top_k=token_keep_k) self._transformer_layers = [] self._num_layers = num_layers self._attention_mask_layer = layers.SelfAttentionMask( name='self_attention_mask') for i in range(num_layers): layer = layers.TransformerEncoderBlock( num_attention_heads=num_attention_heads, inner_dim=inner_dim, inner_activation=inner_activation, output_dropout=output_dropout, attention_dropout=attention_dropout, norm_first=norm_first, output_range=output_range if i == num_layers - 1 else None, kernel_initializer=tf_utils.clone_initializer(initializer), name='transformer/layer_%d' % i) self._transformer_layers.append(layer) self._pooler_layer = tf.keras.layers.Dense( units=hidden_size, activation='tanh', kernel_initializer=tf_utils.clone_initializer(initializer), name='pooler_transform') self._config = { 'vocab_size': vocab_size, 'hidden_size': hidden_size, 'num_layers': num_layers, 'num_attention_heads': num_attention_heads, 'max_sequence_length': max_sequence_length, 'type_vocab_size': type_vocab_size, 'inner_dim': inner_dim, 'inner_activation': tf.keras.activations.serialize(activation), 'output_dropout': output_dropout, 'attention_dropout': attention_dropout, 'token_loss_init_value': token_loss_init_value, 'token_loss_beta': token_loss_beta, 'token_keep_k': token_keep_k, 'token_allow_list': token_allow_list, 'token_deny_list': token_deny_list, 'initializer': tf.keras.initializers.serialize(initializer), 'output_range': output_range, 'embedding_width': embedding_width, 'embedding_layer': embedding_layer, 'norm_first': norm_first, 'with_dense_inputs': with_dense_inputs, } if with_dense_inputs: self.inputs = dict( input_word_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32), input_mask=tf.keras.Input(shape=(None, ), dtype=tf.int32), input_type_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32), dense_inputs=tf.keras.Input(shape=(None, embedding_width), dtype=tf.float32), dense_mask=tf.keras.Input(shape=(None, ), dtype=tf.int32), dense_type_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32), ) else: self.inputs = dict(input_word_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32), input_mask=tf.keras.Input(shape=(None, ), dtype=tf.int32), input_type_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32))