def __call__(self,
                 inputs,
                 encoder_mask=None,
                 encoder_relative_position=None):
        """Applies Transformer block.

    Args:
      inputs: input data `[batch_size, ..., length, dim]`
      encoder_mask: encoder self-attention mask
      encoder_relative_position: encoder relative positions tensor
          `[batch_sizes..., length, length]'

    Returns:
      Encoded input data `[batch_size, ..., length, mlp_dim]`
    """
        cfg = self.config

        # Attention block.
        x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        if cfg.use_relative_attention:
            x = relative_attention.RelativeSelfAttention(
                num_heads=cfg.num_heads,
                dtype=cfg.dtype,
                qkv_features=cfg.qkv_dim,
                kernel_init=cfg.kernel_init,
                bias_init=cfg.bias_init,
                use_bias=False,
                broadcast_dropout=False,
                dropout_rate=cfg.attention_dropout_rate,
                deterministic=cfg.deterministic,
                bidirectional=self.bidirectional_attention,
                num_relative_position_buckets=self.
                num_relative_position_buckets,
                max_distance=self.max_distance)(x, encoder_mask,
                                                encoder_relative_position)
        else:
            x = nn.SelfAttention(num_heads=cfg.num_heads,
                                 dtype=cfg.dtype,
                                 qkv_features=cfg.qkv_dim,
                                 kernel_init=cfg.kernel_init,
                                 bias_init=cfg.bias_init,
                                 use_bias=False,
                                 broadcast_dropout=False,
                                 dropout_rate=cfg.attention_dropout_rate,
                                 deterministic=cfg.deterministic)(x,
                                                                  encoder_mask)

        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = MLPBlock(config=cfg)(y)

        return x + y
 def do_flat_encoded_self_attention(self, flat_encoded, mod_position):
   """Does self-attention for the flat encoding."""
   cfg = self.config.base_config
   x = nn.LayerNorm(dtype=cfg.dtype)(flat_encoded)
   x = relative_attention.RelativeSelfAttention(
       num_heads=cfg.num_heads,
       dtype=cfg.dtype,
       qkv_features=cfg.qkv_dim,
       kernel_init=cfg.kernel_init,
       bias_init=cfg.bias_init,
       use_bias=False,
       broadcast_dropout=False,
       dropout_rate=cfg.attention_dropout_rate,
       deterministic=cfg.deterministic,
       bidirectional=True,
       num_relative_position_buckets=(
           cfg.num_flat_encoding_relative_position_buckets),
       max_distance=cfg.max_flat_encoding_distance,
       mod_position=mod_position)(
           x, None, None)
   x = nn.Dropout(rate=cfg.dropout_rate)(
       x, deterministic=cfg.deterministic)
   x = x + flat_encoded
   return x
示例#3
0
  def __call__(self,
               targets,
               encoded,
               decoder_mask = None,
               encoder_decoder_mask = None,
               decoder_relative_position = None,
               encoder_decoder_relative_position = None):
    """Applies Transformer block.

    Args:
      targets: input data for decoder `[batch_size, ..., length, dim]`
      encoded: input data from encoder `[batch_size, ..., length2, dim2]`
      decoder_mask: decoder self-attention mask
      encoder_decoder_mask: encoder-decoder attention mask
      decoder_relative_position: decoder relative positions tensor
          `[batch_sizes..., length2, length2]'
      encoder_decoder_relative_position: encoder-decoder relative tensor
          `[batch_sizes..., length2, length]'

    Returns:
      Decoded data `[batch_size, ..., length2, mlp_dim]`
    """
    cfg = self.config

    # Decoder block.
    x = nn.LayerNorm(dtype=cfg.dtype)(targets)
    if cfg.use_relative_attention:
      x = relative_attention.RelativeSelfAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic,
          bidirectional=self.bidirectional_attention,
          num_relative_position_buckets=self.num_relative_position_buckets,
          max_distance=self.max_distance)(
              x, decoder_mask, decoder_relative_position)
    else:
      x = nn.SelfAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic)(x, decoder_mask)

    x = nn.Dropout(rate=cfg.dropout_rate)(
        x, deterministic=cfg.deterministic)
    x = x + targets

    # Encoder-Decoder block.
    y = nn.LayerNorm(dtype=cfg.dtype)(x)
    if self.relative_cross_attention:
      y = relative_attention.RelativeMultiHeadDotProductAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic,
          bidirectional=self.bidirectional_cross_attention,
          num_relative_position_buckets=(
              self.num_relative_position_buckets_cross_attention),
          max_distance=self.max_distance_cross_attention)(
              y, encoded, encoder_decoder_mask,
              encoder_decoder_relative_position)
    else:
      y = nn.MultiHeadDotProductAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask)

    y = nn.Dropout(rate=cfg.dropout_rate)(
        y, deterministic=cfg.deterministic)
    y = y + x

    # MLP block.
    z = nn.LayerNorm(dtype=cfg.dtype)(y)
    z = MLPBlock(config=cfg)(z)

    return y + z