def __call__(self, inputs, encoder_mask=None, encoder_relative_position=None): """Applies Transformer block. Args: inputs: input data `[batch_size, ..., length, dim]` encoder_mask: encoder self-attention mask encoder_relative_position: encoder relative positions tensor `[batch_sizes..., length, length]' Returns: Encoded input data `[batch_size, ..., length, mlp_dim]` """ cfg = self.config # Attention block. x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if cfg.use_relative_attention: x = relative_attention.RelativeSelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, bidirectional=self.bidirectional_attention, num_relative_position_buckets=self. num_relative_position_buckets, max_distance=self.max_distance)(x, encoder_mask, encoder_relative_position) else: x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(x, encoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = MLPBlock(config=cfg)(y) return x + y
def do_flat_encoded_self_attention(self, flat_encoded, mod_position): """Does self-attention for the flat encoding.""" cfg = self.config.base_config x = nn.LayerNorm(dtype=cfg.dtype)(flat_encoded) x = relative_attention.RelativeSelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, bidirectional=True, num_relative_position_buckets=( cfg.num_flat_encoding_relative_position_buckets), max_distance=cfg.max_flat_encoding_distance, mod_position=mod_position)( x, None, None) x = nn.Dropout(rate=cfg.dropout_rate)( x, deterministic=cfg.deterministic) x = x + flat_encoded return x
def __call__(self, targets, encoded, decoder_mask = None, encoder_decoder_mask = None, decoder_relative_position = None, encoder_decoder_relative_position = None): """Applies Transformer block. Args: targets: input data for decoder `[batch_size, ..., length, dim]` encoded: input data from encoder `[batch_size, ..., length2, dim2]` decoder_mask: decoder self-attention mask encoder_decoder_mask: encoder-decoder attention mask decoder_relative_position: decoder relative positions tensor `[batch_sizes..., length2, length2]' encoder_decoder_relative_position: encoder-decoder relative tensor `[batch_sizes..., length2, length]' Returns: Decoded data `[batch_size, ..., length2, mlp_dim]` """ cfg = self.config # Decoder block. x = nn.LayerNorm(dtype=cfg.dtype)(targets) if cfg.use_relative_attention: x = relative_attention.RelativeSelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, bidirectional=self.bidirectional_attention, num_relative_position_buckets=self.num_relative_position_buckets, max_distance=self.max_distance)( x, decoder_mask, decoder_relative_position) else: x = nn.SelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(x, decoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)( x, deterministic=cfg.deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(dtype=cfg.dtype)(x) if self.relative_cross_attention: y = relative_attention.RelativeMultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic, bidirectional=self.bidirectional_cross_attention, num_relative_position_buckets=( self.num_relative_position_buckets_cross_attention), max_distance=self.max_distance_cross_attention)( y, encoded, encoder_decoder_mask, encoder_decoder_relative_position) else: y = nn.MultiHeadDotProductAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask) y = nn.Dropout(rate=cfg.dropout_rate)( y, deterministic=cfg.deterministic) y = y + x # MLP block. z = nn.LayerNorm(dtype=cfg.dtype)(y) z = MLPBlock(config=cfg)(z) return y + z