def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, cache=None, block_size=_DEFAULT_BLOCK_SIZE, connectivity_seed=None): """Applies BigBirdBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) cache: flax autoregressive cache for fast decoding. block_size: Size of attention blocks. connectivity_seed: Optional seed for random block sparse attention. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = bigbird_attention.BigBirdSelfAttention( x, num_heads=num_heads, dtype=dtype, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache, block_size=block_size, connectivity_seed=connectivity_seed) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = common_layers.MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, sliding_window_size=512, global_mask=None, causal_mask=False, dtype=jnp.float32, inputs_segmentation=None, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False): """Applies the LongformerBlock module. Args: inputs: input data of size `[bs, seq_len, features]`. qkv_dim: dimension of the query/key/value. mlp_dim: dimension of the mlp on top of attention block. num_heads: number of attention heads. sliding_window_size: size of sliding window attention to use. global_mask: boolean matrix of shape `[bs, seq_len]`, where `True` indicates that the position is globally attended. By default, no global attention is used. causal_mask: If true, apply causal attention mask. dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. padding_mask: bool, mask padding tokens. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: if true, apply dropout else don't. Returns: output of shape `[bs, seq_len, mlp_dim]`. """ assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = longformer_attention.LongformerSelfAttention( x, num_heads=num_heads, qkv_features=qkv_dim, sliding_window_size=sliding_window_size, global_mask=global_mask, causal_mask=causal_mask, dtype=dtype, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs y = nn.LayerNorm(x) y = common_layers.MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, cache=None, attention_fn_cls=_DEFAULT_ATTENTION_FN_CLS, attention_fn_kwargs=None): """Applies PerformerBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) cache: flax autoregressive cache for fast decoding. attention_fn_cls: Attention function key or callable. attention_fn_kwargs: Keywords to pass to `attention_fn_cls`. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 attention_fn = _make_attention_fn( attention_fn_cls, attention_fn_kwargs)(qkv_dim // num_heads, unidirectional=causal_mask) x = nn.LayerNorm(inputs) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache, attention_fn=attention_fn) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = common_layers.MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, max_len=512, cache=None): """Applies LinformerBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) max_len: int, max sequence length. cache: flax autoregressive cache for fast decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = linformer_attention.LinformerSelfAttention( x, num_heads=num_heads, dtype=dtype, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, max_len=max_len, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = common_layers.MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, attention_patterns, dtype=jnp.float32, inputs_segmentation=None, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, use_cls_token=False): """Applies the SparseTransformerBlock module. All Sparse Transformer attention patterns (both encoder and decoder) are causal. To apply the sparse attention pattern reported in the paper on the EnWik8 data set: attention_patterns = [ sparse_attention.Fixed1Pattern(block_size=128), sparse_attention.Fixed2Pattern(block_size=128, c=32) ]. Args: inputs: input data of size `[bs, seq_len, features]`. qkv_dim: dimension of the query/key/value. mlp_dim: dimension of the mlp on top of attention block. num_heads: number of attention heads. attention_patterns: list of sparse attention patterns to apply. dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. padding_mask: bool, mask padding tokens. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: if true, apply dropout else don't. use_cls_token: using cls token or not. Returns: output of shape `[bs, seq_len, mlp_dim]`. """ assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = sparse_attention.SparseSelfAttention( x, num_heads=num_heads, qkv_features=qkv_dim, attention_patterns=attention_patterns, dtype=dtype, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, use_cls_token=use_cls_token) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs y = nn.LayerNorm(x) y = common_layers.MlpBlock( y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, cache=None, max_length=512, ignore_dot_product=False, synthesizer_mode='random'): """Applies SynthesizerBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: dtype inputs_segmentation: input segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) cache: flax autoregressive cache for fast decoding. max_length: int, the maximum supported sequence length. ignore_dot_product: bool, to ignore the dot product attention or not. synthesizer_mode: str support 'dense' and 'random' or 'dense+random' Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = synthesizer_attention.SynthesizerSelfAttention( x, num_heads=num_heads, qkv_features=qkv_dim, attention_axis=(1,), causal_mask=causal_mask, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache, max_length=max_length, ignore_dot_product=ignore_dot_product, synthesizer_mode=synthesizer_mode) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = common_layers.MlpBlock( y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=deterministic) return x + y