예제 #1
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              attention_fn=nn.dot_product_attention,
              cache=None):
        """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      attention_fn: dot product function to use inside attention.
      cache: Cache for decoding.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=causal_mask,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic,
                             attention_fn=attention_fn,
                             cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
예제 #2
0
파일: models.py 프로젝트: zhang-yd15/flax
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      qkv_dim: dimension of the query/key/value.
      mlp_dim: dimension of the mlp on top of attention block.
      num_heads: number of heads.
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      padding_mask: bool, mask padding tokens.
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout rate for attention weights.
      deterministic: bool, deterministic or not (to apply dropout).

    Returns:
      output after transformer encoder block.
    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs, dtype=dtype)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=x,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=False,
                             segmentation=inputs_segmentation,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x, dtype=dtype)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
예제 #3
0
  def apply(self,
            inputs,
            mlp_dim,
            inputs_masks=None,
            dtype=jnp.float32,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            deterministic=True,
            layer_drop_p=None,
            **attention_kwargs):
    """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      mlp_dim: dimension of the mlp on top of attention block.
      inputs_masks: bool, input mask.
      dtype: the dtype of the computation (default: float32).
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout for attention heads.
      deterministic: bool, deterministic or not (to apply dropout).
      layer_drop_p: probability of dropping a layer.
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output after transformer encoder block.
    """

    # Attention block.
    assert inputs.ndim == 3
    x = nn.LayerNorm(inputs, dtype=dtype)
    x = nn.SelfAttention(
        x,
        dtype=dtype,
        inputs_kv=x,
        attention_axis=(1,),
        causal_mask=False,
        padding_mask=inputs_masks,
        kernel_init=nn.initializers.xavier_uniform(),
        broadcast_dropout=False,
        deterministic=deterministic,
        dropout_rate=attention_dropout_rate,
        **attention_kwargs)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)

    drop_pattern = self.get_drop_pattern(x, layer_drop_p)
    x = x * (1.0 - drop_pattern) + inputs

    # MLP block.
    y = nn.LayerNorm(x, dtype=dtype)
    y = MlpBlock(
        y,
        mlp_dim=mlp_dim,
        dtype=dtype,
        dropout_rate=dropout_rate,
        deterministic=deterministic)

    drop_pattern = self.get_drop_pattern(x, layer_drop_p)
    return y * (1.0 - drop_pattern) + x
예제 #4
0
  def apply(self,
            inputs,
            qkv_dim,
            mlp_dim,
            num_heads,
            causal_mask=False,
            padding_mask=None,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            train=True,
            normalizer='layer_norm',
            attention_fn=None,
            cache=None):
    """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      train: bool: if model is training.
      normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
        'pre_layer_norm', 'none'
      attention_fn: Attention function to use. If None, defaults to
        nn.dot_product_attention.
      cache: flax autoregressive cache for fast decoding.

    Returns:
      output after transformer block.

    """

    # Attention block.
    assert inputs.ndim == 3
    if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm', 'none']:
      maybe_pre_normalize = model_utils.get_normalizer(normalizer, train)
      maybe_post_normalize = model_utils.get_normalizer('none', train)
    elif normalizer == 'post_layer_norm':
      maybe_pre_normalize = model_utils.get_normalizer('none', train)
      maybe_post_normalize = model_utils.get_normalizer(normalizer, train)
    else:
      raise ValueError('Unsupported normalizer: {}'.format(normalizer))

    x = maybe_pre_normalize(inputs)

    if attention_fn is None:
      attention_fn = nn.dot_product_attention
    x = nn.SelfAttention(
        x,
        num_heads=num_heads,
        qkv_features=qkv_dim,
        attention_axis=(1,),
        causal_mask=causal_mask,
        padding_mask=padding_mask,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6),
        bias=False,
        broadcast_dropout=False,
        attention_fn=attention_fn,
        dropout_rate=attention_dropout_rate,
        deterministic=not train,
        cache=cache)
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
    x = x + inputs
    x = maybe_post_normalize(x)

    # MLP block.
    y = maybe_pre_normalize(x)
    y = MlpBlock(
        y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=not train)
    res = x + y

    return maybe_post_normalize(res)
예제 #5
0
파일: models.py 프로젝트: zhang-yd15/flax
    def apply(self,
              targets,
              encoded,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              targets_segmentation=None,
              padding_mask=None,
              key_padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              cache=None):
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32)
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      key_padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax attention cache for fast decoding.

    Returns:
      output after transformer encoder-decoder block.
    """

        # Decoder block.
        assert targets.ndim == 3
        x = nn.LayerNorm(targets, dtype=dtype)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=x,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=True,
                             padding_mask=padding_mask,
                             segmentation=targets_segmentation,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic,
                             cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + targets

        # Encoder-Decoder block.
        y = nn.LayerNorm(x, dtype=dtype)
        y = nn.SelfAttention(y,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=encoded,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=False,
                             padding_mask=padding_mask,
                             key_padding_mask=key_padding_mask,
                             segmentation=targets_segmentation,
                             key_segmentation=inputs_segmentation,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic)
        y = nn.dropout(y, rate=dropout_rate, deterministic=deterministic)
        y = y + x

        # MLP block.
        z = nn.LayerNorm(y, dtype=dtype)
        z = MlpBlock(z,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return y + z
예제 #6
0
    def apply(self,
              targets,
              encoded,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              targets_segmentation=None,
              padding_mask=None,
              key_padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              normalizer='layer_norm',
              cache=None):
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: <float>[batch_size, target_sequence_length, qkv_dim]
      encoded: <float>[batch_size, input_sequence_length, qkv_dim]
      qkv_dim: Dimension of the query/key/value.
      mlp_dim: Dimension of the mlp on top of attention block.
      num_heads: Number of heads.
      dtype: Dtype of the computation (default: float32).
      inputs_segmentation: Input segmentation info for packed examples.
      targets_segmentation: Iarget segmentation info for packed examples.
      padding_mask: <bool> Mask padding tokens.
      key_padding_mask: <bool> Mask padding tokens.
      dropout_rate: <float> Dropout rate.
      attention_dropout_rate: <float> Dropout rate for attention weights
      deterministic: <bool> Deterministic or not (to apply dropout)
      normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
        'pre_layer_norm', 'none'
      cache: Flax attention cache for fast decoding.

    Returns:
      output: <float>[batch_size, target_sequence_length, qkv_dim]
    """

        # Decoder block.
        assert targets.ndim == 3
        if normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(
                normalizer, not deterministic)
            maybe_post_normalize = model_utils.get_normalizer(
                'none', not deterministic)
        elif normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer(
                'none', not deterministic)
            maybe_post_normalize = model_utils.get_normalizer(
                normalizer, not deterministic)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(normalizer))

        x = maybe_pre_normalize(targets)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=x,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=True,
                             padding_mask=padding_mask,
                             segmentation=targets_segmentation,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic,
                             cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + targets
        x = maybe_post_normalize(x)

        # Encoder-Decoder block.
        # TODO(ankugarg): Support for confgurable pre vs post layernorm.
        y = maybe_pre_normalize(x)
        y = nn.SelfAttention(y,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=encoded,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=False,
                             padding_mask=padding_mask,
                             key_padding_mask=key_padding_mask,
                             segmentation=targets_segmentation,
                             key_segmentation=inputs_segmentation,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic)
        y = nn.dropout(y, rate=dropout_rate, deterministic=deterministic)
        y = y + x
        y = maybe_post_normalize(y)

        # MLP block.
        z = maybe_pre_normalize(y)
        z = MlpBlock(z,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)
        res = y + z

        return maybe_post_normalize(res)
예제 #7
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              normalizer='layer_norm',
              deterministic=False):
        """Applies Encoder1DBlock module.

    Args:
      inputs: <float>[batch_size, input_sequence_length, qkv_dim]
      qkv_dim: <int> Dimension of the query/key/value.
      mlp_dim: <int> Dimension of the mlp on top of attention block.
      num_heads: <int> Number of heads.
      dtype: Dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      padding_mask: <bool> Mask padding tokens.
      dropout_rate: <float> Dropout rate.
      attention_dropout_rate: <float> Dropout rate for attention weights.
      normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
        'pre_layer_norm', 'none'
      deterministic: <bool> Deterministic or not (to apply dropout).

    Returns:
      Output: <float>[batch_size, input_sequence_length, qkv_dim]
    """

        # Attention block.
        assert inputs.ndim == 3
        if normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(
                normalizer, not deterministic)
            maybe_post_normalize = model_utils.get_normalizer(
                'none', not deterministic)
        elif normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer(
                'none', not deterministic)
            maybe_post_normalize = model_utils.get_normalizer(
                normalizer, not deterministic)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(normalizer))

        x = maybe_pre_normalize(inputs)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=x,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=False,
                             segmentation=inputs_segmentation,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs
        x = maybe_post_normalize(x)

        # MLP block.
        y = maybe_pre_normalize(x)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)
        res = x + y

        return maybe_post_normalize(res)
예제 #8
0
  def apply(self,
            inputs,
            qkv_dim,
            mlp_dim,
            num_heads,
            dtype=jnp.float32,
            inputs_segmentation=None,
            causal_mask=False,
            padding_mask=None,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            deterministic=False,
            cache=None,
            attention_fn=None):
    """Applies TransformerBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax autoregressive cache for fast decoding.

    Returns:
      output after transformer block.

    """

    if attention_fn is None:
      attention_fn = nn.attention.dot_product_attention

    # Attention block.
    assert inputs.ndim == 3
    x = nn.LayerNorm(inputs)
    x = nn.SelfAttention(
        x,
        num_heads=num_heads,
        dtype=dtype,
        qkv_features=qkv_dim,
        attention_axis=(1,),
        causal_mask=causal_mask,
        segmentation=inputs_segmentation,
        padding_mask=padding_mask,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6),
        bias=False,
        broadcast_dropout=False,
        dropout_rate=attention_dropout_rate,
        deterministic=deterministic,
        cache=cache,
        attention_fn=attention_fn)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
    x = x + inputs

    # MLP block.
    y = nn.LayerNorm(x)
    y = common_layers.MlpBlock(
        y,
        mlp_dim=mlp_dim,
        dtype=dtype,
        dropout_rate=dropout_rate,
        deterministic=deterministic)

    return x + y