Beispiel #1
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              cache=None,
              block_size=_DEFAULT_BLOCK_SIZE,
              connectivity_seed=None):
        """Applies BigBirdBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax autoregressive cache for fast decoding.
      block_size: Size of attention blocks.
      connectivity_seed: Optional seed for random block sparse attention.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = bigbird_attention.BigBirdSelfAttention(
            x,
            num_heads=num_heads,
            dtype=dtype,
            qkv_features=qkv_dim,
            attention_axis=(1, ),
            causal_mask=causal_mask,
            segmentation=inputs_segmentation,
            padding_mask=padding_mask,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic,
            cache=cache,
            block_size=block_size,
            connectivity_seed=connectivity_seed)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = common_layers.MlpBlock(y,
                                   mlp_dim=mlp_dim,
                                   dtype=dtype,
                                   dropout_rate=dropout_rate,
                                   deterministic=deterministic)

        return x + y
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              sliding_window_size=512,
              global_mask=None,
              causal_mask=False,
              dtype=jnp.float32,
              inputs_segmentation=None,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False):
        """Applies the LongformerBlock module.

    Args:
      inputs: input data of size `[bs, seq_len, features]`.
      qkv_dim: dimension of the query/key/value.
      mlp_dim: dimension of the mlp on top of attention block.
      num_heads: number of attention heads.
      sliding_window_size: size of sliding window attention to use.
      global_mask: boolean matrix of shape `[bs, seq_len]`, where `True`
        indicates that the position is globally attended. By default, no global
        attention is used.
      causal_mask: If true, apply causal attention mask.
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      padding_mask: bool, mask padding tokens.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: if true, apply dropout else don't.

    Returns:
      output of shape `[bs, seq_len, mlp_dim]`.
    """

        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = longformer_attention.LongformerSelfAttention(
            x,
            num_heads=num_heads,
            qkv_features=qkv_dim,
            sliding_window_size=sliding_window_size,
            global_mask=global_mask,
            causal_mask=causal_mask,
            dtype=dtype,
            segmentation=inputs_segmentation,
            padding_mask=padding_mask,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        y = nn.LayerNorm(x)
        y = common_layers.MlpBlock(y,
                                   mlp_dim=mlp_dim,
                                   dtype=dtype,
                                   dropout_rate=dropout_rate,
                                   deterministic=deterministic)

        return x + y
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              cache=None,
              attention_fn_cls=_DEFAULT_ATTENTION_FN_CLS,
              attention_fn_kwargs=None):
        """Applies PerformerBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax autoregressive cache for fast decoding.
      attention_fn_cls: Attention function key or callable.
      attention_fn_kwargs: Keywords to pass to `attention_fn_cls`.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        attention_fn = _make_attention_fn(
            attention_fn_cls, attention_fn_kwargs)(qkv_dim // num_heads,
                                                   unidirectional=causal_mask)
        x = nn.LayerNorm(inputs)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=causal_mask,
                             segmentation=inputs_segmentation,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic,
                             cache=cache,
                             attention_fn=attention_fn)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = common_layers.MlpBlock(y,
                                   mlp_dim=mlp_dim,
                                   dtype=dtype,
                                   dropout_rate=dropout_rate,
                                   deterministic=deterministic)

        return x + y
Beispiel #4
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              max_len=512,
              cache=None):
        """Applies LinformerBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      max_len: int, max sequence length.
      cache: flax autoregressive cache for fast decoding.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = linformer_attention.LinformerSelfAttention(
            x,
            num_heads=num_heads,
            dtype=dtype,
            qkv_features=qkv_dim,
            attention_axis=(1, ),
            causal_mask=causal_mask,
            segmentation=inputs_segmentation,
            padding_mask=padding_mask,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic,
            max_len=max_len,
            cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = common_layers.MlpBlock(y,
                                   mlp_dim=mlp_dim,
                                   dtype=dtype,
                                   dropout_rate=dropout_rate,
                                   deterministic=deterministic)

        return x + y
Beispiel #5
0
  def apply(self,
            inputs,
            qkv_dim,
            mlp_dim,
            num_heads,
            attention_patterns,
            dtype=jnp.float32,
            inputs_segmentation=None,
            padding_mask=None,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            deterministic=False,
            use_cls_token=False):
    """Applies the SparseTransformerBlock module.

    All Sparse Transformer attention patterns (both encoder and decoder) are
    causal. To apply the sparse attention pattern reported in the paper
    on the EnWik8 data set:
    attention_patterns = [
        sparse_attention.Fixed1Pattern(block_size=128),
        sparse_attention.Fixed2Pattern(block_size=128, c=32)
    ].

    Args:
      inputs: input data of size `[bs, seq_len, features]`.
      qkv_dim: dimension of the query/key/value.
      mlp_dim: dimension of the mlp on top of attention block.
      num_heads: number of attention heads.
      attention_patterns: list of sparse attention patterns to apply.
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      padding_mask: bool, mask padding tokens.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: if true, apply dropout else don't.
      use_cls_token: using cls token or not.

    Returns:
      output of shape `[bs, seq_len, mlp_dim]`.
    """

    assert inputs.ndim == 3
    x = nn.LayerNorm(inputs)
    x = sparse_attention.SparseSelfAttention(
        x,
        num_heads=num_heads,
        qkv_features=qkv_dim,
        attention_patterns=attention_patterns,
        dtype=dtype,
        segmentation=inputs_segmentation,
        padding_mask=padding_mask,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6),
        bias=False,
        broadcast_dropout=False,
        dropout_rate=attention_dropout_rate,
        deterministic=deterministic,
        use_cls_token=use_cls_token)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = x + inputs

    y = nn.LayerNorm(x)
    y = common_layers.MlpBlock(
        y,
        mlp_dim=mlp_dim,
        dtype=dtype,
        dropout_rate=dropout_rate,
        deterministic=deterministic)

    return x + y
  def apply(self,
            inputs,
            qkv_dim,
            mlp_dim,
            num_heads,
            dtype=jnp.float32,
            inputs_segmentation=None,
            causal_mask=False,
            padding_mask=None,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            deterministic=False,
            cache=None,
            max_length=512,
            ignore_dot_product=False,
            synthesizer_mode='random'):
    """Applies SynthesizerBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: dtype
      inputs_segmentation: input segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax autoregressive cache for fast decoding.
      max_length: int, the maximum supported sequence length.
      ignore_dot_product: bool, to ignore the dot product attention or not.
      synthesizer_mode: str support 'dense' and 'random' or 'dense+random'

    Returns:
      output after transformer block.

    """

    # Attention block.
    assert inputs.ndim == 3
    x = nn.LayerNorm(inputs)
    x = synthesizer_attention.SynthesizerSelfAttention(
        x,
        num_heads=num_heads,
        qkv_features=qkv_dim,
        attention_axis=(1,),
        causal_mask=causal_mask,
        padding_mask=padding_mask,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6),
        bias=False,
        broadcast_dropout=False,
        dropout_rate=attention_dropout_rate,
        deterministic=deterministic,
        cache=cache,
        max_length=max_length,
        ignore_dot_product=ignore_dot_product,
        synthesizer_mode=synthesizer_mode)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = x + inputs

    # MLP block.
    y = nn.LayerNorm(x)
    y = common_layers.MlpBlock(
        y,
        mlp_dim=mlp_dim,
        dropout_rate=dropout_rate,
        deterministic=deterministic)

    return x + y