def apply(self, inputs, qkv_dim, mlp_dim, num_heads, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, attention_fn=nn.dot_product_attention, cache=None): """Applies Transformer1DBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) attention_fn: dot product function to use inside attention. cache: Cache for decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = nn.SelfAttention(x, num_heads=num_heads, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, attention_fn=attention_fn, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = MlpBlock(y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False): """Applies Encoder1DBlock module. Args: inputs: input data. qkv_dim: dimension of the query/key/value. mlp_dim: dimension of the mlp on top of attention block. num_heads: number of heads. dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. padding_mask: bool, mask padding tokens. dropout_rate: dropout rate. attention_dropout_rate: dropout rate for attention weights. deterministic: bool, deterministic or not (to apply dropout). Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs, dtype=dtype) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, inputs_kv=x, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=False, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x, dtype=dtype) y = MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, mlp_dim, inputs_masks=None, dtype=jnp.float32, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=True, layer_drop_p=None, **attention_kwargs): """Applies Encoder1DBlock module. Args: inputs: input data. mlp_dim: dimension of the mlp on top of attention block. inputs_masks: bool, input mask. dtype: the dtype of the computation (default: float32). dropout_rate: dropout rate. attention_dropout_rate: dropout for attention heads. deterministic: bool, deterministic or not (to apply dropout). layer_drop_p: probability of dropping a layer. **attention_kwargs: kwargs passed to nn.SelfAttention Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs, dtype=dtype) x = nn.SelfAttention( x, dtype=dtype, inputs_kv=x, attention_axis=(1,), causal_mask=False, padding_mask=inputs_masks, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, dropout_rate=attention_dropout_rate, **attention_kwargs) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) drop_pattern = self.get_drop_pattern(x, layer_drop_p) x = x * (1.0 - drop_pattern) + inputs # MLP block. y = nn.LayerNorm(x, dtype=dtype) y = MlpBlock( y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) drop_pattern = self.get_drop_pattern(x, layer_drop_p) return y * (1.0 - drop_pattern) + x
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, train=True, normalizer='layer_norm', attention_fn=None, cache=None): """Applies Transformer1DBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights train: bool: if model is training. normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' attention_fn: Attention function to use. If None, defaults to nn.dot_product_attention. cache: flax autoregressive cache for fast decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm', 'none']: maybe_pre_normalize = model_utils.get_normalizer(normalizer, train) maybe_post_normalize = model_utils.get_normalizer('none', train) elif normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer('none', train) maybe_post_normalize = model_utils.get_normalizer(normalizer, train) else: raise ValueError('Unsupported normalizer: {}'.format(normalizer)) x = maybe_pre_normalize(inputs) if attention_fn is None: attention_fn = nn.dot_product_attention x = nn.SelfAttention( x, num_heads=num_heads, qkv_features=qkv_dim, attention_axis=(1,), causal_mask=causal_mask, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, attention_fn=attention_fn, dropout_rate=attention_dropout_rate, deterministic=not train, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) x = x + inputs x = maybe_post_normalize(x) # MLP block. y = maybe_pre_normalize(x) y = MlpBlock( y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=not train) res = x + y return maybe_post_normalize(res)
def apply(self, targets, encoded, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, targets_segmentation=None, padding_mask=None, key_padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, cache=None): """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32) inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens key_padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) cache: flax attention cache for fast decoding. Returns: output after transformer encoder-decoder block. """ # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(targets, dtype=dtype) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, inputs_kv=x, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=True, padding_mask=padding_mask, segmentation=targets_segmentation, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(x, dtype=dtype) y = nn.SelfAttention(y, num_heads=num_heads, dtype=dtype, inputs_kv=encoded, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=False, padding_mask=padding_mask, key_padding_mask=key_padding_mask, segmentation=targets_segmentation, key_segmentation=inputs_segmentation, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) y = nn.dropout(y, rate=dropout_rate, deterministic=deterministic) y = y + x # MLP block. z = nn.LayerNorm(y, dtype=dtype) z = MlpBlock(z, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return y + z
def apply(self, targets, encoded, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, targets_segmentation=None, padding_mask=None, key_padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, normalizer='layer_norm', cache=None): """Applies EncoderDecoder1DBlock module. Args: targets: <float>[batch_size, target_sequence_length, qkv_dim] encoded: <float>[batch_size, input_sequence_length, qkv_dim] qkv_dim: Dimension of the query/key/value. mlp_dim: Dimension of the mlp on top of attention block. num_heads: Number of heads. dtype: Dtype of the computation (default: float32). inputs_segmentation: Input segmentation info for packed examples. targets_segmentation: Iarget segmentation info for packed examples. padding_mask: <bool> Mask padding tokens. key_padding_mask: <bool> Mask padding tokens. dropout_rate: <float> Dropout rate. attention_dropout_rate: <float> Dropout rate for attention weights deterministic: <bool> Deterministic or not (to apply dropout) normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' cache: Flax attention cache for fast decoding. Returns: output: <float>[batch_size, target_sequence_length, qkv_dim] """ # Decoder block. assert targets.ndim == 3 if normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer( normalizer, not deterministic) maybe_post_normalize = model_utils.get_normalizer( 'none', not deterministic) elif normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( 'none', not deterministic) maybe_post_normalize = model_utils.get_normalizer( normalizer, not deterministic) else: raise ValueError('Unsupported normalizer: {}'.format(normalizer)) x = maybe_pre_normalize(targets) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, inputs_kv=x, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=True, padding_mask=padding_mask, segmentation=targets_segmentation, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + targets x = maybe_post_normalize(x) # Encoder-Decoder block. # TODO(ankugarg): Support for confgurable pre vs post layernorm. y = maybe_pre_normalize(x) y = nn.SelfAttention(y, num_heads=num_heads, dtype=dtype, inputs_kv=encoded, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=False, padding_mask=padding_mask, key_padding_mask=key_padding_mask, segmentation=targets_segmentation, key_segmentation=inputs_segmentation, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) y = nn.dropout(y, rate=dropout_rate, deterministic=deterministic) y = y + x y = maybe_post_normalize(y) # MLP block. z = maybe_pre_normalize(y) z = MlpBlock(z, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) res = y + z return maybe_post_normalize(res)
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, normalizer='layer_norm', deterministic=False): """Applies Encoder1DBlock module. Args: inputs: <float>[batch_size, input_sequence_length, qkv_dim] qkv_dim: <int> Dimension of the query/key/value. mlp_dim: <int> Dimension of the mlp on top of attention block. num_heads: <int> Number of heads. dtype: Dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. padding_mask: <bool> Mask padding tokens. dropout_rate: <float> Dropout rate. attention_dropout_rate: <float> Dropout rate for attention weights. normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' deterministic: <bool> Deterministic or not (to apply dropout). Returns: Output: <float>[batch_size, input_sequence_length, qkv_dim] """ # Attention block. assert inputs.ndim == 3 if normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer( normalizer, not deterministic) maybe_post_normalize = model_utils.get_normalizer( 'none', not deterministic) elif normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( 'none', not deterministic) maybe_post_normalize = model_utils.get_normalizer( normalizer, not deterministic) else: raise ValueError('Unsupported normalizer: {}'.format(normalizer)) x = maybe_pre_normalize(inputs) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, inputs_kv=x, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=False, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs x = maybe_post_normalize(x) # MLP block. y = maybe_pre_normalize(x) y = MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) res = x + y return maybe_post_normalize(res)
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, cache=None, attention_fn=None): """Applies TransformerBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) cache: flax autoregressive cache for fast decoding. Returns: output after transformer block. """ if attention_fn is None: attention_fn = nn.attention.dot_product_attention # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = nn.SelfAttention( x, num_heads=num_heads, dtype=dtype, qkv_features=qkv_dim, attention_axis=(1,), causal_mask=causal_mask, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache, attention_fn=attention_fn) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = common_layers.MlpBlock( y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y