示例#1
0
    def build(self, input_shape):
        self._vocab_size, embedding_width = self.embedding_table.shape
        hidden_size = input_shape[-1]
        self.dense = tf.keras.layers.Dense(
            hidden_size,
            activation=self.activation,
            kernel_initializer=tf_utils.clone_initializer(self.initializer),
            name='transform/dense')

        if hidden_size > embedding_width:
            self.extra_output_weights = self.add_weight(
                'extra_output_weights',
                shape=(self._vocab_size, hidden_size - embedding_width),
                initializer=tf_utils.clone_initializer(self.initializer),
                trainable=True)
        elif hidden_size == embedding_width:
            self.extra_output_weights = None
        else:
            raise ValueError(
                'hidden size %d cannot be smaller than embedding width %d.' %
                (hidden_size, embedding_width))

        self.layer_norm = tf.keras.layers.LayerNormalization(
            axis=-1, epsilon=1e-12, name='transform/LayerNorm')
        self.bias = self.add_weight('output_bias/bias',
                                    shape=(self._vocab_size, ),
                                    initializer='zeros',
                                    trainable=True)

        super(MobileBertMaskedLM, self).build(input_shape)
示例#2
0
  def _make_output_dense(self, free_dims, common_kwargs, name=None,
                         use_bias=True):
    """Builds the output projection matrix.

    Args:
      free_dims: Number of free dimensions for einsum equation building.
      common_kwargs: Common keyword arguments for einsum layer.
      name: Name for the projection layer.
      use_bias: Use bias if self._use_bias is true

    Returns:
      Projection layer.
    """
    if self._output_shape:
      if not isinstance(self._output_shape, collections.abc.Sized):
        output_shape = [self._output_shape]
      else:
        output_shape = self._output_shape
    else:
      output_shape = [self._query_shape[-1]]
    einsum_equation, bias_axes, output_rank = _build_proj_equation(
        free_dims, bound_dims=2, output_dims=len(output_shape))
    return tf.keras.layers.EinsumDense(
        einsum_equation,
        output_shape=_get_output_shape(output_rank - 1, output_shape),
        bias_axes=bias_axes if (use_bias and self._use_bias) else None,
        name=name,
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
        bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
        **common_kwargs)
 def build(self, unused_input_shapes):
     common_kwargs = dict(kernel_regularizer=self._kernel_regularizer,
                          bias_regularizer=self._bias_regularizer,
                          activity_regularizer=self._activity_regularizer,
                          kernel_constraint=self._kernel_constraint,
                          bias_constraint=self._bias_constraint)
     self._query_dense = tf.keras.layers.EinsumDense(
         "BAE,ENH->BANH",
         output_shape=(None, self._num_heads, self._head_size),
         bias_axes="NH",
         name="query",
         kernel_initializer=tf_utils.clone_initializer(
             self._kernel_initializer),
         bias_initializer=tf_utils.clone_initializer(
             self._bias_initializer),
         **common_kwargs)
     self._key_dense = tf.keras.layers.EinsumDense(
         "BAE,ENH->BANH",
         output_shape=(None, self._num_heads, self._head_size),
         bias_axes="NH",
         name="key",
         kernel_initializer=tf_utils.clone_initializer(
             self._kernel_initializer),
         bias_initializer=tf_utils.clone_initializer(
             self._bias_initializer),
         **common_kwargs)
     super().build(unused_input_shapes)
示例#4
0
    def __init__(
            self,
            word_vocab_size,
            word_embed_size,
            type_vocab_size,
            output_embed_size,
            max_sequence_length=512,
            normalization_type='no_norm',
            initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
            dropout_rate=0.1,
            **kwargs):
        """Class initialization.

    Args:
      word_vocab_size: Number of words in the vocabulary.
      word_embed_size: Word embedding size.
      type_vocab_size: Number of word types.
      output_embed_size: Embedding size for the final embedding output.
      max_sequence_length: Maximum length of input sequence.
      normalization_type: String. The type of normalization_type, only
        `no_norm` and `layer_norm` are supported.
      initializer: The initializer to use for the embedding weights and
        linear projection weights.
      dropout_rate: Dropout rate.
      **kwargs: keyword arguments.
    """
        super(MobileBertEmbedding, self).__init__(**kwargs)
        self.word_vocab_size = word_vocab_size
        self.word_embed_size = word_embed_size
        self.type_vocab_size = type_vocab_size
        self.output_embed_size = output_embed_size
        self.max_sequence_length = max_sequence_length
        self.normalization_type = normalization_type
        self.initializer = tf.keras.initializers.get(initializer)
        self.dropout_rate = dropout_rate

        self.word_embedding = on_device_embedding.OnDeviceEmbedding(
            self.word_vocab_size,
            self.word_embed_size,
            initializer=tf_utils.clone_initializer(self.initializer),
            name='word_embedding')
        self.type_embedding = on_device_embedding.OnDeviceEmbedding(
            self.type_vocab_size,
            self.output_embed_size,
            initializer=tf_utils.clone_initializer(self.initializer),
            name='type_embedding')
        self.pos_embedding = position_embedding.PositionEmbedding(
            max_length=max_sequence_length,
            initializer=tf_utils.clone_initializer(self.initializer),
            name='position_embedding')
        self.word_embedding_proj = tf.keras.layers.EinsumDense(
            'abc,cd->abd',
            output_shape=[None, self.output_embed_size],
            kernel_initializer=tf_utils.clone_initializer(self.initializer),
            bias_axes='d',
            name='embedding_projection')
        self.layer_norm = _get_norm_layer(normalization_type, 'embedding_norm')
        self.dropout_layer = tf.keras.layers.Dropout(self.dropout_rate,
                                                     name='embedding_dropout')
示例#5
0
    def __init__(self,
                 generator_network,
                 discriminator_network,
                 vocab_size,
                 num_classes,
                 num_token_predictions,
                 mlm_activation=None,
                 mlm_initializer='glorot_uniform',
                 output_type='logits',
                 disallow_correct=False,
                 **kwargs):
        super(ElectraPretrainer, self).__init__()
        self._config = {
            'generator_network': generator_network,
            'discriminator_network': discriminator_network,
            'vocab_size': vocab_size,
            'num_classes': num_classes,
            'num_token_predictions': num_token_predictions,
            'mlm_activation': mlm_activation,
            'mlm_initializer': mlm_initializer,
            'output_type': output_type,
            'disallow_correct': disallow_correct,
        }
        for k, v in kwargs.items():
            self._config[k] = v

        self.generator_network = generator_network
        self.discriminator_network = discriminator_network
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        self.num_token_predictions = num_token_predictions
        self.mlm_activation = mlm_activation
        self.mlm_initializer = mlm_initializer
        self.output_type = output_type
        self.disallow_correct = disallow_correct
        self.masked_lm = layers.MaskedLM(
            embedding_table=generator_network.get_embedding_table(),
            activation=mlm_activation,
            initializer=tf_utils.clone_initializer(mlm_initializer),
            output=output_type,
            name='generator_masked_lm')
        self.classification = layers.ClassificationHead(
            inner_dim=generator_network.get_config()['hidden_size'],
            num_classes=num_classes,
            initializer=tf_utils.clone_initializer(mlm_initializer),
            name='generator_classification_head')
        self.discriminator_projection = tf.keras.layers.Dense(
            units=discriminator_network.get_config()['hidden_size'],
            activation=mlm_activation,
            kernel_initializer=tf_utils.clone_initializer(mlm_initializer),
            name='discriminator_projection_head')
        self.discriminator_head = tf.keras.layers.Dense(
            units=1,
            kernel_initializer=tf_utils.clone_initializer(mlm_initializer))
示例#6
0
    def build(self, input_shape: Tuple[int, ...]) -> None:
        """Builds GroupConv2D layer as a collection of smaller Conv2D layers."""
        input_shape = tf.TensorShape(input_shape)
        input_channel = self._get_input_channel(input_shape)
        if input_channel % self._groups != 0:
            raise ValueError(
                f'Number of input channels: {input_channel} are not divisible '
                f'by number of groups: {self._groups}.')

        self.group_input_channel = int(input_channel / self._groups)
        self.group_output_channel = int(self.filters / self._groups)
        self.group_kernel_shape = self.kernel_size + (
            self.group_input_channel, self.group_output_channel)

        self.kernel = []
        self.bias = []
        for g in range(self._groups):
            self.kernel.append(
                self.add_weight(name='kernel_{}'.format(g),
                                shape=self.group_kernel_shape,
                                initializer=tf_utils.clone_initializer(
                                    self.kernel_initializer),
                                regularizer=self.kernel_regularizer,
                                constraint=self.kernel_constraint,
                                trainable=True,
                                dtype=self.dtype))
            if self.use_bias:
                self.bias.append(
                    self.add_weight(name='bias_{}'.format(g),
                                    shape=(self.group_output_channel, ),
                                    initializer=tf_utils.clone_initializer(
                                        self.bias_initializer),
                                    regularizer=self.bias_regularizer,
                                    constraint=self.bias_constraint,
                                    trainable=True,
                                    dtype=self.dtype))
        channel_axis = self._get_channel_axis()
        self.input_spec = tf.keras.layers.InputSpec(
            ndim=self.rank + 2, axes={channel_axis: input_channel})

        self._build_conv_op_data_shape = input_shape[-(self.rank + 1):]
        self._build_input_channel = input_channel
        self._padding_op = self._get_padding_op()
        # channels_last corresponds to 'NHWC' data format.
        self._conv_op_data_format = 'NHWC'

        self.bn_layers = []
        if self.use_batch_norm:
            for group_index in range(self._groups):
                self.bn_layers.append(self.batch_norm_layer[group_index])

        self.built = True
示例#7
0
    def __init__(self,
                 input_width,
                 start_n_top=5,
                 end_n_top=5,
                 activation='tanh',
                 dropout_rate=0.,
                 initializer='glorot_uniform',
                 **kwargs):
        super().__init__(**kwargs)
        self._config = {
            'input_width': input_width,
            'activation': activation,
            'initializer': initializer,
            'start_n_top': start_n_top,
            'end_n_top': end_n_top,
            'dropout_rate': dropout_rate,
        }
        if start_n_top <= 1:
            raise ValueError('`start_n_top` must be greater than 1.')
        self._start_n_top = start_n_top
        self._end_n_top = end_n_top
        self.start_logits_dense = tf.keras.layers.Dense(
            units=1,
            kernel_initializer=tf_utils.clone_initializer(initializer),
            name='predictions/transform/start_logits')

        self.end_logits_inner_dense = tf.keras.layers.Dense(
            units=input_width,
            kernel_initializer=tf_utils.clone_initializer(initializer),
            activation=activation,
            name='predictions/transform/end_logits/inner')
        self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
            axis=-1,
            epsilon=1e-12,
            name='predictions/transform/end_logits/layernorm')
        self.end_logits_output_dense = tf.keras.layers.Dense(
            units=1,
            kernel_initializer=tf_utils.clone_initializer(initializer),
            name='predictions/transform/end_logits/output')

        self.answer_logits_inner = tf.keras.layers.Dense(
            units=input_width,
            kernel_initializer=tf_utils.clone_initializer(initializer),
            activation=activation,
            name='predictions/transform/answer_logits/inner')
        self.answer_logits_dropout = tf.keras.layers.Dropout(rate=dropout_rate)
        self.answer_logits_output = tf.keras.layers.Dense(
            units=1,
            kernel_initializer=tf_utils.clone_initializer(initializer),
            use_bias=False,
            name='predictions/transform/answer_logits/output')
示例#8
0
    def build(self, input_shape: List[int]) -> None:
        # Disable the attribute-defined-outside-init violations in this function
        # pylint: disable=attribute-defined-outside-init
        if input_shape[-1] is None:
            raise ValueError(
                'The last dimension of the inputs to `TNExpandCondense` '
                'should be defined. Found `None`.')

        super(TNExpandCondense, self).build(input_shape)

        self.proj_size = self.proj_multiplier * input_shape[-1]

        assert (self.proj_size //
                input_shape[-1]) * input_shape[-1] == self.proj_size, (
                    f'{self.proj_size} / {input_shape[-1]} must be '
                    f'round')
        assert (input_shape[-1] // 128) * 128 == input_shape[
            -1], f'{input_shape[-1]} / 128 must be round'

        self.w1 = self.add_weight(name='w1',
                                  shape=(input_shape[-1], input_shape[-1]),
                                  trainable=True,
                                  initializer=tf_utils.clone_initializer(
                                      self.kernel_initializer))

        self.w2 = self.add_weight(
            name='w2',
            shape=(128, (128 * (self.proj_size // input_shape[-1]))),
            trainable=True,
            initializer=tf_utils.clone_initializer(self.kernel_initializer))

        self.w3 = self.add_weight(
            name='w3',
            shape=(128 * (self.proj_size // input_shape[-1]), 128),
            trainable=True,
            initializer=tf_utils.clone_initializer(self.kernel_initializer))
        self.w4 = self.add_weight(
            name='w4',
            shape=(input_shape[-1] // 128, 128, input_shape[-1]),
            trainable=True,
            initializer=tf_utils.clone_initializer(self.kernel_initializer))

        if self.use_bias:
            self.bias = self.add_weight(
                name='b',
                shape=(input_shape[-1] // 128, 1,
                       128 * (self.proj_size // input_shape[-1])),
                trainable=True,
                initializer=self.bias_initializer)
        else:
            self.bias = None
示例#9
0
    def build(self, input_shape):
        num_reduced_filters = nn_layers.make_divisible(
            max(1, int(self._in_filters * self._se_ratio)),
            divisor=self._divisible_by,
            round_down_protect=self._round_down_protect)

        self._se_reduce = helper.Conv2DQuantized(
            filters=num_reduced_filters,
            kernel_size=1,
            strides=1,
            padding='same',
            use_bias=True,
            kernel_initializer=tf_utils.clone_initializer(
                self._kernel_initializer),
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activation=helper.NoOpActivation())

        self._se_expand = helper.Conv2DOutputQuantized(
            filters=self._out_filters,
            kernel_size=1,
            strides=1,
            padding='same',
            use_bias=True,
            kernel_initializer=tf_utils.clone_initializer(
                self._kernel_initializer),
            kernel_regularizer=self._kernel_regularizer,
            bias_regularizer=self._bias_regularizer,
            activation=helper.NoOpActivation())

        self._multiply = tfmot.quantization.keras.QuantizeWrapperV2(
            tf.keras.layers.Multiply(),
            configs.Default8BitQuantizeConfig([], [], True))
        self._reduce_mean_quantizer = (
            tfmot.quantization.keras.quantizers.MovingAverageQuantizer(
                num_bits=8,
                per_axis=False,
                symmetric=False,
                narrow_range=False))
        self._reduce_mean_quantizer_vars = self._reduce_mean_quantizer.build(
            None, 'reduce_mean_quantizer_vars', self)

        self._activation_layer = tfmot.quantization.keras.QuantizeWrapperV2(
            tf_utils.get_activation(self._activation, use_keras_layer=True),
            configs.Default8BitActivationQuantizeConfig())
        self._create_gating_activation_layer()

        self._build_quantizer_vars()
        super().build(input_shape)
示例#10
0
  def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
    """Creates the variables of the mask scoring head."""
    conv_op = tf.keras.layers.Conv2D
    conv_kwargs = {
        'filters': self._config_dict['num_filters'],
        'kernel_size': 3,
        'padding': 'same',
    }
    conv_kwargs.update({
        'kernel_initializer': tf.keras.initializers.VarianceScaling(
            scale=2, mode='fan_out', distribution='untruncated_normal'),
        'bias_initializer': tf.zeros_initializer(),
        'kernel_regularizer': self._config_dict['kernel_regularizer'],
        'bias_regularizer': self._config_dict['bias_regularizer'],
    })
    bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
             if self._config_dict['use_sync_bn']
             else tf.keras.layers.BatchNormalization)
    bn_kwargs = {
        'axis': self._bn_axis,
        'momentum': self._config_dict['norm_momentum'],
        'epsilon': self._config_dict['norm_epsilon'],
    }

    self._convs = []
    self._conv_norms = []
    for i in range(self._config_dict['num_convs']):
      conv_name = 'mask-scoring_{}'.format(i)
      if 'kernel_initializer' in conv_kwargs:
        conv_kwargs['kernel_initializer'] = tf_utils.clone_initializer(
            conv_kwargs['kernel_initializer'])
      self._convs.append(conv_op(name=conv_name, **conv_kwargs))
      bn_name = 'mask-scoring-bn_{}'.format(i)
      self._conv_norms.append(bn_op(name=bn_name, **bn_kwargs))

    self._fcs = []
    self._fc_norms = []
    for i in range(self._config_dict['num_fcs']):
      fc_name = 'mask-scoring-fc_{}'.format(i)
      self._fcs.append(
          tf.keras.layers.Dense(
              units=self._config_dict['fc_dims'],
              kernel_initializer=tf.keras.initializers.VarianceScaling(
                  scale=1 / 3.0, mode='fan_out', distribution='uniform'),
              kernel_regularizer=self._config_dict['kernel_regularizer'],
              bias_regularizer=self._config_dict['bias_regularizer'],
              name=fc_name))
      bn_name = 'mask-scoring-fc-bn_{}'.format(i)
      self._fc_norms.append(bn_op(name=bn_name, **bn_kwargs))

    self._classifier = tf.keras.layers.Dense(
        units=self._config_dict['num_classes'],
        kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
        bias_initializer=tf.zeros_initializer(),
        kernel_regularizer=self._config_dict['kernel_regularizer'],
        bias_regularizer=self._config_dict['bias_regularizer'],
        name='iou-scores')

    super(MaskScoring, self).build(input_shape)
示例#11
0
 def build(self, input_shape):
   self._embedding_projection = tf.keras.layers.EinsumDense(
       '...x,xy->...y',
       output_shape=self._output_dim,
       bias_axes=None,
       kernel_initializer=tf_utils.clone_initializer(self._initializer),
       name='embedding_projection')
   super().build(input_shape)
示例#12
0
    def __init__(self,
                 inner_dim,
                 cls_list,
                 cls_token_idx=0,
                 activation="tanh",
                 dropout_rate=0.0,
                 initializer="glorot_uniform",
                 **kwargs):
        """Initializes the `MultiClsHeads`.

    Args:
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
      cls_list: a list of pairs of (classification problem name and the numbers
        of classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
        super().__init__(**kwargs)
        self.dropout_rate = dropout_rate
        self.inner_dim = inner_dim
        self.cls_list = cls_list
        self.activation = tf_utils.get_activation(activation)
        self.initializer = tf.keras.initializers.get(initializer)
        self.cls_token_idx = cls_token_idx

        if self.inner_dim:
            self.dense = tf.keras.layers.Dense(
                units=inner_dim,
                activation=self.activation,
                kernel_initializer=tf_utils.clone_initializer(
                    self.initializer),
                name="pooler_dense")
        self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
        self.out_projs = []
        for name, num_classes in cls_list:
            self.out_projs.append(
                tf.keras.layers.Dense(
                    units=num_classes,
                    kernel_initializer=tf_utils.clone_initializer(
                        self.initializer),
                    name=name))
示例#13
0
    def _build_attention(self, qkv_rank):
        """Builds multi-head dot-product attention computations.

    This function overrides base class to create additional linear projection
    that will be applied on attention scores before and after softmax.

    Args:
      qkv_rank: The rank of query, key, value tensors after projection.
    """
        super(TalkingHeadsAttention, self)._build_attention(qkv_rank)

        # Build an equation:
        # (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
        # (<batch_dims>, num_heads_b, ...)
        # qkv_ranks has `batch_dims`, `attention_dims`, `num_heads` and `channels`.
        num_batch_dims = qkv_rank - len(self._attention_axes) - 2

        # The shape of attn_scores is:
        # (<batch_dims>, num_heads, <query_attn_dims>, <key_attn_dims>)
        attn_scores_rank = num_batch_dims + 1 + len(self._attention_axes) * 2
        scores_notation = _CHR_IDX[:attn_scores_rank]
        projection_notation = scores_notation[num_batch_dims] + (
            _CHR_IDX[attn_scores_rank])
        projected_scores_notation = scores_notation[:num_batch_dims] + (
            _CHR_IDX[attn_scores_rank] + scores_notation[num_batch_dims + 1:])
        self._talking_heads_equation = "%s,%s->%s" % (
            scores_notation, projection_notation, projected_scores_notation)

        self._pre_softmax_weight = self.add_weight(
            "pre_softmax_weight",
            shape=(self._num_heads, self._num_heads),
            initializer=tf_utils.clone_initializer(self._kernel_initializer),
            regularizer=self._kernel_regularizer,
            constraint=self._kernel_constraint,
            dtype=self.dtype,
            trainable=True)
        self._post_softmax_weight = self.add_weight(
            "post_softmax_weight",
            shape=(self._num_heads, self._num_heads),
            initializer=tf_utils.clone_initializer(self._kernel_initializer),
            regularizer=self._kernel_regularizer,
            constraint=self._kernel_constraint,
            dtype=self.dtype,
            trainable=True)
示例#14
0
    def __init__(self,
                 inner_dim,
                 num_classes,
                 cls_token_idx=0,
                 activation="tanh",
                 dropout_rate=0.0,
                 initializer="glorot_uniform",
                 use_spec_norm=True,
                 use_gp_layer=True,
                 temperature=None,
                 **kwargs):
        """Initializes the `GaussianProcessClassificationHead`.

    Args:
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
      num_classes: Number of output classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      use_spec_norm: Whether to apply spectral normalization to pooler layer.
      use_gp_layer: Whether to use Gaussian process as the output layer.
      temperature: The temperature parameter to be used for mean-field
        approximation during inference. If None then no mean-field adjustment is
        applied.
      **kwargs: Additional keyword arguments.
    """
        # Collects spectral normalization and Gaussian process args from kwargs.
        self.use_spec_norm = use_spec_norm
        self.use_gp_layer = use_gp_layer
        self.spec_norm_kwargs = extract_spec_norm_kwargs(kwargs)
        self.gp_layer_kwargs = extract_gp_layer_kwargs(kwargs)
        self.temperature = temperature

        super().__init__(inner_dim=inner_dim,
                         num_classes=num_classes,
                         cls_token_idx=cls_token_idx,
                         activation=activation,
                         dropout_rate=dropout_rate,
                         initializer=initializer,
                         **kwargs)

        # Applies spectral normalization to the dense pooler layer.
        if self.use_spec_norm and hasattr(self, "dense"):
            self.dense = spectral_normalization.SpectralNormalization(
                self.dense, inhere_layer_name=True, **self.spec_norm_kwargs)

        # Replace Dense output layer with the Gaussian process layer.
        if use_gp_layer:
            self.out_proj = gaussian_process.RandomFeatureGaussianProcess(
                self.num_classes,
                kernel_initializer=tf_utils.clone_initializer(
                    self.initializer),
                name="logits",
                **self.gp_layer_kwargs)
示例#15
0
    def build(self, input_shape):
        """Builds the layer."""
        # Layers for linearly projecting the queries, keys, and values.
        size_per_head = self.hidden_size // self.num_heads

        def _glorot_initializer(fan_in, fan_out):
            limit = math.sqrt(6.0 / (fan_in + fan_out))
            return tf.keras.initializers.RandomUniform(minval=-limit,
                                                       maxval=limit)

        attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
                                                    self.hidden_size)
        self.query_dense_layer = tf.keras.layers.experimental.EinsumDense(
            "BTE,ENH->BTNH",
            output_shape=(None, self.num_heads, size_per_head),
            kernel_initializer=tf_utils.clone_initializer(
                attention_initializer),
            bias_axes=None,
            name="query")
        self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
            "BTE,ENH->BTNH",
            output_shape=(None, self.num_heads, size_per_head),
            kernel_initializer=tf_utils.clone_initializer(
                attention_initializer),
            bias_axes=None,
            name="key")
        self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
            "BTE,ENH->BTNH",
            output_shape=(None, self.num_heads, size_per_head),
            kernel_initializer=tf_utils.clone_initializer(
                attention_initializer),
            bias_axes=None,
            name="value")

        output_initializer = _glorot_initializer(self.hidden_size,
                                                 self.hidden_size)
        self.output_dense_layer = tf.keras.layers.experimental.EinsumDense(
            "BTNH,NHE->BTE",
            output_shape=(None, self.hidden_size),
            kernel_initializer=output_initializer,
            bias_axes=None,
            name="output_transform")
        super(Attention, self).build(input_shape)
示例#16
0
    def __init__(self,
                 inner_dim,
                 num_classes,
                 cls_token_idx=0,
                 activation="tanh",
                 dropout_rate=0.0,
                 initializer="glorot_uniform",
                 **kwargs):
        """Initializes the `ClassificationHead`.

    Args:
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
      num_classes: Number of output classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
        super().__init__(**kwargs)
        self.dropout_rate = dropout_rate
        self.inner_dim = inner_dim
        self.num_classes = num_classes
        self.activation = tf_utils.get_activation(activation)
        self.initializer = tf.keras.initializers.get(initializer)
        self.cls_token_idx = cls_token_idx

        if self.inner_dim:
            self.dense = tf.keras.layers.Dense(
                units=self.inner_dim,
                activation=self.activation,
                kernel_initializer=tf_utils.clone_initializer(
                    self.initializer),
                name="pooler_dense")
        self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)

        self.out_proj = tf.keras.layers.Dense(
            units=num_classes,
            kernel_initializer=tf_utils.clone_initializer(self.initializer),
            name="logits")
示例#17
0
 def __init__(self,
              num_attention_heads,
              intermediate_size,
              intermediate_activation,
              dropout_rate=0.0,
              attention_dropout_rate=0.0,
              multi_channel_cross_attention=False,
              kernel_initializer="glorot_uniform",
              bias_initializer="zeros",
              kernel_regularizer=None,
              bias_regularizer=None,
              activity_regularizer=None,
              kernel_constraint=None,
              bias_constraint=None,
              use_bias=True,
              norm_first=False,
              norm_epsilon=1e-12,
              intermediate_dropout=0.0,
              attention_initializer=None,
              **kwargs):
     super().__init__(**kwargs)
     self.num_attention_heads = num_attention_heads
     self.intermediate_size = intermediate_size
     self.intermediate_activation = tf.keras.activations.get(
         intermediate_activation)
     self.dropout_rate = dropout_rate
     self.attention_dropout_rate = attention_dropout_rate
     self.multi_channel_cross_attention = multi_channel_cross_attention
     self._kernel_initializer = tf.keras.initializers.get(
         kernel_initializer)
     self._bias_initializer = tf.keras.initializers.get(bias_initializer)
     self._kernel_regularizer = tf.keras.regularizers.get(
         kernel_regularizer)
     self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
     self._activity_regularizer = tf.keras.regularizers.get(
         activity_regularizer)
     self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
     self._bias_constraint = tf.keras.constraints.get(bias_constraint)
     self._use_bias = use_bias
     self._norm_first = norm_first
     self._norm_epsilon = norm_epsilon
     self._intermediate_dropout = intermediate_dropout
     if attention_initializer:
         self._attention_initializer = tf.keras.initializers.get(
             attention_initializer)
     else:
         self._attention_initializer = tf_utils.clone_initializer(
             self._kernel_initializer)
     if self.multi_channel_cross_attention:
         self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
     else:
         self._cross_attention_cls = attention.MultiHeadAttention
示例#18
0
    def __init__(self,
                 num_attention_heads,
                 intermediate_size,
                 intermediate_activation,
                 dropout_rate=0.0,
                 attention_dropout_rate=0.0,
                 output_range=None,
                 kernel_initializer="glorot_uniform",
                 bias_initializer="zeros",
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 use_bias=True,
                 norm_first=False,
                 norm_epsilon=1e-12,
                 intermediate_dropout=0.0,
                 attention_initializer=None,
                 **kwargs):
        super(TNTransformerExpandCondense, self).__init__(**kwargs)

        self._num_heads = num_attention_heads
        self._intermediate_size = intermediate_size
        self._intermediate_activation = intermediate_activation
        self._attention_dropout_rate = attention_dropout_rate
        self._dropout_rate = dropout_rate
        self._output_range = output_range
        self._kernel_initializer = tf.keras.initializers.get(
            kernel_initializer)
        self._bias_initializer = tf.keras.initializers.get(bias_initializer)
        self._kernel_regularizer = tf.keras.regularizers.get(
            kernel_regularizer)
        self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
        self._activity_regularizer = tf.keras.regularizers.get(
            activity_regularizer)
        self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
        self._bias_constraint = tf.keras.constraints.get(bias_constraint)
        self._use_bias = use_bias
        self._norm_first = norm_first
        self._norm_epsilon = norm_epsilon
        self._intermediate_dropout = intermediate_dropout
        if attention_initializer:
            self._attention_initializer = tf.keras.initializers.get(
                attention_initializer)
        else:
            self._attention_initializer = tf_utils.clone_initializer(
                self._kernel_initializer)
示例#19
0
    def build(self, input_shape):
        hidden_size = input_shape.as_list()[-1]

        common_kwargs = dict(kernel_regularizer=self._kernel_regularizer,
                             bias_regularizer=self._bias_regularizer,
                             activity_regularizer=self._activity_regularizer,
                             kernel_constraint=self._kernel_constraint,
                             bias_constraint=self._bias_constraint)

        self._intermediate_dense = tf.keras.layers.EinsumDense(
            "abc,cde->abde",
            output_shape=(None, self._num_blocks,
                          self._intermediate_size // self._num_blocks),
            bias_axes="de",
            name="intermediate",
            kernel_initializer=tf_utils.clone_initializer(
                self._kernel_initializer),
            bias_initializer=tf_utils.clone_initializer(
                self._bias_initializer),
            **common_kwargs)

        policy = tf.keras.mixed_precision.global_policy()
        if policy.name == "mixed_bfloat16":
            # bfloat16 causes BERT with the LAMB optimizer to not converge
            # as well, so we use float32.
            policy = tf.float32
        self._intermediate_activation_layer = tf.keras.layers.Activation(
            self._intermediate_activation, dtype=policy)

        self._output_dense = tf.keras.layers.EinsumDense(
            "abde,deo->abdo",
            output_shape=(None, self._num_blocks,
                          hidden_size // self._num_blocks),
            bias_axes="do",
            name="output",
            kernel_initializer=tf_utils.clone_initializer(
                self._kernel_initializer),
            bias_initializer=tf_utils.clone_initializer(
                self._bias_initializer),
            **common_kwargs)

        if self._apply_mixing:
            self._output_mixing = tf.keras.layers.EinsumDense(
                "abdo,de->abeo",
                output_shape=(None, self._num_blocks,
                              hidden_size // self._num_blocks),
                name="output_mixing",
                kernel_initializer=tf_utils.clone_initializer(
                    self._kernel_initializer),
                bias_initializer=tf_utils.clone_initializer(
                    self._bias_initializer),
                **common_kwargs)
        self._output_reshape = tf.keras.layers.Reshape((-1, hidden_size))

        self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout)
示例#20
0
 def build(self, input_shape):
     """Implements build() for the layer."""
     self.encoder_layers = []
     for i in range(self.num_layers):
         self.encoder_layers.append(
             TransformerEncoderBlock(
                 num_attention_heads=self.num_attention_heads,
                 inner_dim=self._intermediate_size,
                 inner_activation=self._activation,
                 output_dropout=self._dropout_rate,
                 attention_dropout=self._attention_dropout_rate,
                 use_bias=self._use_bias,
                 norm_first=self._norm_first,
                 norm_epsilon=self._norm_epsilon,
                 inner_dropout=self._intermediate_dropout,
                 attention_initializer=tf_utils.clone_initializer(
                     models.seq2seq_transformer.attention_initializer(
                         input_shape[2])),
                 name=("layer_%d" % i)))
     self.output_normalization = tf.keras.layers.LayerNormalization(
         epsilon=self._norm_epsilon, dtype="float32")
     super(TransformerEncoder, self).build(input_shape)
示例#21
0
  def build(self, input_shape):
    hidden_size = input_shape.as_list()[-1]

    common_kwargs = dict(
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint)
    self._intermediate_dense = []
    self._intermediate_activation_layers = []
    self._gate_dense = []
    self._output_dense = []
    self._output_dropout = []
    self._output_layer_norm = []
    activation_policy = tf.keras.mixed_precision.global_policy()
    if activation_policy.name == "mixed_bfloat16":
      # bfloat16 causes BERT with the LAMB optimizer to not converge
      # as well, so we use float32.
      # TODO(b/154538392): Investigate this.
      activation_policy = tf.float32
    for i in range(self._num_blocks):
      self._intermediate_dense.append(
          tf.keras.layers.experimental.EinsumDense(
              "abc,cd->abd",
              output_shape=(None, self._intermediate_size),
              bias_axes="d",
              name="intermediate_%d" % i,
              kernel_initializer=tf_utils.clone_initializer(
                  self._kernel_initializer),
              bias_initializer=tf_utils.clone_initializer(
                  self._bias_initializer),
              **common_kwargs))
      self._intermediate_activation_layers.append(
          tf.keras.layers.Activation(
              self._intermediate_activation, dtype=activation_policy))
      if self._use_gate:
        self._gate_dense.append(
            tf.keras.layers.experimental.EinsumDense(
                "abc,cd->abd",
                output_shape=(None, self._intermediate_size),
                bias_axes="d",
                name="gate_%d" % i,
                kernel_initializer=tf_utils.clone_initializer(
                    self._kernel_initializer),
                bias_initializer=tf_utils.clone_initializer(
                    self._bias_initializer),
                **common_kwargs))
      self._output_dense.append(
          tf.keras.layers.experimental.EinsumDense(
              "abc,cd->abd",
              output_shape=(None, hidden_size),
              bias_axes="d",
              name="output_%d" % i,
              kernel_initializer=tf_utils.clone_initializer(
                  self._kernel_initializer),
              bias_initializer=tf_utils.clone_initializer(
                  self._bias_initializer),
              **common_kwargs))
      self._output_dropout.append(tf.keras.layers.Dropout(rate=self._dropout))
      # Use float32 in layernorm for numeric stability.
      if self._apply_output_layer_norm:
        self._output_layer_norm.append(
            tf.keras.layers.LayerNormalization(
                name="output_layer_norm_%d" % i,
                axis=-1,
                epsilon=1e-12,
                dtype=tf.float32))
示例#22
0
    def __init__(self,
                 num_attention_heads,
                 inner_dim,
                 inner_activation,
                 output_range=None,
                 kernel_initializer="glorot_uniform",
                 bias_initializer="zeros",
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 use_bias=True,
                 norm_first=False,
                 norm_epsilon=1e-12,
                 output_dropout=0.0,
                 attention_dropout=0.0,
                 inner_dropout=0.0,
                 attention_initializer=None,
                 attention_axes=None,
                 use_query_residual=True,
                 key_dim=None,
                 value_dim=None,
                 output_last_dim=None,
                 diff_q_kv_att_layer_norm=False,
                 **kwargs):
        """Initializes `TransformerEncoderBlock`.

    Note: If `output_last_dim` is used and `use_query_residual` is `True`, the
    `output_last_dim`'s value must equal the first input's last dimension for
    the query residual connection to work. This is because the residual
    connection after the multi-head-attention requires their dimensions to
    match. If `use_query_residual` is `False`, the `output_last_dim` dictactes
    the last dimension of the output of this module and the
    multi-head-attention.

    E.g. let's say input dims are `[batch_size, seq_dim, input_last_dim]`.
    Scenario 1: If `output_last_dim` is not `None`, then the output dims of this
    module would be `[batch_size, seq_dim, output_last_dim]`. Note `key_dim` is
    overriden by `output_last_dim`.
    Scenario 2: If `output_last_dim` is `None` and `key_dim` is not `None`, then
    the output dims of this module would be `[batch_size, seq_dim, key_dim]`.
    Scenario 3: If the `output_last_dim` and `key_dim` are both `None`, the
    output dims would be `[batch_size, seq_dim, input_last_dim]`.

    Args:
      num_attention_heads: Number of attention heads.
      inner_dim: The output dimension of the first Dense layer in a two-layer
        feedforward network.
      inner_activation: The activation for the first Dense layer in a two-layer
        feedforward network.
      output_range: the sequence output range, [0, output_range) for slicing the
        target sequence. `None` means the target sequence is not sliced.
      kernel_initializer: Initializer for dense layer kernels.
      bias_initializer: Initializer for dense layer biases.
      kernel_regularizer: Regularizer for dense layer kernels.
      bias_regularizer: Regularizer for dense layer biases.
      activity_regularizer: Regularizer for dense layer activity.
      kernel_constraint: Constraint for dense layer kernels.
      bias_constraint: Constraint for dense layer kernels.
      use_bias: Whether to enable use_bias in attention layer. If set False,
        use_bias in attention layer is disabled.
      norm_first: Whether to normalize inputs to attention and intermediate
        dense layers. If set False, output of attention and intermediate dense
        layers is normalized.
      norm_epsilon: Epsilon value to initialize normalization layers.
      output_dropout: Dropout probability for the post-attention and output
        dropout.
      attention_dropout: Dropout probability for within the attention layer.
      inner_dropout: Dropout probability for the first Dense layer in a
        two-layer feedforward network.
      attention_initializer: Initializer for kernels of attention layers. If set
        `None`, attention layers use kernel_initializer as initializer for
        kernel.
      attention_axes: axes over which the attention is applied. `None` means
        attention over all axes, but batch, heads, and features.
      use_query_residual: Toggle to execute residual connection after attention.
      key_dim: `key_dim` for the `tf.keras.layers.MultiHeadAttention`. If
        `None`, we use the first `input_shape`'s last dim.
      value_dim: `value_dim` for the `tf.keras.layers.MultiHeadAttention`.
      output_last_dim: Final dimension of the output of this module. This also
        dictates the value for the final dimension of the
        multi-head-attention. When it's `None`, we use, in order of decreasing
        precedence, `key_dim` * `num_heads` or the first `input_shape`'s last
        dim as the output's last dim.
      diff_q_kv_att_layer_norm: If `True`, create a separate attention layer
        norm layer for query and key-value if `norm_first` is `True`. Invalid
        to set to `True` if `norm_first` is `False`.
      **kwargs: keyword arguments.
    """
        util.filter_kwargs(kwargs)
        super().__init__(**kwargs)

        # Deprecation warning.
        if output_range is not None:
            logging.warning(
                "`output_range` is avaliable as an argument for `call()`."
                "The `output_range` as __init__ argument is deprecated.")

        self._num_heads = num_attention_heads
        self._inner_dim = inner_dim
        self._inner_activation = inner_activation
        self._attention_dropout_rate = attention_dropout
        self._output_dropout_rate = output_dropout
        self._output_range = output_range
        self._kernel_initializer = tf.keras.initializers.get(
            kernel_initializer)
        self._bias_initializer = tf.keras.initializers.get(bias_initializer)
        self._kernel_regularizer = tf.keras.regularizers.get(
            kernel_regularizer)
        self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
        self._activity_regularizer = tf.keras.regularizers.get(
            activity_regularizer)
        self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
        self._bias_constraint = tf.keras.constraints.get(bias_constraint)
        self._use_bias = use_bias
        self._norm_first = norm_first
        self._norm_epsilon = norm_epsilon
        self._inner_dropout = inner_dropout
        self._use_query_residual = use_query_residual
        self._key_dim = key_dim
        self._value_dim = value_dim
        self._output_last_dim = output_last_dim
        self._diff_q_kv_att_layer_norm = diff_q_kv_att_layer_norm
        if attention_initializer:
            self._attention_initializer = tf.keras.initializers.get(
                attention_initializer)
        else:
            self._attention_initializer = tf_utils.clone_initializer(
                self._kernel_initializer)
        self._attention_axes = attention_axes

        if self._diff_q_kv_att_layer_norm and not self._norm_first:
            raise ValueError(
                "Setting `diff_q_and_kv_attention_layer_norm` to True"
                "when `norm_first` is False is invalid.")
示例#23
0
    def build(self, input_shape):
        if isinstance(input_shape, tf.TensorShape):
            input_tensor_shape = input_shape
        elif isinstance(input_shape, (list, tuple)):
            input_tensor_shape = tf.TensorShape(input_shape[0])
        else:
            raise ValueError(
                "The type of input shape argument is not supported, got: %s" %
                type(input_shape))
        einsum_equation = "abc,cd->abd"
        if len(input_tensor_shape.as_list()) > 3:
            einsum_equation = "...bc,cd->...bd"
        hidden_size = input_tensor_shape[-1]
        if hidden_size % self._num_heads != 0:
            logging.warning(
                "The input size (%d) is not a multiple of the number of attention "
                "heads (%d)", hidden_size, self._num_heads)
        if self._key_dim is None:
            self._key_dim = int(hidden_size // self._num_heads)
        if self._output_last_dim is None:
            last_output_shape = hidden_size
        else:
            last_output_shape = self._output_last_dim

        common_kwargs = dict(bias_regularizer=self._bias_regularizer,
                             activity_regularizer=self._activity_regularizer,
                             kernel_constraint=self._kernel_constraint,
                             bias_constraint=self._bias_constraint)
        self._attention_layer = tf.keras.layers.MultiHeadAttention(
            num_heads=self._num_heads,
            key_dim=self._key_dim,
            value_dim=self._value_dim,
            dropout=self._attention_dropout_rate,
            use_bias=self._use_bias,
            kernel_initializer=self._attention_initializer,
            bias_initializer=tf_utils.clone_initializer(
                self._bias_initializer),
            attention_axes=self._attention_axes,
            output_shape=self._output_last_dim,
            name="self_attention",
            **common_kwargs)
        self._attention_dropout = tf.keras.layers.Dropout(
            rate=self._attention_dropout_rate)
        # Use float32 in layernorm for numeric stability.
        # It is probably safe in mixed_float16, but we haven't validated this yet.
        self._attention_layer_norm = (tf.keras.layers.LayerNormalization(
            name="self_attention_layer_norm",
            axis=-1,
            epsilon=self._norm_epsilon,
            dtype=tf.float32))
        self._attention_layer_norm_kv = self._attention_layer_norm
        if self._diff_q_kv_att_layer_norm:
            self._attention_layer_norm_kv = (
                tf.keras.layers.LayerNormalization(
                    name="self_attention_layer_norm_kv",
                    axis=-1,
                    epsilon=self._norm_epsilon,
                    dtype=tf.float32))

        self._intermediate_dense = tf.keras.layers.EinsumDense(
            einsum_equation,
            output_shape=(None, self._inner_dim),
            bias_axes="d",
            kernel_initializer=tf_utils.clone_initializer(
                self._kernel_initializer),
            bias_initializer=tf_utils.clone_initializer(
                self._bias_initializer),
            name="intermediate",
            **common_kwargs)
        policy = tf.keras.mixed_precision.global_policy()
        if policy.name == "mixed_bfloat16":
            # bfloat16 causes BERT with the LAMB optimizer to not converge
            # as well, so we use float32.
            # TODO(b/154538392): Investigate this.
            policy = tf.float32
        self._intermediate_activation_layer = tf.keras.layers.Activation(
            self._inner_activation, dtype=policy)
        self._inner_dropout_layer = tf.keras.layers.Dropout(
            rate=self._inner_dropout)
        self._output_dense = tf.keras.layers.EinsumDense(
            einsum_equation,
            output_shape=(None, last_output_shape),
            bias_axes="d",
            name="output",
            kernel_initializer=tf_utils.clone_initializer(
                self._kernel_initializer),
            bias_initializer=tf_utils.clone_initializer(
                self._bias_initializer),
            **common_kwargs)
        self._output_dropout = tf.keras.layers.Dropout(
            rate=self._output_dropout_rate)
        # Use float32 in layernorm for numeric stability.
        self._output_layer_norm = tf.keras.layers.LayerNormalization(
            name="output_layer_norm",
            axis=-1,
            epsilon=self._norm_epsilon,
            dtype=tf.float32)

        super().build(input_shape)
示例#24
0
    def __init__(
            self,
            pooled_output_dim,
            pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            embedding_cls=None,
            embedding_cfg=None,
            embedding_data=None,
            num_hidden_instances=1,
            hidden_cls=layers.Transformer,
            hidden_cfg=None,
            mask_cls=layers.SelfAttentionMask,
            mask_cfg=None,
            layer_norm_before_pooling=False,
            return_all_layer_outputs=False,
            dict_outputs=False,
            layer_idx_as_attention_seed=False,
            feed_layer_idx=False,
            recursive=False,
            **kwargs):

        if embedding_cls:
            if inspect.isclass(embedding_cls):
                embedding_network = embedding_cls(
                    **embedding_cfg) if embedding_cfg else embedding_cls()
            else:
                embedding_network = embedding_cls
            inputs = embedding_network.inputs
            embeddings, attention_mask = embedding_network(inputs)
            embedding_layer = None
            position_embedding_layer = None
            type_embedding_layer = None
            embedding_norm_layer = None
        else:
            embedding_network = None
            seq_length = embedding_cfg.get('seq_length', None)
            word_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                             dtype=tf.int32,
                                             name='input_word_ids')
            mask = tf.keras.layers.Input(shape=(seq_length, ),
                                         dtype=tf.int32,
                                         name='input_mask')
            type_ids = tf.keras.layers.Input(shape=(seq_length, ),
                                             dtype=tf.int32,
                                             name='input_type_ids')
            inputs = [word_ids, mask, type_ids]

            embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=embedding_cfg['vocab_size'],
                embedding_width=embedding_cfg['hidden_size'],
                initializer=tf_utils.clone_initializer(
                    embedding_cfg['initializer']),
                name='word_embeddings')

            word_embeddings = embedding_layer(word_ids)

            # Always uses dynamic slicing for simplicity.
            position_embedding_layer = layers.PositionEmbedding(
                initializer=tf_utils.clone_initializer(
                    embedding_cfg['initializer']),
                max_length=embedding_cfg['max_seq_length'],
                name='position_embedding')
            position_embeddings = position_embedding_layer(word_embeddings)

            type_embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=embedding_cfg['type_vocab_size'],
                embedding_width=embedding_cfg['hidden_size'],
                initializer=tf_utils.clone_initializer(
                    embedding_cfg['initializer']),
                use_one_hot=True,
                name='type_embeddings')
            type_embeddings = type_embedding_layer(type_ids)

            embeddings = tf.keras.layers.Add()(
                [word_embeddings, position_embeddings, type_embeddings])

            embedding_norm_layer = tf.keras.layers.LayerNormalization(
                name='embeddings/layer_norm',
                axis=-1,
                epsilon=1e-12,
                dtype=tf.float32)
            embeddings = embedding_norm_layer(embeddings)

            embeddings = (tf.keras.layers.Dropout(
                rate=embedding_cfg['dropout_rate'])(embeddings))

            mask_cfg = {} if mask_cfg is None else mask_cfg
            if inspect.isclass(mask_cls):
                mask_layer = mask_cls(**mask_cfg)
            else:
                mask_layer = mask_cls
            attention_mask = mask_layer(embeddings, mask)

        data = embeddings

        layer_output_data = []
        hidden_layers = []
        hidden_cfg = hidden_cfg if hidden_cfg else {}

        if isinstance(hidden_cls,
                      list) and len(hidden_cls) != num_hidden_instances:
            raise RuntimeError((
                'When input hidden_cls to EncoderScaffold %s is a list, it must '
                'contain classes or instances with size specified by '
                'num_hidden_instances, got %d vs %d.') % self.name,
                               len(hidden_cls), num_hidden_instances)
        # Consider supporting customized init states.
        recursive_states = None
        for i in range(num_hidden_instances):
            if isinstance(hidden_cls, list):
                cur_hidden_cls = hidden_cls[i]
            else:
                cur_hidden_cls = hidden_cls
            if inspect.isclass(cur_hidden_cls):
                if hidden_cfg and 'attention_cfg' in hidden_cfg and (
                        layer_idx_as_attention_seed):
                    hidden_cfg = copy.deepcopy(hidden_cfg)
                    hidden_cfg['attention_cfg']['seed'] = i
                if feed_layer_idx:
                    hidden_cfg['layer_idx'] = i
                layer = cur_hidden_cls(**hidden_cfg)
            else:
                layer = cur_hidden_cls
            if recursive:
                data, recursive_states = layer(
                    [data, attention_mask, recursive_states])
            else:
                data = layer([data, attention_mask])
            layer_output_data.append(data)
            hidden_layers.append(layer)

        if layer_norm_before_pooling:
            # Normalize the final output.
            output_layer_norm = tf.keras.layers.LayerNormalization(
                name='final_layer_norm', axis=-1, epsilon=1e-12)
            layer_output_data[-1] = output_layer_norm(layer_output_data[-1])

        last_layer_output = layer_output_data[-1]
        # Applying a tf.slice op (through subscript notation) to a Keras tensor
        # like this will create a SliceOpLambda layer. This is better than a Lambda
        # layer with Python code, because that is fundamentally less portable.
        first_token_tensor = last_layer_output[:, 0, :]
        pooler_layer_initializer = tf.keras.initializers.get(
            pooler_layer_initializer)
        pooler_layer = tf.keras.layers.Dense(
            units=pooled_output_dim,
            activation='tanh',
            kernel_initializer=pooler_layer_initializer,
            name='cls_transform')
        cls_output = pooler_layer(first_token_tensor)

        if dict_outputs:
            outputs = dict(
                sequence_output=layer_output_data[-1],
                pooled_output=cls_output,
                encoder_outputs=layer_output_data,
            )
        elif return_all_layer_outputs:
            outputs = [layer_output_data, cls_output]
        else:
            outputs = [layer_output_data[-1], cls_output]

        # b/164516224
        # Once we've created the network using the Functional API, we call
        # super().__init__ as though we were invoking the Functional API Model
        # constructor, resulting in this object having all the properties of a model
        # created using the Functional API. Once super().__init__ is called, we
        # can assign attributes to `self` - note that all `self` assignments are
        # below this line.
        super().__init__(inputs=inputs, outputs=outputs, **kwargs)

        self._hidden_cls = hidden_cls
        self._hidden_cfg = hidden_cfg
        self._mask_cls = mask_cls
        self._mask_cfg = mask_cfg
        self._num_hidden_instances = num_hidden_instances
        self._pooled_output_dim = pooled_output_dim
        self._pooler_layer_initializer = pooler_layer_initializer
        self._embedding_cls = embedding_cls
        self._embedding_cfg = embedding_cfg
        self._embedding_data = embedding_data
        self._layer_norm_before_pooling = layer_norm_before_pooling
        self._return_all_layer_outputs = return_all_layer_outputs
        self._dict_outputs = dict_outputs
        self._kwargs = kwargs

        self._embedding_layer = embedding_layer
        self._embedding_network = embedding_network
        self._position_embedding_layer = position_embedding_layer
        self._type_embedding_layer = type_embedding_layer
        self._embedding_norm_layer = embedding_norm_layer
        self._hidden_layers = hidden_layers
        if self._layer_norm_before_pooling:
            self._output_layer_norm = output_layer_norm
        self._pooler_layer = pooler_layer
        self._layer_idx_as_attention_seed = layer_idx_as_attention_seed

        logging.info('EncoderScaffold configs: %s', self.get_config())
示例#25
0
    def __init__(
            self,
            vocab_size: int,
            hidden_size: int = 768,
            num_layers: int = 12,
            num_attention_heads: int = 12,
            max_sequence_length: int = 512,
            type_vocab_size: int = 16,
            inner_dim: int = 3072,
            inner_activation: _Activation = _approx_gelu,
            output_dropout: float = 0.1,
            attention_dropout: float = 0.1,
            pool_type: str = _MAX,
            pool_stride: int = 2,
            unpool_length: int = 0,
            initializer: _Initializer = tf.keras.initializers.TruncatedNormal(
                stddev=0.02),
            output_range: Optional[int] = None,
            embedding_width: Optional[int] = None,
            embedding_layer: Optional[tf.keras.layers.Layer] = None,
            norm_first: bool = False,
            transformer_cls: Union[
                str, tf.keras.layers.Layer] = layers.TransformerEncoderBlock,
            share_rezero: bool = True,
            **kwargs):
        super().__init__(**kwargs)
        activation = tf.keras.activations.get(inner_activation)
        initializer = tf.keras.initializers.get(initializer)

        if embedding_width is None:
            embedding_width = hidden_size

        if embedding_layer is None:
            self._embedding_layer = layers.OnDeviceEmbedding(
                vocab_size=vocab_size,
                embedding_width=embedding_width,
                initializer=tf_utils.clone_initializer(initializer),
                name='word_embeddings')
        else:
            self._embedding_layer = embedding_layer

        self._position_embedding_layer = layers.PositionEmbedding(
            initializer=tf_utils.clone_initializer(initializer),
            max_length=max_sequence_length,
            name='position_embedding')

        self._type_embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=type_vocab_size,
            embedding_width=embedding_width,
            initializer=tf_utils.clone_initializer(initializer),
            use_one_hot=True,
            name='type_embeddings')

        self._embedding_norm_layer = tf.keras.layers.LayerNormalization(
            name='embeddings/layer_norm',
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32)

        self._embedding_dropout = tf.keras.layers.Dropout(
            rate=output_dropout, name='embedding_dropout')

        # We project the 'embedding' output to 'hidden_size' if it is not already
        # 'hidden_size'.
        self._embedding_projection = None
        if embedding_width != hidden_size:
            self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
                '...x,xy->...y',
                output_shape=hidden_size,
                bias_axes='y',
                kernel_initializer=tf_utils.clone_initializer(initializer),
                name='embedding_projection')

        self._transformer_layers = []
        self._attention_mask_layer = layers.SelfAttentionMask(
            name='self_attention_mask')
        # Will raise an error if the string is not supported.
        if isinstance(transformer_cls, str):
            transformer_cls = _str2transformer_cls[transformer_cls]
        for i in range(num_layers):
            layer = transformer_cls(
                num_attention_heads=num_attention_heads,
                intermediate_size=inner_dim,
                inner_dim=inner_dim,
                intermediate_activation=inner_activation,
                inner_activation=inner_activation,
                output_dropout=output_dropout,
                attention_dropout=attention_dropout,
                norm_first=norm_first,
                output_range=output_range if i == num_layers - 1 else None,
                kernel_initializer=tf_utils.clone_initializer(initializer),
                share_rezero=share_rezero,
                name='transformer/layer_%d' % i)
            self._transformer_layers.append(layer)

        self._pooler_layer = tf.keras.layers.Dense(
            units=hidden_size,
            activation='tanh',
            kernel_initializer=tf_utils.clone_initializer(initializer),
            name='pooler_transform')
        if isinstance(pool_stride, int):
            # TODO(b/197133196): Pooling layer can be shared.
            pool_strides = [pool_stride] * num_layers
        else:
            if len(pool_stride) != num_layers:
                raise ValueError(
                    'Lengths of pool_stride and num_layers are not equal.')
            pool_strides = pool_stride
        # TODO(crickwu): explore tf.keras.layers.serialize method.
        if pool_type == _MAX:
            pool_cls = tf.keras.layers.MaxPooling1D
        elif pool_type == _AVG:
            pool_cls = tf.keras.layers.AveragePooling1D
        elif pool_type == _TRUNCATED_AVG:
            # TODO(b/203665205): unpool_length should be implemented.
            if unpool_length != 0:
                raise ValueError(
                    'unpool_length is not supported by truncated_avg now.')
        else:
            raise ValueError('pool_type not supported.')

        if pool_type in (_MAX, _AVG):
            self._att_input_pool_layers = []
            for layer_pool_stride in pool_strides:
                att_input_pool_layer = pool_cls(pool_size=layer_pool_stride,
                                                strides=layer_pool_stride,
                                                padding='same',
                                                name='att_input_pool_layer')
                self._att_input_pool_layers.append(att_input_pool_layer)

        self._max_sequence_length = max_sequence_length
        self._pool_strides = pool_strides  # This is a list here.
        self._unpool_length = unpool_length
        self._pool_type = pool_type

        self._config = {
            'vocab_size':
            vocab_size,
            'hidden_size':
            hidden_size,
            'num_layers':
            num_layers,
            'num_attention_heads':
            num_attention_heads,
            'max_sequence_length':
            max_sequence_length,
            'type_vocab_size':
            type_vocab_size,
            'inner_dim':
            inner_dim,
            'inner_activation':
            tf.keras.activations.serialize(activation),
            'output_dropout':
            output_dropout,
            'attention_dropout':
            attention_dropout,
            'initializer':
            tf.keras.initializers.serialize(initializer),
            'output_range':
            output_range,
            'embedding_width':
            embedding_width,
            'embedding_layer':
            embedding_layer,
            'norm_first':
            norm_first,
            'pool_type':
            pool_type,
            'pool_stride':
            pool_stride,
            'unpool_length':
            unpool_length,
            'transformer_cls':
            _transformer_cls2str.get(transformer_cls, str(transformer_cls))
        }

        self.inputs = dict(input_word_ids=tf.keras.Input(shape=(None, ),
                                                         dtype=tf.int32),
                           input_mask=tf.keras.Input(shape=(None, ),
                                                     dtype=tf.int32),
                           input_type_ids=tf.keras.Input(shape=(None, ),
                                                         dtype=tf.int32))
示例#26
0
    def build(self, input_shape):
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]

        self.aspp_layers = []

        if self.use_sync_bn:
            bn_op = tf.keras.layers.experimental.SyncBatchNormalization
        else:
            bn_op = tf.keras.layers.BatchNormalization

        if tf.keras.backend.image_data_format() == 'channels_last':
            bn_axis = -1
        else:
            bn_axis = 1

        conv_sequential = tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                filters=self.output_channels,
                kernel_size=(1, 1),
                kernel_initializer=tf_utils.clone_initializer(
                    self.kernel_initializer),
                kernel_regularizer=self.kernel_regularizer,
                use_bias=False),
            bn_op(axis=bn_axis,
                  momentum=self.batchnorm_momentum,
                  epsilon=self.batchnorm_epsilon),
            tf.keras.layers.Activation(self.activation)
        ])
        self.aspp_layers.append(conv_sequential)

        for dilation_rate in self.dilation_rates:
            leading_layers = []
            kernel_size = (3, 3)
            if self.use_depthwise_convolution:
                leading_layers += [
                    tf.keras.layers.DepthwiseConv2D(
                        depth_multiplier=1,
                        kernel_size=kernel_size,
                        padding='same',
                        depthwise_regularizer=self.kernel_regularizer,
                        depthwise_initializer=tf_utils.clone_initializer(
                            self.kernel_initializer),
                        dilation_rate=dilation_rate,
                        use_bias=False)
                ]
                kernel_size = (1, 1)
            conv_sequential = tf.keras.Sequential(leading_layers + [
                tf.keras.layers.Conv2D(
                    filters=self.output_channels,
                    kernel_size=kernel_size,
                    padding='same',
                    kernel_regularizer=self.kernel_regularizer,
                    kernel_initializer=tf_utils.clone_initializer(
                        self.kernel_initializer),
                    dilation_rate=dilation_rate,
                    use_bias=False),
                bn_op(axis=bn_axis,
                      momentum=self.batchnorm_momentum,
                      epsilon=self.batchnorm_epsilon),
                tf.keras.layers.Activation(self.activation)
            ])
            self.aspp_layers.append(conv_sequential)

        if self.pool_kernel_size is None:
            pool_sequential = tf.keras.Sequential([
                tf.keras.layers.GlobalAveragePooling2D(),
                tf.keras.layers.Reshape((1, 1, channels))
            ])
        else:
            pool_sequential = tf.keras.Sequential(
                [tf.keras.layers.AveragePooling2D(self.pool_kernel_size)])

        pool_sequential.add(
            tf.keras.Sequential([
                tf.keras.layers.Conv2D(
                    filters=self.output_channels,
                    kernel_size=(1, 1),
                    kernel_initializer=tf_utils.clone_initializer(
                        self.kernel_initializer),
                    kernel_regularizer=self.kernel_regularizer,
                    use_bias=False),
                bn_op(axis=bn_axis,
                      momentum=self.batchnorm_momentum,
                      epsilon=self.batchnorm_epsilon),
                tf.keras.layers.Activation(self.activation),
                tf.keras.layers.experimental.preprocessing.Resizing(
                    height,
                    width,
                    interpolation=self.interpolation,
                    dtype=tf.float32)
            ]))

        self.aspp_layers.append(pool_sequential)

        self.projection = tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                filters=self.output_channels,
                kernel_size=(1, 1),
                kernel_initializer=tf_utils.clone_initializer(
                    self.kernel_initializer),
                kernel_regularizer=self.kernel_regularizer,
                use_bias=False),
            bn_op(axis=bn_axis,
                  momentum=self.batchnorm_momentum,
                  epsilon=self.batchnorm_epsilon),
            tf.keras.layers.Activation(self.activation),
            tf.keras.layers.Dropout(rate=self.dropout)
        ])
示例#27
0
  def build(self, input_shape):
    input_tensor_shape = input_shape[0] if (
        len(input_shape) == 2) else input_shape
    input_tensor_shape = tf.TensorShape(input_tensor_shape)
    if len(input_tensor_shape.as_list()) != 3:
      raise ValueError(
          "TransformerScaffold expects a three-dimensional input of "
          "shape [batch, sequence, width].")
    hidden_size = input_tensor_shape[-1]
    if hidden_size % self._num_heads != 0:
      raise ValueError(
          "The input size (%d) is not a multiple of the number of attention "
          "heads (%d)" % (hidden_size, self._num_heads))
    self._attention_head_size = int(hidden_size // self._num_heads)

    common_kwargs = dict(
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activity_regularizer=self._activity_regularizer,
        kernel_constraint=self._kernel_constraint,
        bias_constraint=self._bias_constraint)

    def get_layer_instance(instance_or_cls, config, default_config):
      if isinstance(instance_or_cls, tf.keras.layers.Layer):
        return instance_or_cls
      else:
        if config is None:
          return instance_or_cls(**default_config)
        else:
          return instance_or_cls(**config)

    default_attention_cfg = {
        "kernel_initializer": tf_utils.clone_initializer(
            self._kernel_initializer),
        "bias_initializer": tf_utils.clone_initializer(self._bias_initializer),
        "num_heads": self._num_heads,
        "key_dim": self._attention_head_size,
        "dropout": self._attention_dropout_rate,
        "name": "self_attention"
    }
    default_attention_cfg.update(common_kwargs)
    self._attention_layer = get_layer_instance(
        self._attention_cls,
        config=self._attention_cfg,
        default_config=default_attention_cfg)

    if self._feedforward_cls is not None:
      default_feedforward_cfg = {
          "kernel_initializer": tf_utils.clone_initializer(
              self._kernel_initializer),
          "bias_initializer": tf_utils.clone_initializer(
              self._bias_initializer),
          "inner_dim": self._inner_dim,
          "inner_activation": self._inner_activation,
          # TODO(hongkuny): try to update all ffn block args.
          "intermediate_size": self._inner_dim,
          "intermediate_activation": self._inner_activation,
          "dropout": self._dropout_rate,
          "name": "feedforward",
      }
      default_feedforward_cfg.update(common_kwargs)
      self._feedforward_block = get_layer_instance(
          self._feedforward_cls,
          config=self._feedforward_cfg,
          default_config=default_feedforward_cfg)
    else:
      self._feedforward_block = None

    # self._dropout_rate controls dropout rates at two places:
    # after attention, and after FFN.
    self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
    # Use float32 in layernorm for numeric stability.
    # It is probably safe in mixed_float16, but we haven't validated this yet.
    self._attention_layer_norm = (
        tf.keras.layers.LayerNormalization(
            name="self_attention_layer_norm",
            axis=-1,
            epsilon=1e-12,
            dtype=tf.float32))

    if self._feedforward_block is None:
      self._intermediate_dense = tf.keras.layers.EinsumDense(
          "abc,cd->abd",
          output_shape=(None, self._inner_dim),
          bias_axes="d",
          name="intermediate",
          kernel_initializer=tf_utils.clone_initializer(
              self._kernel_initializer),
          bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
          **common_kwargs)
      policy = tf.keras.mixed_precision.global_policy()
      if policy.name == "mixed_bfloat16":
        # bfloat16 causes BERT with the LAMB optimizer to not converge
        # as well, so we use float32.
        # TODO(b/154538392): Investigate this.
        policy = tf.float32
      self._intermediate_activation_layer = tf.keras.layers.Activation(
          self._inner_activation, dtype=policy)
      self._output_dense = tf.keras.layers.EinsumDense(
          "abc,cd->abd",
          output_shape=(None, hidden_size),
          bias_axes="d",
          name="output",
          kernel_initializer=tf_utils.clone_initializer(
              self._kernel_initializer),
          bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
          **common_kwargs)

    self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
    # Use float32 in layernorm for numeric stability.
    self._output_layer_norm = tf.keras.layers.LayerNormalization(
        name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)

    super(TransformerScaffold, self).build(input_shape)
    logging.info("%s configs: %s", self.__class__.__name__, self.get_config())
示例#28
0
    def build(self, input_shape):
        input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
        input_tensor_shape = tf.TensorShape(input_tensor)
        if len(input_tensor_shape.as_list()) != 3:
            raise ValueError(
                "TNTransformerExpandCondense expects a three-dimensional input of "
                "shape [batch, sequence, width].")
        batch_size, sequence_length, hidden_size = input_tensor_shape

        if len(input_shape) == 2:
            mask_tensor_shape = tf.TensorShape(input_shape[1])
            expected_mask_tensor_shape = tf.TensorShape(
                [batch_size, sequence_length, sequence_length])
            if not expected_mask_tensor_shape.is_compatible_with(
                    mask_tensor_shape):
                raise ValueError(
                    "When passing a mask tensor to TNTransformerExpandCondense, the "
                    "mask tensor must be of shape [batch, "
                    "sequence_length, sequence_length] (here %s). Got a "
                    "mask tensor of shape %s." %
                    (expected_mask_tensor_shape, mask_tensor_shape))
        if hidden_size % self._num_heads != 0:
            raise ValueError(
                "The input size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, self._num_heads))
        self._attention_head_size = int(hidden_size // self._num_heads)
        common_kwargs = dict(kernel_regularizer=self._kernel_regularizer,
                             bias_regularizer=self._bias_regularizer,
                             activity_regularizer=self._activity_regularizer,
                             kernel_constraint=self._kernel_constraint,
                             bias_constraint=self._bias_constraint)
        self._attention_layer = tf.keras.layers.MultiHeadAttention(
            num_heads=self._num_heads,
            key_dim=self._attention_head_size,
            dropout=self._attention_dropout_rate,
            use_bias=self._use_bias,
            kernel_initializer=self._attention_initializer,
            bias_initializer=tf_utils.clone_initializer(
                self._bias_initializer),
            name="self_attention",
            **common_kwargs)
        self._attention_dropout = tf.keras.layers.Dropout(
            rate=self._dropout_rate)
        # Use float32 in layernorm for numeric stability.
        # It is probably safe in mixed_float16, but we haven't validated this yet.
        self._attention_layer_norm = (tf.keras.layers.LayerNormalization(
            name="self_attention_layer_norm",
            axis=-1,
            epsilon=self._norm_epsilon,
            dtype=tf.float32))

        # Substitute Dense layers with a single Expand-Condense layer.
        self._output_dense = TNExpandCondense(
            4,
            use_bias=True,
            activation=self._intermediate_activation,
            kernel_initializer=self._kernel_initializer,
            bias_initializer=self._bias_initializer)

        self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
        # Use float32 in layernorm for numeric stability.
        self._output_layer_norm = tf.keras.layers.LayerNormalization(
            name="output_layer_norm",
            axis=-1,
            epsilon=self._norm_epsilon,
            dtype=tf.float32)

        super(TNTransformerExpandCondense, self).build(input_shape)
示例#29
0
    def call(self, inputs):
        """Implements call() for the layer."""
        input_ids = inputs["input_ids"]
        segment_ids = inputs["segment_ids"]
        input_mask = inputs["input_mask"]
        state = inputs["state"]
        permutation_mask = inputs["permutation_mask"]
        target_mapping = inputs["target_mapping"]
        masked_tokens = inputs["masked_tokens"]

        batch_size = tf.shape(input_ids)[0]
        seq_length = tf.shape(input_ids)[1]
        if state is not None:
            memory_length = tf.shape(state[0])[1]
        else:
            memory_length = 0
        total_length = memory_length + seq_length

        if self._two_stream and masked_tokens is None:
            raise ValueError("`masked_tokens` must be provided in order to "
                             "initialize the query stream in "
                             "`TwoStreamRelativeAttention`.")
        if masked_tokens is not None and not self._two_stream:
            logging.warning(
                "`masked_tokens` is provided but `two_stream` is not "
                "enabled. Please enable `two_stream` to enable two "
                "stream attention.")

        if input_mask is not None:
            dtype = input_mask.dtype
        elif permutation_mask is not None:
            dtype = permutation_mask.dtype
        else:
            dtype = tf.int32
        query_attention_mask, content_attention_mask = _compute_attention_mask(
            input_mask=input_mask,
            permutation_mask=permutation_mask,
            attention_type=self._attention_type,
            seq_length=seq_length,
            memory_length=memory_length,
            batch_size=batch_size,
            dtype=dtype)
        relative_position_encoding = _compute_positional_encoding(
            attention_type=self._attention_type,
            position_encoding_layer=self.position_encoding,
            hidden_size=self._hidden_size,
            batch_size=batch_size,
            total_length=total_length,
            seq_length=seq_length,
            clamp_length=self._clamp_length,
            bi_data=self._bi_data,
            dtype=tf.float32)
        relative_position_encoding = self.embedding_dropout(
            relative_position_encoding)

        if segment_ids is None:
            segment_embedding = None
            segment_matrix = None
        else:
            if self._segment_embedding is None:
                self._segment_embedding = self.add_weight(
                    "seg_embed",
                    shape=[
                        self._num_layers, 2, self._num_attention_heads,
                        self._head_size
                    ],
                    dtype=tf.float32,
                    initializer=tf_utils.clone_initializer(self._initializer))

            segment_embedding = self._segment_embedding
            segment_matrix = _compute_segment_matrix(
                segment_ids=segment_ids,
                memory_length=memory_length,
                batch_size=batch_size,
                use_cls_mask=self._use_cls_mask)

        word_embeddings = self._embedding_layer(input_ids)
        content_stream = self._dropout(word_embeddings)

        if self._two_stream:
            if self._mask_embedding is None:
                self._mask_embedding = self.add_weight(
                    "mask_emb/mask_emb",
                    shape=[1, 1, self._hidden_size],
                    dtype=tf.float32)
            if target_mapping is None:
                masked_tokens = masked_tokens[:, :, None]
                masked_token_embedding = (
                    masked_tokens * self._mask_embedding +
                    (1 - masked_tokens) * word_embeddings)
            else:
                masked_token_embedding = tf.tile(
                    self._mask_embedding,
                    [batch_size, tf.shape(target_mapping)[1], 1])
            query_stream = self._dropout(masked_token_embedding)
        else:
            query_stream = None

        return self._transformer_xl(
            content_stream=content_stream,
            query_stream=query_stream,
            target_mapping=target_mapping,
            state=state,
            relative_position_encoding=relative_position_encoding,
            segment_matrix=segment_matrix,
            segment_embedding=segment_embedding,
            content_attention_mask=content_attention_mask,
            query_attention_mask=query_attention_mask)
示例#30
0
    def __init__(self,
                 vocab_size,
                 num_layers,
                 hidden_size,
                 num_attention_heads,
                 head_size,
                 inner_size,
                 dropout_rate,
                 attention_dropout_rate,
                 attention_type,
                 bi_data,
                 initializer,
                 two_stream=False,
                 tie_attention_biases=True,
                 memory_length=None,
                 clamp_length=-1,
                 reuse_length=None,
                 inner_activation="relu",
                 use_cls_mask=False,
                 embedding_width=None,
                 **kwargs):
        super(XLNetBase, self).__init__(**kwargs)

        self._vocab_size = vocab_size
        self._initializer = initializer
        self._attention_type = attention_type
        self._num_layers = num_layers
        self._hidden_size = hidden_size
        self._num_attention_heads = num_attention_heads
        self._head_size = head_size
        self._inner_size = inner_size
        self._inner_activation = inner_activation
        self._dropout_rate = dropout_rate
        self._attention_dropout_rate = attention_dropout_rate
        self._tie_attention_biases = tie_attention_biases
        self._two_stream = two_stream

        self._memory_length = memory_length
        self._reuse_length = reuse_length
        self._bi_data = bi_data
        self._clamp_length = clamp_length
        self._use_cls_mask = use_cls_mask

        self._segment_embedding = None
        self._mask_embedding = None
        self._embedding_width = embedding_width

        if embedding_width is None:
            embedding_width = hidden_size

        self._embedding_layer = layers.OnDeviceEmbedding(
            vocab_size=self._vocab_size,
            embedding_width=embedding_width,
            initializer=tf_utils.clone_initializer(self._initializer),
            dtype=tf.float32,
            name="word_embedding")
        self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)

        self.embedding_dropout = tf.keras.layers.Dropout(
            rate=self._dropout_rate)
        self.position_encoding = RelativePositionEncoding(self._hidden_size)

        self._transformer_xl = transformer_xl.TransformerXL(
            vocab_size=vocab_size,
            num_layers=num_layers,
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            head_size=head_size,
            inner_size=inner_size,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            initializer=initializer,
            two_stream=two_stream,
            tie_attention_biases=tie_attention_biases,
            memory_length=memory_length,
            reuse_length=reuse_length,
            inner_activation=inner_activation,
            name="transformer_xl")