Esempio n. 1
0
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        embedding_width = 768
        dropout_rate = 0.1
        initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02)

        self._embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=30522,
            embedding_width=embedding_width,
            initializer=initializer,
            name="word_embeddings",
        )

        # Always uses dynamic slicing for simplicity.
        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            use_dynamic_slicing=True,
            max_sequence_length=512,
            name="position_embedding",
        )
        self._type_embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=2,
            embedding_width=embedding_width,
            initializer=initializer,
            use_one_hot=True,
            name="type_embeddings",
        )
        self._add = tf.keras.layers.Add()
        self._layer_norm = tf.keras.layers.LayerNormalization(
            name="embeddings/layer_norm",
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)
        self._dropout = tf.keras.layers.Dropout(rate=dropout_rate)

        self._attention_mask = layers.SelfAttentionMask()
        self._transformer_layers = []
        for i in range(12):
            layer = layers.Transformer(
                num_attention_heads=12,
                intermediate_size=3072,
                intermediate_activation=activations.gelu,
                dropout_rate=dropout_rate,
                attention_dropout_rate=0.1,
                output_range=None,
                kernel_initializer=initializer,
                name="transformer/layer_%d" % i,
            )
            self._transformer_layers.append(layer)

        self._lambda = tf.keras.layers.Lambda(
            lambda x: tf.squeeze(x[:, 0:1, :], axis=1))
        self._pooler_layer = tf.keras.layers.Dense(
            units=embedding_width,
            activation="tanh",
            kernel_initializer=initializer,
            name="pooler_transform",
        )
Esempio n. 2
0
 def __init__(self,
              emb_dim=768,
              num_layers=6,
              num_heads=12,
              mlp_dim=3072,
              mlp_act=activations.approximate_gelu,
              output_dropout=0.1,
              attention_dropout=0.1,
              mlp_dropout=0.1,
              norm_first=True,
              norm_input=False,
              norm_output=True,
              causal=False,
              trainable_posemb=False,
              posemb_init=initializers.HarmonicEmbeddings(scale_factor=1e-4,
                                                          max_freq=1.0),
              aaemb_init=tf.initializers.RandomNormal(stddev=1.0),
              kernel_init=tf.initializers.GlorotUniform(),
              aaemb_scale_factor=None,
              max_len=1024,
              **kwargs):
     super().__init__(**kwargs)
     self._causal = causal
     self.posemb_layer = nlp_layers.PositionEmbedding(
         max_length=max_len,
         initializer=posemb_init,
         trainable=trainable_posemb,
         name='embeddings/positional')
     self.aaemb_layer = nlp_layers.OnDeviceEmbedding(
         vocab_size=len(self._vocab),
         embedding_width=emb_dim,
         initializer=aaemb_init,
         scale_factor=aaemb_scale_factor,
         name='embeddings/aminoacid')
     layer_norm_cls = functools.partial(tf.keras.layers.LayerNormalization,
                                        axis=-1,
                                        epsilon=1e-12)
     self._input_norm_layer = (layer_norm_cls(
         name='embeddings/layer_norm') if norm_input else None)
     self._output_norm_layer = (layer_norm_cls(
         name='output/layer_norm') if norm_output else None)
     self._dropout_layer = tf.keras.layers.Dropout(
         rate=output_dropout, name='embeddings/dropout')
     self._attention_mask = nlp_layers.SelfAttentionMask()
     self._transformer_layers = []
     for i in range(num_layers):
         self._transformer_layers.append(
             nlp_layers.TransformerEncoderBlock(
                 num_attention_heads=num_heads,
                 inner_dim=mlp_dim,
                 inner_activation=mlp_act,
                 output_dropout=output_dropout,
                 attention_dropout=attention_dropout,
                 inner_dropout=mlp_dropout,
                 kernel_initializer=kernel_init,
                 norm_first=norm_first,
                 name=f'transformer/layer_{i}'))
Esempio n. 3
0
  def __init__(self,
               word_vocab_size,
               word_embed_size,
               type_vocab_size,
               output_embed_size,
               max_sequence_length=512,
               normalization_type='no_norm',
               initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
               dropout_rate=0.1):
    """Class initialization.

    Arguments:
      word_vocab_size: Number of words in the vocabulary.
      word_embed_size: Word embedding size.
      type_vocab_size: Number of word types.
      output_embed_size: Embedding size for the final embedding output.
      max_sequence_length: Maximum length of input sequence.
      normalization_type: String. The type of normalization_type, only
        'no_norm' and 'layer_norm' are supported.
      initializer: The initializer to use for the embedding weights and
        linear projection weights.
      dropout_rate: Dropout rate.
    """
    super(MobileBertEmbedding, self).__init__()
    self.word_vocab_size = word_vocab_size
    self.word_embed_size = word_embed_size
    self.type_vocab_size = type_vocab_size
    self.output_embed_size = output_embed_size
    self.max_sequence_length = max_sequence_length
    self.dropout_rate = dropout_rate

    self.word_embedding = layers.OnDeviceEmbedding(
        self.word_vocab_size,
        self.word_embed_size,
        initializer=initializer,
        name='word_embedding')
    self.type_embedding = layers.OnDeviceEmbedding(
        self.type_vocab_size,
        self.output_embed_size,
        use_one_hot=True,
        initializer=initializer,
        name='type_embedding')
    self.pos_embedding = layers.PositionEmbedding(
        use_dynamic_slicing=True,
        max_sequence_length=max_sequence_length,
        initializer=initializer,
        name='position_embedding')
    self.word_embedding_proj = tf.keras.layers.experimental.EinsumDense(
        'abc,cd->abd',
        output_shape=[None, self.output_embed_size],
        kernel_initializer=initializer,
        bias_axes='d',
        name='embedding_projection')
    self.layer_norm = _get_norm_layer(normalization_type, 'embedding_norm')
    self.dropout_layer = tf.keras.layers.Dropout(
        self.dropout_rate,
        name='embedding_dropout')
Esempio n. 4
0
    def __init__(self,
                 emb_dim=768,
                 dropout=0.0,
                 use_layer_norm=False,
                 use_positional_embedding=False,
                 position_embed_init=None,
                 train_position_embed=True,
                 aaemb_init=None,
                 aaemb_scale_factor=None,
                 max_len=1024,
                 **kwargs):
        super().__init__(**kwargs)
        if position_embed_init is None:
            position_embed_init = initializers.HarmonicEmbeddings(
                scale_factor=1e-4, max_freq=1.0)
        if aaemb_init is None:
            aaemb_init = tf.initializers.TruncatedNormal(stddev=1.0)

        self._use_layer_norm = use_layer_norm

        if use_positional_embedding:
            self._positional_embedding = nlp_layers.PositionEmbedding(
                max_length=max_len,
                initializer=position_embed_init,
                trainable=train_position_embed,
                name='embeddings/positional')
        else:
            self._positional_embedding = None

        self._aa_embed = nlp_layers.OnDeviceEmbedding(
            vocab_size=len(self._vocab),
            embedding_width=emb_dim,
            initializer=aaemb_init,
            scale_factor=aaemb_scale_factor,
            name='embeddings/aminoacid')

        if use_layer_norm:
            self._layer_norm = tf.keras.layers.LayerNormalization(
                axis=-1, epsilon=1e-12, name='embeddings/layer_norm')
        else:
            self._layer_norm = None

        self._dropout = tf.keras.layers.Dropout(rate=dropout,
                                                name='embeddings/dropout')
Esempio n. 5
0
    def __init__(
            self,
            pooled_output_dim,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            embedding_cls=None,
            embedding_cfg=None,
            embedding_data=None,
            num_hidden_instances=1,
            hidden_cls=layers.Transformer,
            hidden_cfg=None,
            mask_cls=layers.SelfAttentionMask,
            mask_cfg=None,
            layer_norm_before_pooling=False,
            return_all_layer_outputs=False,
            dict_outputs=False,
            layer_idx_as_attention_seed=False,
            feed_layer_idx=False,
            recursive=False,
            **kwargs):

        if embedding_cls:
            if inspect.isclass(embedding_cls):
                embedding_network = embedding_cls(
                    **embedding_cfg) if embedding_cfg else embedding_cls()
            else:
                embedding_network = embedding_cls
            inputs = embedding_network.inputs
            embeddings, attention_mask = embedding_network(inputs)
            embedding_layer = None
            position_embedding_layer = None
            type_embedding_layer = None
            embedding_norm_layer = None
        else:
            embedding_network = None
            seq_length = embedding_cfg.get('seq_length', None)
            word_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                             dtype=tf.int32,
                                             name='input_word_ids')
            mask = tf.keras.layers.Input(shape=(seq_length, ),
                                         dtype=tf.int32,
                                         name='input_mask')
            type_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                             dtype=tf.int32,
                                             name='input_type_ids')
            inputs = [word_ids, mask, type_ids]

            embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=embedding_cfg['vocab_size'],
                embedding_width=embedding_cfg['hidden_size'],
                initializer=tf_utils.clone_initializer(
                    embedding_cfg['initializer']),
                name='word_embeddings')

            word_embeddings = embedding_layer(word_ids)

            # Always uses dynamic slicing for simplicity.
            position_embedding_layer = layers.PositionEmbedding(
                initializer=tf_utils.clone_initializer(
                    embedding_cfg['initializer']),
                max_length=embedding_cfg['max_seq_length'],
                name='position_embedding')
            position_embeddings = position_embedding_layer(word_embeddings)

            type_embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=embedding_cfg['type_vocab_size'],
                embedding_width=embedding_cfg['hidden_size'],
                initializer=tf_utils.clone_initializer(
                    embedding_cfg['initializer']),
                use_one_hot=True,
                name='type_embeddings')
            type_embeddings = type_embedding_layer(type_ids)

            embeddings = tf.keras.layers.Add()(
                [word_embeddings, position_embeddings, type_embeddings])

            embedding_norm_layer = tf.keras.layers.LayerNormalization(
                name='embeddings/layer_norm',
                axis=-1,
                epsilon=1e-12,
                dtype=tf.float32)
            embeddings = embedding_norm_layer(embeddings)

            embeddings = (tf.keras.layers.Dropout(
                rate=embedding_cfg['dropout_rate'])(embeddings))

            mask_cfg = {} if mask_cfg is None else mask_cfg
            if inspect.isclass(mask_cls):
                mask_layer = mask_cls(**mask_cfg)
            else:
                mask_layer = mask_cls
            attention_mask = mask_layer(embeddings, mask)

        data = embeddings

        layer_output_data = []
        hidden_layers = []
        hidden_cfg = hidden_cfg if hidden_cfg else {}

        if isinstance(hidden_cls,
                      list) and len(hidden_cls) != num_hidden_instances:
            raise RuntimeError((
                'When input hidden_cls to EncoderScaffold %s is a list, it must '
                'contain classes or instances with size specified by '
                'num_hidden_instances, got %d vs %d.') % self.name,
                               len(hidden_cls), num_hidden_instances)
        # Consider supporting customized init states.
        recursive_states = None
        for i in range(num_hidden_instances):
            if isinstance(hidden_cls, list):
                cur_hidden_cls = hidden_cls[i]
            else:
                cur_hidden_cls = hidden_cls
            if inspect.isclass(cur_hidden_cls):
                if hidden_cfg and 'attention_cfg' in hidden_cfg and (
                        layer_idx_as_attention_seed):
                    hidden_cfg = copy.deepcopy(hidden_cfg)
                    hidden_cfg['attention_cfg']['seed'] = i
                if feed_layer_idx:
                    hidden_cfg['layer_idx'] = i
                layer = cur_hidden_cls(**hidden_cfg)
            else:
                layer = cur_hidden_cls
            if recursive:
                data, recursive_states = layer(
                    [data, attention_mask, recursive_states])
            else:
                data = layer([data, attention_mask])
            layer_output_data.append(data)
            hidden_layers.append(layer)

        if layer_norm_before_pooling:
            # Normalize the final output.
            output_layer_norm = tf.keras.layers.LayerNormalization(
                name='final_layer_norm', axis=-1, epsilon=1e-12)
            layer_output_data[-1] = output_layer_norm(layer_output_data[-1])

        last_layer_output = layer_output_data[-1]
        # Applying a tf.slice op (through subscript notation) to a Keras tensor
        # like this will create a SliceOpLambda layer. This is better than a Lambda
        # layer with Python code, because that is fundamentally less portable.
        first_token_tensor = last_layer_output[:, 0, :]
        pooler_layer_initializer = tf.keras.initializers.get(
            pooler_layer_initializer)
        pooler_layer = tf.keras.layers.Dense(
            units=pooled_output_dim,
            activation='tanh',
            kernel_initializer=pooler_layer_initializer,
            name='cls_transform')
        cls_output = pooler_layer(first_token_tensor)

        if dict_outputs:
            outputs = dict(
                sequence_output=layer_output_data[-1],
                pooled_output=cls_output,
                encoder_outputs=layer_output_data,
            )
        elif return_all_layer_outputs:
            outputs = [layer_output_data, cls_output]
        else:
            outputs = [layer_output_data[-1], cls_output]

        # b/164516224
        # Once we've created the network using the Functional API, we call
        # super().__init__ as though we were invoking the Functional API Model
        # constructor, resulting in this object having all the properties of a model
        # created using the Functional API. Once super().__init__ is called, we
        # can assign attributes to `self` - note that all `self` assignments are
        # below this line.
        super().__init__(inputs=inputs, outputs=outputs, **kwargs)

        self._hidden_cls = hidden_cls
        self._hidden_cfg = hidden_cfg
        self._mask_cls = mask_cls
        self._mask_cfg = mask_cfg
        self._num_hidden_instances = num_hidden_instances
        self._pooled_output_dim = pooled_output_dim
        self._pooler_layer_initializer = pooler_layer_initializer
        self._embedding_cls = embedding_cls
        self._embedding_cfg = embedding_cfg
        self._embedding_data = embedding_data
        self._layer_norm_before_pooling = layer_norm_before_pooling
        self._return_all_layer_outputs = return_all_layer_outputs
        self._dict_outputs = dict_outputs
        self._kwargs = kwargs

        self._embedding_layer = embedding_layer
        self._embedding_network = embedding_network
        self._position_embedding_layer = position_embedding_layer
        self._type_embedding_layer = type_embedding_layer
        self._embedding_norm_layer = embedding_norm_layer
        self._hidden_layers = hidden_layers
        if self._layer_norm_before_pooling:
            self._output_layer_norm = output_layer_norm
        self._pooler_layer = pooler_layer
        self._layer_idx_as_attention_seed = layer_idx_as_attention_seed

        logging.info('EncoderScaffold configs: %s', self.get_config())
Esempio n. 6
0
    def __init__(
            self,
            vocab_size: int,
            hidden_size: int = 768,
            num_layers: int = 12,
            num_attention_heads: int = 12,
            max_sequence_length: int = 512,
            type_vocab_size: int = 16,
            inner_dim: int = 3072,
            inner_activation: _Activation = _approx_gelu,
            output_dropout: float = 0.1,
            attention_dropout: float = 0.1,
            pool_type: str = _MAX,
            pool_stride: int = 2,
            unpool_length: int = 0,
            initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            output_range: Optional[int] = None,
            embedding_width: Optional[int] = None,
            embedding_layer: Optional[tf.keras.layers.Layer] = None,
            norm_first: bool = False,
            transformer_cls: Union[
                str, tf.keras.layers.Layer] = layers.TransformerEncoderBlock,
            share_rezero: bool = True,
            **kwargs):
        super().__init__(**kwargs)
        activation = tf.keras.activations.get(inner_activation)
        initializer = tf.keras.initializers.get(initializer)

        if embedding_width is None:
            embedding_width = hidden_size

        if embedding_layer is None:
            self._embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=vocab_size,
                embedding_width=embedding_width,
                initializer=tf_utils.clone_initializer(initializer),
                name='word_embeddings')
        else:
            self._embedding_layer = embedding_layer

        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=tf_utils.clone_initializer(initializer),
            max_length=max_sequence_length,
            name='position_embedding')

        self._type_embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=tf_utils.clone_initializer(initializer),
            use_one_hot=True,
            name='type_embeddings')

        self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)

        self._embedding_dropout = tf.keras.layers.Dropout(
            rate=output_dropout, name='embedding_dropout')

        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        self._embedding_projection = None
        if embedding_width != hidden_size:
            self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
                '...x,xy->...y',
                output_shape=hidden_size,
                bias_axes='y',
                kernel_initializer=tf_utils.clone_initializer(initializer),
                name='embedding_projection')

        self._transformer_layers = []
        self._attention_mask_layer = layers.SelfAttentionMask(
            name='self_attention_mask')
        # Will raise an error if the string is not supported.
        if isinstance(transformer_cls, str):
            transformer_cls = _str2transformer_cls[transformer_cls]
        for i in range(num_layers):
            layer = transformer_cls(
                num_attention_heads=num_attention_heads,
                intermediate_size=inner_dim,
                inner_dim=inner_dim,
                intermediate_activation=inner_activation,
                inner_activation=inner_activation,
                output_dropout=output_dropout,
                attention_dropout=attention_dropout,
                norm_first=norm_first,
                output_range=output_range if i == num_layers - 1 else None,
                kernel_initializer=tf_utils.clone_initializer(initializer),
                share_rezero=share_rezero,
                name='transformer/layer_%d' % i)
            self._transformer_layers.append(layer)

        self._pooler_layer = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=tf_utils.clone_initializer(initializer),
            name='pooler_transform')
        if isinstance(pool_stride, int):
            # TODO(b/197133196): Pooling layer can be shared.
            pool_strides = [pool_stride] * num_layers
        else:
            if len(pool_stride) != num_layers:
                raise ValueError(
                    'Lengths of pool_stride and num_layers are not equal.')
            pool_strides = pool_stride
        # TODO(crickwu): explore tf.keras.layers.serialize method.
        if pool_type == _MAX:
            pool_cls = tf.keras.layers.MaxPooling1D
        elif pool_type == _AVG:
            pool_cls = tf.keras.layers.AveragePooling1D
        elif pool_type == _TRUNCATED_AVG:
            # TODO(b/203665205): unpool_length should be implemented.
            if unpool_length != 0:
                raise ValueError(
                    'unpool_length is not supported by truncated_avg now.')
        else:
            raise ValueError('pool_type not supported.')

        if pool_type in (_MAX, _AVG):
            self._att_input_pool_layers = []
            for layer_pool_stride in pool_strides:
                att_input_pool_layer = pool_cls(pool_size=layer_pool_stride,
                                                strides=layer_pool_stride,
                                                padding='same',
                                                name='att_input_pool_layer')
                self._att_input_pool_layers.append(att_input_pool_layer)

        self._max_sequence_length = max_sequence_length
        self._pool_strides = pool_strides  # This is a list here.
        self._unpool_length = unpool_length
        self._pool_type = pool_type

        self._config = {
            'vocab_size':
            vocab_size,
            'hidden_size':
            hidden_size,
            'num_layers':
            num_layers,
            'num_attention_heads':
            num_attention_heads,
            'max_sequence_length':
            max_sequence_length,
            'type_vocab_size':
            type_vocab_size,
            'inner_dim':
            inner_dim,
            'inner_activation':
            tf.keras.activations.serialize(activation),
            'output_dropout':
            output_dropout,
            'attention_dropout':
            attention_dropout,
            'initializer':
            tf.keras.initializers.serialize(initializer),
            'output_range':
            output_range,
            'embedding_width':
            embedding_width,
            'embedding_layer':
            embedding_layer,
            'norm_first':
            norm_first,
            'pool_type':
            pool_type,
            'pool_stride':
            pool_stride,
            'unpool_length':
            unpool_length,
            'transformer_cls':
            _transformer_cls2str.get(transformer_cls, str(transformer_cls))
        }

        self.inputs = dict(input_word_ids=tf.keras.Input(shape=(None, ),
                                                         dtype=tf.int32),
                           input_mask=tf.keras.Input(shape=(None, ),
                                                     dtype=tf.int32),
                           input_type_ids=tf.keras.Input(shape=(None, ),
                                                         dtype=tf.int32))
Esempio n. 7
0
    def __init__(
            self,
            vocab_size,
            hidden_size=768,
            num_layers=12,
            num_attention_heads=12,
            sequence_length=512,
            max_sequence_length=None,
            type_vocab_size=16,
            intermediate_size=3072,
            activation=activations.gelu,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            return_all_encoder_outputs=False,
            output_range=None,
            embedding_width=None,
            **kwargs):
        activation = tf.keras.activations.get(activation)
        initializer = tf.keras.initializers.get(initializer)

        if not max_sequence_length:
            max_sequence_length = sequence_length
        self._self_setattr_tracking = False
        self._config_dict = {
            'vocab_size': vocab_size,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'sequence_length': sequence_length,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'intermediate_size': intermediate_size,
            'activation': tf.keras.activations.serialize(activation),
            'dropout_rate': dropout_rate,
            'attention_dropout_rate': attention_dropout_rate,
            'initializer': tf.keras.initializers.serialize(initializer),
            'return_all_encoder_outputs': return_all_encoder_outputs,
            'output_range': output_range,
            'embedding_width': embedding_width,
        }

        word_ids = tf.keras.layers.Input(shape=(sequence_length, ),
                                         dtype=tf.int32,
                                         name='input_word_ids')
        mask = tf.keras.layers.Input(shape=(sequence_length, ),
                                     dtype=tf.int32,
                                     name='input_mask')
        type_ids = tf.keras.layers.Input(shape=(sequence_length, ),
                                         dtype=tf.int32,
                                         name='input_type_ids')

        if embedding_width is None:
            embedding_width = hidden_size
        self._embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            name='word_embeddings')
        word_embeddings = self._embedding_layer(word_ids)

        # Always uses dynamic slicing for simplicity.
        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            use_dynamic_slicing=True,
            max_sequence_length=max_sequence_length,
            name='position_embedding')
        position_embeddings = self._position_embedding_layer(word_embeddings)
        self._type_embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            use_one_hot=True,
            name='type_embeddings')
        type_embeddings = self._type_embedding_layer(type_ids)

        embeddings = tf.keras.layers.Add()(
            [word_embeddings, position_embeddings, type_embeddings])

        embeddings = (tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)(embeddings))
        embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))

        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        if embedding_width != hidden_size:
            self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
                '...x,xy->...y',
                output_shape=hidden_size,
                bias_axes='y',
                kernel_initializer=initializer,
                name='embedding_projection')
            embeddings = self._embedding_projection(embeddings)

        self._transformer_layers = []
        data = embeddings
        attention_mask = layers.SelfAttentionMask()([data, mask])
        encoder_outputs = []
        for i in range(num_layers):
            if i == num_layers - 1 and output_range is not None:
                transformer_output_range = output_range
            else:
                transformer_output_range = None
            layer = layers.Transformer(
                num_attention_heads=num_attention_heads,
                intermediate_size=intermediate_size,
                intermediate_activation=activation,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                output_range=transformer_output_range,
                kernel_initializer=initializer,
                name='transformer/layer_%d' % i)
            self._transformer_layers.append(layer)
            data = layer([data, attention_mask])
            encoder_outputs.append(data)

        first_token_tensor = (
            tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
                encoder_outputs[-1]))
        self._pooler_layer = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=initializer,
            name='pooler_transform')
        cls_output = self._pooler_layer(first_token_tensor)

        if return_all_encoder_outputs:
            outputs = [encoder_outputs, cls_output]
        else:
            outputs = [encoder_outputs[-1], cls_output]

        super(TransformerEncoder,
              self).__init__(inputs=[word_ids, mask, type_ids],
                             outputs=outputs,
                             **kwargs)
    def __init__(
            self,
            vocab_size: int,
            attention_window: Union[List[int], int] = 512,
            global_attention_size: int = 0,
            pad_token_id: int = 1,
            hidden_size: int = 768,
            num_layers: int = 12,
            num_attention_heads: int = 12,
            max_sequence_length: int = 512,
            type_vocab_size: int = 16,
            inner_dim: int = 3072,
            inner_activation: Callable[..., Any] = _approx_gelu,
            output_dropout: float = 0.1,
            attention_dropout: float = 0.1,
            initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            output_range: Optional[int] = None,
            embedding_width: Optional[int] = None,
            embedding_layer: Optional[tf.keras.layers.Layer] = None,
            norm_first: bool = False,
            **kwargs):
        super().__init__(**kwargs)
        # Longformer args
        self._attention_window = attention_window
        self._global_attention_size = global_attention_size
        self._pad_token_id = pad_token_id

        activation = tf.keras.activations.get(inner_activation)
        initializer = tf.keras.initializers.get(initializer)

        if embedding_width is None:
            embedding_width = hidden_size

        if embedding_layer is None:
            self._embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=vocab_size,
                embedding_width=embedding_width,
                initializer=initializer,
                name='word_embeddings')
        else:
            self._embedding_layer = embedding_layer

        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            max_length=max_sequence_length,
            name='position_embedding')

        self._type_embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            use_one_hot=True,
            name='type_embeddings')

        self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)

        self._embedding_dropout = tf.keras.layers.Dropout(
            rate=output_dropout, name='embedding_dropout')

        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        self._embedding_projection = None
        if embedding_width != hidden_size:
            self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
                '...x,xy->...y',
                output_shape=hidden_size,
                bias_axes='y',
                kernel_initializer=initializer,
                name='embedding_projection')

        self._transformer_layers = []
        self._attention_mask_layer = layers.SelfAttentionMask(
            name='self_attention_mask')
        for i in range(num_layers):
            layer = LongformerEncoderBlock(
                global_attention_size=global_attention_size,
                num_attention_heads=num_attention_heads,
                inner_dim=inner_dim,
                inner_activation=inner_activation,
                attention_window=attention_window[i],
                layer_id=i,
                output_dropout=output_dropout,
                attention_dropout=attention_dropout,
                norm_first=norm_first,
                output_range=output_range if i == num_layers - 1 else None,
                kernel_initializer=initializer,
                name=f'transformer/layer_{i}')
            self._transformer_layers.append(layer)

        self._pooler_layer = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=initializer,
            name='pooler_transform')

        self._config = {
            'vocab_size': vocab_size,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'inner_dim': inner_dim,
            'inner_activation': tf.keras.activations.serialize(activation),
            'output_dropout': output_dropout,
            'attention_dropout': attention_dropout,
            'initializer': tf.keras.initializers.serialize(initializer),
            'output_range': output_range,
            'embedding_width': embedding_width,
            'embedding_layer': embedding_layer,
            'norm_first': norm_first,
            'attention_window': attention_window,
            'global_attention_size': global_attention_size,
            'pad_token_id': pad_token_id,
        }
        self.inputs = dict(input_word_ids=tf.keras.Input(shape=(None, ),
                                                         dtype=tf.int32),
                           input_mask=tf.keras.Input(shape=(None, ),
                                                     dtype=tf.int32),
                           input_type_ids=tf.keras.Input(shape=(None, ),
                                                         dtype=tf.int32))
Esempio n. 9
0
  def __init__(
      self,
      vocab_size,
      hidden_size=768,
      num_layers=12,
      num_attention_heads=12,
      max_sequence_length=512,
      type_vocab_size=16,
      inner_dim=3072,
      inner_activation=lambda x: tf.keras.activations.gelu(x, approximate=True),
      output_dropout=0.1,
      attention_dropout=0.1,
      initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
      output_range=None,
      embedding_width=None,
      embedding_layer=None,
      norm_first=False,
      dict_outputs=False,
      return_all_encoder_outputs=False,
      **kwargs):
    if 'sequence_length' in kwargs:
      kwargs.pop('sequence_length')
      logging.warning('`sequence_length` is a deprecated argument to '
                      '`BertEncoder`, which has no effect for a while. Please '
                      'remove `sequence_length` argument.')

    # Handles backward compatible kwargs.
    if 'intermediate_size' in kwargs:
      inner_dim = kwargs.pop('intermediate_size')

    if 'activation' in kwargs:
      inner_activation = kwargs.pop('activation')

    if 'dropout_rate' in kwargs:
      output_dropout = kwargs.pop('dropout_rate')

    if 'attention_dropout_rate' in kwargs:
      attention_dropout = kwargs.pop('attention_dropout_rate')

    activation = tf.keras.activations.get(inner_activation)
    initializer = tf.keras.initializers.get(initializer)

    word_ids = tf.keras.layers.Input(
        shape=(None,), dtype=tf.int32, name='input_word_ids')
    mask = tf.keras.layers.Input(
        shape=(None,), dtype=tf.int32, name='input_mask')
    type_ids = tf.keras.layers.Input(
        shape=(None,), dtype=tf.int32, name='input_type_ids')

    if embedding_width is None:
      embedding_width = hidden_size

    if embedding_layer is None:
      embedding_layer_inst = layers.OnDeviceEmbedding(
          vocab_size=vocab_size,
          embedding_width=embedding_width,
          initializer=initializer,
          name='word_embeddings')
    else:
      embedding_layer_inst = embedding_layer
    word_embeddings = embedding_layer_inst(word_ids)

    # Always uses dynamic slicing for simplicity.
    position_embedding_layer = layers.PositionEmbedding(
        initializer=initializer,
        max_length=max_sequence_length,
        name='position_embedding')
    position_embeddings = position_embedding_layer(word_embeddings)
    type_embedding_layer = layers.OnDeviceEmbedding(
        vocab_size=type_vocab_size,
        embedding_width=embedding_width,
        initializer=initializer,
        use_one_hot=True,
        name='type_embeddings')
    type_embeddings = type_embedding_layer(type_ids)

    embeddings = tf.keras.layers.Add()(
        [word_embeddings, position_embeddings, type_embeddings])

    embedding_norm_layer = tf.keras.layers.LayerNormalization(
        name='embeddings/layer_norm', axis=-1, epsilon=1e-12, dtype=tf.float32)

    embeddings = embedding_norm_layer(embeddings)
    embeddings = (tf.keras.layers.Dropout(rate=output_dropout)(embeddings))

    # We project the 'embedding' output to 'hidden_size' if it is not already
    # 'hidden_size'.
    if embedding_width != hidden_size:
      embedding_projection = tf.keras.layers.experimental.EinsumDense(
          '...x,xy->...y',
          output_shape=hidden_size,
          bias_axes='y',
          kernel_initializer=initializer,
          name='embedding_projection')
      embeddings = embedding_projection(embeddings)
    else:
      embedding_projection = None

    transformer_layers = []
    data = embeddings
    attention_mask = layers.SelfAttentionMask()(data, mask)
    encoder_outputs = []
    for i in range(num_layers):
      if i == num_layers - 1 and output_range is not None:
        transformer_output_range = output_range
      else:
        transformer_output_range = None
      layer = layers.TransformerEncoderBlock(
          num_attention_heads=num_attention_heads,
          inner_dim=inner_dim,
          inner_activation=inner_activation,
          output_dropout=output_dropout,
          attention_dropout=attention_dropout,
          norm_first=norm_first,
          output_range=transformer_output_range,
          kernel_initializer=initializer,
          name='transformer/layer_%d' % i)
      transformer_layers.append(layer)
      data = layer([data, attention_mask])
      encoder_outputs.append(data)

    last_encoder_output = encoder_outputs[-1]
    # Applying a tf.slice op (through subscript notation) to a Keras tensor
    # like this will create a SliceOpLambda layer. This is better than a Lambda
    # layer with Python code, because that is fundamentally less portable.
    first_token_tensor = last_encoder_output[:, 0, :]
    pooler_layer = tf.keras.layers.Dense(
        units=hidden_size,
        activation='tanh',
        kernel_initializer=initializer,
        name='pooler_transform')
    cls_output = pooler_layer(first_token_tensor)

    outputs = dict(
        sequence_output=encoder_outputs[-1],
        pooled_output=cls_output,
        encoder_outputs=encoder_outputs,
    )

    if dict_outputs:
      super().__init__(
          inputs=[word_ids, mask, type_ids], outputs=outputs, **kwargs)
    else:
      cls_output = outputs['pooled_output']
      if return_all_encoder_outputs:
        encoder_outputs = outputs['encoder_outputs']
        outputs = [encoder_outputs, cls_output]
      else:
        sequence_output = outputs['sequence_output']
        outputs = [sequence_output, cls_output]
      super().__init__(  # pylint: disable=bad-super-call
          inputs=[word_ids, mask, type_ids],
          outputs=outputs,
          **kwargs)

    self._pooler_layer = pooler_layer
    self._transformer_layers = transformer_layers
    self._embedding_norm_layer = embedding_norm_layer
    self._embedding_layer = embedding_layer_inst
    self._position_embedding_layer = position_embedding_layer
    self._type_embedding_layer = type_embedding_layer
    if embedding_projection is not None:
      self._embedding_projection = embedding_projection

    config_dict = {
        'vocab_size': vocab_size,
        'hidden_size': hidden_size,
        'num_layers': num_layers,
        'num_attention_heads': num_attention_heads,
        'max_sequence_length': max_sequence_length,
        'type_vocab_size': type_vocab_size,
        'inner_dim': inner_dim,
        'inner_activation': tf.keras.activations.serialize(activation),
        'output_dropout': output_dropout,
        'attention_dropout': attention_dropout,
        'initializer': tf.keras.initializers.serialize(initializer),
        'output_range': output_range,
        'embedding_width': embedding_width,
        'embedding_layer': embedding_layer,
        'norm_first': norm_first,
        'dict_outputs': dict_outputs,
    }
    # pylint: disable=protected-access
    self._setattr_tracking = False
    self._config = config_dict
    self._setattr_tracking = True
Esempio n. 10
0
    def __init__(
            self,
            vocab_size,
            hidden_size=768,
            num_layers=12,
            num_attention_heads=12,
            max_sequence_length=512,
            type_vocab_size=16,
            inner_dim=3072,
            inner_activation=lambda x: tf.keras.activations.gelu(
                x, approximate=True),
            output_dropout=0.1,
            attention_dropout=0.1,
            pool_type='max',
            pool_stride=2,
            unpool_length=0,
            initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            output_range=None,
            embedding_width=None,
            embedding_layer=None,
            norm_first=False,
            **kwargs):
        super().__init__(**kwargs)
        activation = tf.keras.activations.get(inner_activation)
        initializer = tf.keras.initializers.get(initializer)

        if embedding_width is None:
            embedding_width = hidden_size

        if embedding_layer is None:
            self._embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=vocab_size,
                embedding_width=embedding_width,
                initializer=initializer,
                name='word_embeddings')
        else:
            self._embedding_layer = embedding_layer

        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            max_length=max_sequence_length,
            name='position_embedding')

        self._type_embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            use_one_hot=True,
            name='type_embeddings')

        self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)

        self._embedding_dropout = tf.keras.layers.Dropout(
            rate=output_dropout, name='embedding_dropout')

        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        self._embedding_projection = None
        if embedding_width != hidden_size:
            self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
                '...x,xy->...y',
                output_shape=hidden_size,
                bias_axes='y',
                kernel_initializer=initializer,
                name='embedding_projection')

        self._transformer_layers = []
        self._attention_mask_layer = layers.SelfAttentionMask(
            name='self_attention_mask')
        for i in range(num_layers):
            layer = layers.TransformerEncoderBlock(
                num_attention_heads=num_attention_heads,
                inner_dim=inner_dim,
                inner_activation=inner_activation,
                output_dropout=output_dropout,
                attention_dropout=attention_dropout,
                norm_first=norm_first,
                output_range=output_range if i == num_layers - 1 else None,
                kernel_initializer=initializer,
                name='transformer/layer_%d' % i)
            self._transformer_layers.append(layer)

        self._pooler_layer = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=initializer,
            name='pooler_transform')
        if isinstance(pool_stride, int):
            # TODO(b/197133196): Pooling layer can be shared.
            pool_strides = [pool_stride] * num_layers
        else:
            if len(pool_stride) != num_layers:
                raise ValueError(
                    'Lengths of pool_stride and num_layers are not equal.')
            pool_strides = pool_stride
        # TODO(crickwu): explore tf.keras.layers.serialize method.
        if pool_type == 'max':
            pool_cls = tf.keras.layers.MaxPooling1D
        elif pool_type == 'avg':
            pool_cls = tf.keras.layers.AveragePooling1D
        else:
            raise ValueError('pool_type not supported.')
        self._att_input_pool_layers = []
        for layer_pool_stride in pool_strides:
            att_input_pool_layer = pool_cls(pool_size=layer_pool_stride,
                                            strides=layer_pool_stride,
                                            padding='same',
                                            name='att_input_pool_layer')
            self._att_input_pool_layers.append(att_input_pool_layer)

        self._pool_strides = pool_strides  # This is a list here.
        self._unpool_length = unpool_length

        self._config = {
            'vocab_size': vocab_size,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'inner_dim': inner_dim,
            'inner_activation': tf.keras.activations.serialize(activation),
            'output_dropout': output_dropout,
            'attention_dropout': attention_dropout,
            'initializer': tf.keras.initializers.serialize(initializer),
            'output_range': output_range,
            'embedding_width': embedding_width,
            'embedding_layer': embedding_layer,
            'norm_first': norm_first,
            'pool_type': pool_type,
            'pool_stride': pool_stride,
            'unpool_length': unpool_length,
        }
Esempio n. 11
0
    def __init__(
            self,
            vocab_size: int,
            hidden_size: int = 768,
            num_layers: int = 12,
            num_attention_heads: int = 12,
            max_sequence_length: int = 512,
            type_vocab_size: int = 16,
            inner_dim: int = 3072,
            inner_activation: _Activation = _approx_gelu,
            output_dropout: float = 0.1,
            attention_dropout: float = 0.1,
            initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            output_range: Optional[int] = None,
            embedding_width: Optional[int] = None,
            embedding_layer: Optional[tf.keras.layers.Layer] = None,
            norm_first: bool = False,
            with_dense_inputs: bool = False,
            **kwargs):
        # Pops kwargs that are used in V1 implementation.
        if 'dict_outputs' in kwargs:
            kwargs.pop('dict_outputs')
        if 'return_all_encoder_outputs' in kwargs:
            kwargs.pop('return_all_encoder_outputs')
        if 'intermediate_size' in kwargs:
            inner_dim = kwargs.pop('intermediate_size')
        if 'activation' in kwargs:
            inner_activation = kwargs.pop('activation')
        if 'dropout_rate' in kwargs:
            output_dropout = kwargs.pop('dropout_rate')
        if 'attention_dropout_rate' in kwargs:
            attention_dropout = kwargs.pop('attention_dropout_rate')
        super().__init__(**kwargs)

        activation = tf.keras.activations.get(inner_activation)
        initializer = tf.keras.initializers.get(initializer)

        if embedding_width is None:
            embedding_width = hidden_size

        if embedding_layer is None:
            self._embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=vocab_size,
                embedding_width=embedding_width,
                initializer=initializer,
                name='word_embeddings')
        else:
            self._embedding_layer = embedding_layer

        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            max_length=max_sequence_length,
            name='position_embedding')

        self._type_embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            use_one_hot=True,
            name='type_embeddings')

        self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)

        self._embedding_dropout = tf.keras.layers.Dropout(
            rate=output_dropout, name='embedding_dropout')

        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        self._embedding_projection = None
        if embedding_width != hidden_size:
            self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
                '...x,xy->...y',
                output_shape=hidden_size,
                bias_axes='y',
                kernel_initializer=initializer,
                name='embedding_projection')

        self._transformer_layers = []
        self._attention_mask_layer = layers.SelfAttentionMask(
            name='self_attention_mask')
        for i in range(num_layers):
            layer = layers.TransformerEncoderBlock(
                num_attention_heads=num_attention_heads,
                inner_dim=inner_dim,
                inner_activation=inner_activation,
                output_dropout=output_dropout,
                attention_dropout=attention_dropout,
                norm_first=norm_first,
                output_range=output_range if i == num_layers - 1 else None,
                kernel_initializer=initializer,
                name='transformer/layer_%d' % i)
            self._transformer_layers.append(layer)

        self._pooler_layer = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=initializer,
            name='pooler_transform')

        self._config = {
            'vocab_size': vocab_size,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'inner_dim': inner_dim,
            'inner_activation': tf.keras.activations.serialize(activation),
            'output_dropout': output_dropout,
            'attention_dropout': attention_dropout,
            'initializer': tf.keras.initializers.serialize(initializer),
            'output_range': output_range,
            'embedding_width': embedding_width,
            'embedding_layer': embedding_layer,
            'norm_first': norm_first,
            'with_dense_inputs': with_dense_inputs,
        }
        if with_dense_inputs:
            self.inputs = dict(
                input_word_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32),
                input_mask=tf.keras.Input(shape=(None, ), dtype=tf.int32),
                input_type_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32),
                dense_inputs=tf.keras.Input(shape=(None, embedding_width),
                                            dtype=tf.float32),
                dense_mask=tf.keras.Input(shape=(None, ), dtype=tf.int32),
                dense_type_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32),
            )
        else:
            self.inputs = dict(input_word_ids=tf.keras.Input(shape=(None, ),
                                                             dtype=tf.int32),
                               input_mask=tf.keras.Input(shape=(None, ),
                                                         dtype=tf.int32),
                               input_type_ids=tf.keras.Input(shape=(None, ),
                                                             dtype=tf.int32))
Esempio n. 12
0
    def __init__(self,
                 network,
                 bert_config,
                 initializer='glorot_uniform',
                 seq_length=128,
                 use_pointing=True,
                 is_training=True):
        """Creates Felix Tagger.

    Setting up all of the layers needed for call.

    Args:
      network: An encoder network, which should output a sequence of hidden
               states.
      bert_config: A config file which in addition to the  BertConfig values
      also includes: num_classes, hidden_dropout_prob, and query_transformer.
      initializer: The initializer (if any) to use in the classification
                   networks. Defaults to a Glorot uniform initializer.
      seq_length:  Maximum sequence length.
      use_pointing: Whether a pointing network is used.
      is_training: The model is being trained.
    """

        super(FelixTagger, self).__init__()
        self._network = network
        self._seq_length = seq_length
        self._bert_config = bert_config
        self._use_pointing = use_pointing
        self._is_training = is_training

        self._tag_logits_layer = tf.keras.layers.Dense(
            self._bert_config.num_classes)
        if not self._use_pointing:
            return

        # An arbitrary heuristic (sqrt vocab size) for the tag embedding dimension.
        self._tag_embedding_layer = tf.keras.layers.Embedding(
            self._bert_config.num_classes,
            int(math.ceil(math.sqrt(self._bert_config.num_classes))),
            input_length=seq_length)

        self._position_embedding_layer = layers.PositionEmbedding(
            max_length=seq_length)
        self._edit_tagged_sequence_output_layer = tf.keras.layers.Dense(
            self._bert_config.hidden_size, activation=activations.gelu)

        if self._bert_config.query_transformer:
            self._self_attention_mask_layer = layers.SelfAttentionMask()
            self._transformer_query_layer = layers.TransformerEncoderBlock(
                num_attention_heads=self._bert_config.num_attention_heads,
                inner_dim=self._bert_config.intermediate_size,
                inner_activation=activations.gelu,
                output_dropout=self._bert_config.hidden_dropout_prob,
                attention_dropout=self._bert_config.hidden_dropout_prob,
                output_range=seq_length,
            )

        self._query_embeddings_layer = tf.keras.layers.Dense(
            self._bert_config.query_size)

        self._key_embeddings_layer = tf.keras.layers.Dense(
            self._bert_config.query_size)
Esempio n. 13
0
    def __init__(
            self,
            vocab_size: int,
            hidden_size: int = 768,
            num_layers: int = 12,
            num_attention_heads: int = 12,
            max_sequence_length: int = 512,
            type_vocab_size: int = 16,
            inner_dim: int = 3072,
            inner_activation: _Activation = _approx_gelu,
            output_dropout: float = 0.1,
            attention_dropout: float = 0.1,
            token_loss_init_value: float = 10.0,
            token_loss_beta: float = 0.995,
            token_keep_k: int = 256,
            token_allow_list: Tuple[int, ...] = (100, 101, 102, 103),
            token_deny_list: Tuple[int, ...] = (0, ),
            initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            output_range: Optional[int] = None,
            embedding_width: Optional[int] = None,
            embedding_layer: Optional[tf.keras.layers.Layer] = None,
            norm_first: bool = False,
            with_dense_inputs: bool = False,
            **kwargs):
        # Pops kwargs that are used in V1 implementation.
        if 'dict_outputs' in kwargs:
            kwargs.pop('dict_outputs')
        if 'return_all_encoder_outputs' in kwargs:
            kwargs.pop('return_all_encoder_outputs')
        if 'intermediate_size' in kwargs:
            inner_dim = kwargs.pop('intermediate_size')
        if 'activation' in kwargs:
            inner_activation = kwargs.pop('activation')
        if 'dropout_rate' in kwargs:
            output_dropout = kwargs.pop('dropout_rate')
        if 'attention_dropout_rate' in kwargs:
            attention_dropout = kwargs.pop('attention_dropout_rate')
        super().__init__(**kwargs)

        activation = tf.keras.activations.get(inner_activation)
        initializer = tf.keras.initializers.get(initializer)

        if embedding_width is None:
            embedding_width = hidden_size

        if embedding_layer is None:
            self._embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=vocab_size,
                embedding_width=embedding_width,
                initializer=tf_utils.clone_initializer(initializer),
                name='word_embeddings')
        else:
            self._embedding_layer = embedding_layer

        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=tf_utils.clone_initializer(initializer),
            max_length=max_sequence_length,
            name='position_embedding')

        self._type_embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=tf_utils.clone_initializer(initializer),
            use_one_hot=True,
            name='type_embeddings')

        self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)

        self._embedding_dropout = tf.keras.layers.Dropout(
            rate=output_dropout, name='embedding_dropout')

        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        self._embedding_projection = None
        if embedding_width != hidden_size:
            self._embedding_projection = tf.keras.layers.EinsumDense(
                '...x,xy->...y',
                output_shape=hidden_size,
                bias_axes='y',
                kernel_initializer=tf_utils.clone_initializer(initializer),
                name='embedding_projection')

        # The first 999 tokens are special tokens such as [PAD], [CLS], [SEP].
        # We want to always mask [PAD], and always not to maks [CLS], [SEP].
        init_importance = tf.constant(token_loss_init_value,
                                      shape=(vocab_size))
        if token_allow_list:
            init_importance = tf.tensor_scatter_nd_update(
                tensor=init_importance,
                indices=[[x] for x in token_allow_list],
                updates=[1.0e4 for x in token_allow_list])
        if token_deny_list:
            init_importance = tf.tensor_scatter_nd_update(
                tensor=init_importance,
                indices=[[x] for x in token_deny_list],
                updates=[-1.0e4 for x in token_deny_list])
        self._token_importance_embed = layers.TokenImportanceWithMovingAvg(
            vocab_size=vocab_size,
            init_importance=init_importance,
            moving_average_beta=token_loss_beta)

        self._token_separator = layers.SelectTopK(top_k=token_keep_k)
        self._transformer_layers = []
        self._num_layers = num_layers
        self._attention_mask_layer = layers.SelfAttentionMask(
            name='self_attention_mask')
        for i in range(num_layers):
            layer = layers.TransformerEncoderBlock(
                num_attention_heads=num_attention_heads,
                inner_dim=inner_dim,
                inner_activation=inner_activation,
                output_dropout=output_dropout,
                attention_dropout=attention_dropout,
                norm_first=norm_first,
                output_range=output_range if i == num_layers - 1 else None,
                kernel_initializer=tf_utils.clone_initializer(initializer),
                name='transformer/layer_%d' % i)
            self._transformer_layers.append(layer)

        self._pooler_layer = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=tf_utils.clone_initializer(initializer),
            name='pooler_transform')

        self._config = {
            'vocab_size': vocab_size,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'inner_dim': inner_dim,
            'inner_activation': tf.keras.activations.serialize(activation),
            'output_dropout': output_dropout,
            'attention_dropout': attention_dropout,
            'token_loss_init_value': token_loss_init_value,
            'token_loss_beta': token_loss_beta,
            'token_keep_k': token_keep_k,
            'token_allow_list': token_allow_list,
            'token_deny_list': token_deny_list,
            'initializer': tf.keras.initializers.serialize(initializer),
            'output_range': output_range,
            'embedding_width': embedding_width,
            'embedding_layer': embedding_layer,
            'norm_first': norm_first,
            'with_dense_inputs': with_dense_inputs,
        }
        if with_dense_inputs:
            self.inputs = dict(
                input_word_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32),
                input_mask=tf.keras.Input(shape=(None, ), dtype=tf.int32),
                input_type_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32),
                dense_inputs=tf.keras.Input(shape=(None, embedding_width),
                                            dtype=tf.float32),
                dense_mask=tf.keras.Input(shape=(None, ), dtype=tf.int32),
                dense_type_ids=tf.keras.Input(shape=(None, ), dtype=tf.int32),
            )
        else:
            self.inputs = dict(input_word_ids=tf.keras.Input(shape=(None, ),
                                                             dtype=tf.int32),
                               input_mask=tf.keras.Input(shape=(None, ),
                                                         dtype=tf.int32),
                               input_type_ids=tf.keras.Input(shape=(None, ),
                                                             dtype=tf.int32))
Esempio n. 14
0
    def __init__(
            self,
            vocab_size,
            embedding_width=128,
            hidden_size=768,
            num_layers=12,
            num_attention_heads=12,
            max_sequence_length=512,
            type_vocab_size=16,
            intermediate_size=3072,
            activation=activations.gelu,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            **kwargs):
        activation = tf.keras.activations.get(activation)
        initializer = tf.keras.initializers.get(initializer)

        self._self_setattr_tracking = False
        self._config_dict = {
            'vocab_size': vocab_size,
            'embedding_width': embedding_width,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'intermediate_size': intermediate_size,
            'activation': tf.keras.activations.serialize(activation),
            'dropout_rate': dropout_rate,
            'attention_dropout_rate': attention_dropout_rate,
            'initializer': tf.keras.initializers.serialize(initializer),
        }

        word_ids = tf.keras.layers.Input(shape=(None, ),
                                         dtype=tf.int32,
                                         name='input_word_ids')
        mask = tf.keras.layers.Input(shape=(None, ),
                                     dtype=tf.int32,
                                     name='input_mask')
        type_ids = tf.keras.layers.Input(shape=(None, ),
                                         dtype=tf.int32,
                                         name='input_type_ids')

        if embedding_width is None:
            embedding_width = hidden_size
        self._embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            name='word_embeddings')
        word_embeddings = self._embedding_layer(word_ids)

        # Always uses dynamic slicing for simplicity.
        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            use_dynamic_slicing=True,
            max_sequence_length=max_sequence_length,
            name='position_embedding')
        position_embeddings = self._position_embedding_layer(word_embeddings)

        type_embeddings = (layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            use_one_hot=True,
            name='type_embeddings')(type_ids))

        embeddings = tf.keras.layers.Add()(
            [word_embeddings, position_embeddings, type_embeddings])
        embeddings = (tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)(embeddings))
        embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        if embedding_width != hidden_size:
            embeddings = tf.keras.layers.experimental.EinsumDense(
                '...x,xy->...y',
                output_shape=hidden_size,
                bias_axes='y',
                kernel_initializer=initializer,
                name='embedding_projection')(embeddings)

        data = embeddings
        attention_mask = layers.SelfAttentionMask()([data, mask])
        shared_layer = keras_nlp.TransformerEncoderBlock(
            num_attention_heads=num_attention_heads,
            inner_dim=intermediate_size,
            inner_activation=activation,
            output_dropout=dropout_rate,
            attention_dropout=attention_dropout_rate,
            kernel_initializer=initializer,
            name='transformer')
        for _ in range(num_layers):
            data = shared_layer([data, attention_mask])

        first_token_tensor = (tf.keras.layers.Lambda(
            lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data))
        cls_output = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=initializer,
            name='pooler_transform')(first_token_tensor)

        super(AlbertTransformerEncoder,
              self).__init__(inputs=[word_ids, mask, type_ids],
                             outputs=[data, cls_output],
                             **kwargs)
Esempio n. 15
0
    def __init__(
            self,
            vocab_size,
            embedding_width=128,
            hidden_size=768,
            num_layers=12,
            num_attention_heads=12,
            max_sequence_length=512,
            type_vocab_size=16,
            intermediate_size=3072,
            activation=activations.gelu,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            dict_outputs=False,
            **kwargs):
        activation = tf.keras.activations.get(activation)
        initializer = tf.keras.initializers.get(initializer)

        word_ids = tf.keras.layers.Input(shape=(None, ),
                                         dtype=tf.int32,
                                         name='input_word_ids')
        mask = tf.keras.layers.Input(shape=(None, ),
                                     dtype=tf.int32,
                                     name='input_mask')
        type_ids = tf.keras.layers.Input(shape=(None, ),
                                         dtype=tf.int32,
                                         name='input_type_ids')

        if embedding_width is None:
            embedding_width = hidden_size
        embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            name='word_embeddings')
        word_embeddings = embedding_layer(word_ids)

        # Always uses dynamic slicing for simplicity.
        position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            max_length=max_sequence_length,
            name='position_embedding')
        position_embeddings = position_embedding_layer(word_embeddings)

        type_embeddings = (layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=initializer,
            use_one_hot=True,
            name='type_embeddings')(type_ids))

        embeddings = tf.keras.layers.Add()(
            [word_embeddings, position_embeddings, type_embeddings])
        embeddings = (tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)(embeddings))
        embeddings = (tf.keras.layers.Dropout(rate=dropout_rate)(embeddings))
        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        if embedding_width != hidden_size:
            embeddings = tf.keras.layers.experimental.EinsumDense(
                '...x,xy->...y',
                output_shape=hidden_size,
                bias_axes='y',
                kernel_initializer=initializer,
                name='embedding_projection')(embeddings)

        data = embeddings
        attention_mask = layers.SelfAttentionMask()(data, mask)
        shared_layer = layers.TransformerEncoderBlock(
            num_attention_heads=num_attention_heads,
            inner_dim=intermediate_size,
            inner_activation=activation,
            output_dropout=dropout_rate,
            attention_dropout=attention_dropout_rate,
            kernel_initializer=initializer,
            name='transformer')
        encoder_outputs = []
        for _ in range(num_layers):
            data = shared_layer([data, attention_mask])
            encoder_outputs.append(data)

        # Applying a tf.slice op (through subscript notation) to a Keras tensor
        # like this will create a SliceOpLambda layer. This is better than a Lambda
        # layer with Python code, because that is fundamentally less portable.
        first_token_tensor = data[:, 0, :]
        cls_output = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=initializer,
            name='pooler_transform')(first_token_tensor)
        if dict_outputs:
            outputs = dict(
                sequence_output=data,
                encoder_outputs=encoder_outputs,
                pooled_output=cls_output,
            )
        else:
            outputs = [data, cls_output]

        # b/164516224
        # Once we've created the network using the Functional API, we call
        # super().__init__ as though we were invoking the Functional API Model
        # constructor, resulting in this object having all the properties of a model
        # created using the Functional API. Once super().__init__ is called, we
        # can assign attributes to `self` - note that all `self` assignments are
        # below this line.
        super(AlbertEncoder, self).__init__(inputs=[word_ids, mask, type_ids],
                                            outputs=outputs,
                                            **kwargs)
        config_dict = {
            'vocab_size': vocab_size,
            'embedding_width': embedding_width,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'intermediate_size': intermediate_size,
            'activation': tf.keras.activations.serialize(activation),
            'dropout_rate': dropout_rate,
            'attention_dropout_rate': attention_dropout_rate,
            'initializer': tf.keras.initializers.serialize(initializer),
        }

        # We are storing the config dict as a namedtuple here to ensure checkpoint
        # compatibility with an earlier version of this model which did not track
        # the config dict attribute. TF does not track immutable attrs which
        # do not contain Trackables, so by creating a config namedtuple instead of
        # a dict we avoid tracking it.
        config_cls = collections.namedtuple('Config', config_dict.keys())
        self._config = config_cls(**config_dict)
        self._embedding_layer = embedding_layer
        self._position_embedding_layer = position_embedding_layer
Esempio n. 16
0
  def __init__(
      self,
      pooled_output_dim,
      pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
          stddev=0.02),
      embedding_cls=None,
      embedding_cfg=None,
      embedding_data=None,
      num_hidden_instances=1,
      hidden_cls=layers.Transformer,
      hidden_cfg=None,
      return_all_layer_outputs=False,
      **kwargs):
    self._self_setattr_tracking = False
    self._hidden_cls = hidden_cls
    self._hidden_cfg = hidden_cfg
    self._num_hidden_instances = num_hidden_instances
    self._pooled_output_dim = pooled_output_dim
    self._pooler_layer_initializer = pooler_layer_initializer
    self._embedding_cls = embedding_cls
    self._embedding_cfg = embedding_cfg
    self._embedding_data = embedding_data
    self._return_all_layer_outputs = return_all_layer_outputs
    self._kwargs = kwargs

    if embedding_cls:
      if inspect.isclass(embedding_cls):
        self._embedding_network = embedding_cls(
            **embedding_cfg) if embedding_cfg else embedding_cls()
      else:
        self._embedding_network = embedding_cls
      inputs = self._embedding_network.inputs
      embeddings, mask = self._embedding_network(inputs)
    else:
      self._embedding_network = None
      word_ids = tf.keras.layers.Input(
          shape=(embedding_cfg['seq_length'],),
          dtype=tf.int32,
          name='input_word_ids')
      mask = tf.keras.layers.Input(
          shape=(embedding_cfg['seq_length'],),
          dtype=tf.int32,
          name='input_mask')
      type_ids = tf.keras.layers.Input(
          shape=(embedding_cfg['seq_length'],),
          dtype=tf.int32,
          name='input_type_ids')
      inputs = [word_ids, mask, type_ids]

      self._embedding_layer = layers.OnDeviceEmbedding(
          vocab_size=embedding_cfg['vocab_size'],
          embedding_width=embedding_cfg['hidden_size'],
          initializer=embedding_cfg['initializer'],
          name='word_embeddings')

      word_embeddings = self._embedding_layer(word_ids)

      # Always uses dynamic slicing for simplicity.
      self._position_embedding_layer = layers.PositionEmbedding(
          initializer=embedding_cfg['initializer'],
          use_dynamic_slicing=True,
          max_sequence_length=embedding_cfg['max_seq_length'],
          name='position_embedding')
      position_embeddings = self._position_embedding_layer(word_embeddings)

      type_embeddings = (
          layers.OnDeviceEmbedding(
              vocab_size=embedding_cfg['type_vocab_size'],
              embedding_width=embedding_cfg['hidden_size'],
              initializer=embedding_cfg['initializer'],
              use_one_hot=True,
              name='type_embeddings')(type_ids))

      embeddings = tf.keras.layers.Add()(
          [word_embeddings, position_embeddings, type_embeddings])
      embeddings = (
          tf.keras.layers.LayerNormalization(
              name='embeddings/layer_norm',
              axis=-1,
              epsilon=1e-12,
              dtype=tf.float32)(embeddings))
      embeddings = (
          tf.keras.layers.Dropout(
              rate=embedding_cfg['dropout_rate'])(embeddings))

    attention_mask = layers.SelfAttentionMask()([embeddings, mask])
    data = embeddings

    layer_output_data = []
    self._hidden_layers = []
    for _ in range(num_hidden_instances):
      if inspect.isclass(hidden_cls):
        layer = hidden_cls(**hidden_cfg) if hidden_cfg else hidden_cls()
      else:
        layer = hidden_cls
      data = layer([data, attention_mask])
      layer_output_data.append(data)
      self._hidden_layers.append(layer)

    first_token_tensor = (
        tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
            layer_output_data[-1]))
    self._pooler_layer = tf.keras.layers.Dense(
        units=pooled_output_dim,
        activation='tanh',
        kernel_initializer=pooler_layer_initializer,
        name='cls_transform')
    cls_output = self._pooler_layer(first_token_tensor)

    if return_all_layer_outputs:
      outputs = [layer_output_data, cls_output]
    else:
      outputs = [layer_output_data[-1], cls_output]

    super(EncoderScaffold, self).__init__(
        inputs=inputs, outputs=outputs, **kwargs)
Esempio n. 17
0
    def __init__(
            self,
            vocab_size,
            hidden_size=768,
            num_layers=12,
            num_attention_heads=12,
            sequence_length=512,
            max_sequence_length=None,
            type_vocab_size=16,
            intermediate_size=3072,
            activation=activations.gelu,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            float_dtype='float32',
            **kwargs):
        activation = tf.keras.activations.get(activation)
        initializer = tf.keras.initializers.get(initializer)

        if not max_sequence_length:
            max_sequence_length = sequence_length
        self._self_setattr_tracking = False
        self._config_dict = {
            'vocab_size': vocab_size,
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'num_attention_heads': num_attention_heads,
            'sequence_length': sequence_length,
            'max_sequence_length': max_sequence_length,
            'type_vocab_size': type_vocab_size,
            'intermediate_size': intermediate_size,
            'activation': tf.keras.activations.serialize(activation),
            'dropout_rate': dropout_rate,
            'attention_dropout_rate': attention_dropout_rate,
            'initializer': tf.keras.initializers.serialize(initializer),
            'float_dtype': float_dtype,
        }

        word_ids = tf.keras.layers.Input(shape=(sequence_length, ),
                                         dtype=tf.int32,
                                         name='input_word_ids')
        mask = tf.keras.layers.Input(shape=(sequence_length, ),
                                     dtype=tf.int32,
                                     name='input_mask')
        type_ids = tf.keras.layers.Input(shape=(sequence_length, ),
                                         dtype=tf.int32,
                                         name='input_type_ids')

        self._embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=vocab_size,
            embedding_width=hidden_size,
            initializer=initializer,
            name='word_embeddings')
        word_embeddings = self._embedding_layer(word_ids)

        # Always uses dynamic slicing for simplicity.
        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=initializer,
            use_dynamic_slicing=True,
            max_sequence_length=max_sequence_length)
        position_embeddings = self._position_embedding_layer(word_embeddings)

        type_embeddings = (layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=hidden_size,
            initializer=initializer,
            use_one_hot=True,
            name='type_embeddings')(type_ids))

        embeddings = tf.keras.layers.Add()(
            [word_embeddings, position_embeddings, type_embeddings])
        embeddings = (tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)(embeddings))
        embeddings = (tf.keras.layers.Dropout(rate=dropout_rate,
                                              dtype=tf.float32)(embeddings))

        if float_dtype == 'float16':
            embeddings = tf.cast(embeddings, tf.float16)

        data = embeddings
        attention_mask = MakeAttentionMaskLayer()([data, mask])
        for i in range(num_layers):
            layer = layers.Transformer(
                num_attention_heads=num_attention_heads,
                intermediate_size=intermediate_size,
                intermediate_activation=activation,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                kernel_initializer=initializer,
                dtype=float_dtype,
                name='transformer/layer_%d' % i)
            data = layer([data, attention_mask])

        first_token_tensor = (tf.keras.layers.Lambda(
            lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(data))
        cls_output = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=initializer,
            name='pooler_transform')(first_token_tensor)

        super(TransformerEncoder,
              self).__init__(inputs=[word_ids, mask, type_ids],
                             outputs=[data, cls_output],
                             **kwargs)
Esempio n. 18
0
          dtype=tf.int32,
          name='input_type_ids')
      inputs = [word_ids, mask, type_ids]

      self._embedding_layer = layers.OnDeviceEmbedding(
          vocab_size=embedding_cfg['vocab_size'],
          embedding_width=embedding_cfg['hidden_size'],
          initializer=embedding_cfg['initializer'],
          name='word_embeddings')

      word_embeddings = self._embedding_layer(word_ids)

      # Always uses dynamic slicing for simplicity.
      self._position_embedding_layer = layers.PositionEmbedding(
          initializer=embedding_cfg['initializer'],
          use_dynamic_slicing=True,
          max_sequence_length=embedding_cfg['max_seq_length'],
          name='position_embedding')
      position_embeddings = self._position_embedding_layer(word_embeddings)

      type_embeddings = (
          layers.OnDeviceEmbedding(
              vocab_size=embedding_cfg['type_vocab_size'],
              embedding_width=embedding_cfg['hidden_size'],
              initializer=embedding_cfg['initializer'],
              use_one_hot=True,
              name='type_embeddings')(type_ids))

      embeddings = tf.keras.layers.Add()(
          [word_embeddings, position_embeddings, type_embeddings])
      embeddings = (