Esempio n. 1
0
  def _init_mixing_sublayer(self, layer, model_arch,
                            mixing_key):
    """Initializes config-dependent mixing sublayer."""
    if model_arch == ModelArchitecture.BERT:
      mixing_sublayer = nn.SelfAttention(
          num_heads=self.config.num_heads,
          qkv_features=self.config.d_model,
          broadcast_dropout=False,
          kernel_init=default_kernel_init,
          bias_init=default_bias_init,
          dropout_rate=self.config.mixing_dropout_rate,
          use_bias=True,
          name=f"self_attention_{layer}")
    elif model_arch == ModelArchitecture.F_NET:
      mixing_sublayer = layers.FourierTransform(
          fourier_transform=self.fourier_transform,
          name=f"fourier_transform_{layer}")
    elif model_arch == ModelArchitecture.FF_ONLY:
      mixing_sublayer = layers.IdentityTransform(
          name=f"identity_transform_{layer}")
    elif model_arch == ModelArchitecture.LINEAR:
      mixing_sublayer = layers.LinearTransform(
          precision=lax.Precision.DEFAULT, name=f"linear_transform_{layer}")
    elif model_arch == ModelArchitecture.RANDOM:
      mixing_sublayer = layers.RandomTransform(
          max_seq_length=self.config.max_seq_length,
          d_model=self.config.d_model,
          key=mixing_key,
          precision=lax.Precision.DEFAULT,
          name=f"random_transform_{layer}")
    else:
      raise ValueError("Unexpected model architecture: %s" % model_arch.name)

    return mixing_sublayer
Esempio n. 2
0
  def test_decoding(self, spatial_shape, attn_dims):
    bs = 2
    num_heads = 3
    num_features = 4
    rng = random.PRNGKey(0)
    key1, key2 = random.split(rng)
    inputs = random.normal(
        key1, (bs,) + spatial_shape + (num_heads * num_features,))
    module = nn.SelfAttention(
        num_heads=num_heads,
        qkv_features=num_heads * num_features,
        precision=lax.Precision.HIGHEST,
        decode=False)
    decode_module = module.clone(decode=True)

    initial_vars = decode_module.init(key2, inputs)
    causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape))
    y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))(
        inputs, causal_mask)
    # feed the inputs sequentially to simulate decoding
    def body_fn(vars_in, x):
      y, vars_out = decode_module.apply(vars_in, x,
                                        mutable=['cache'])
      return vars_out, y
    # scan_in_dim supports scanning multiple dims
    _, y = jax_utils.scan_in_dim(body_fn, initial_vars, inputs,
                                 axis=attn_dims, keepdims=True)

    np.testing.assert_allclose(y_ref, y, atol=1e-5)
Esempio n. 3
0
    def __call__(self, inputs, encoder_mask=None):
        """Applies Encoder1DBlock module.
    Args:
      inputs: input data.
      encoder_mask: encoder self-attention mask.
    Returns:
      output after transformer encoder block.
    """
        cfg = self.config

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        x = nn.SelfAttention(num_heads=cfg.num_heads,
                             dtype=cfg.dtype,
                             qkv_features=cfg.qkv_dim,
                             kernel_init=cfg.kernel_init,
                             bias_init=cfg.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=cfg.attention_dropout_rate,
                             deterministic=cfg.deterministic)(x, encoder_mask)

        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = MlpBlock(config=cfg)(y)

        return x + y
Esempio n. 4
0
    def __call__(self, inputs, encoder_mask=None):
        """Applies Transformer block.

    Args:
      inputs: input data `[batch_size, ..., length, dim]`
      encoder_mask: encoder self-attention mask

    Returns:
      Encoded input data `[batch_size, ..., length, mlp_dim]`
    """
        cfg = self.config

        # Attention block.
        x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        x = nn.SelfAttention(num_heads=cfg.num_heads,
                             dtype=cfg.dtype,
                             qkv_features=cfg.qkv_dim,
                             kernel_init=cfg.kernel_init,
                             bias_init=cfg.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=cfg.attention_dropout_rate,
                             deterministic=cfg.deterministic)(x, encoder_mask)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = MLPBlock(config=cfg)(y)

        return x + y
Esempio n. 5
0
  def __call__(self,
               targets,
               encoded,
               decoder_mask=None,
               encoder_decoder_mask=None):
    """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output after transformer encoder-decoder block.
    """
    cfg = self.config

    # Decoder block.
    assert targets.ndim == 3
    x = nn.LayerNorm(dtype=cfg.dtype)(targets)
    x = nn.SelfAttention(
        num_heads=cfg.num_heads,
        dtype=cfg.dtype,
        qkv_features=cfg.qkv_dim,
        kernel_init=cfg.kernel_init,
        bias_init=cfg.bias_init,
        use_bias=False,
        broadcast_dropout=False,
        dropout_rate=cfg.attention_dropout_rate,
        deterministic=cfg.deterministic,
        decode=cfg.decode)(x, decoder_mask)
    x = nn.Dropout(rate=cfg.dropout_rate)(
        x, deterministic=cfg.deterministic)
    x = x + targets

    # Encoder-Decoder block.
    y = nn.LayerNorm(dtype=cfg.dtype)(x)
    y = nn.MultiHeadDotProductAttention(
        num_heads=cfg.num_heads,
        dtype=cfg.dtype,
        qkv_features=cfg.qkv_dim,
        kernel_init=cfg.kernel_init,
        bias_init=cfg.bias_init,
        use_bias=False,
        broadcast_dropout=False,
        dropout_rate=cfg.attention_dropout_rate,
        deterministic=cfg.deterministic)(
            y, encoded, encoder_decoder_mask)

    y = nn.Dropout(rate=cfg.dropout_rate)(
        y, deterministic=cfg.deterministic)
    y = y + x

    # MLP block.
    z = nn.LayerNorm(dtype=cfg.dtype)(y)
    z = MlpBlock(config=cfg)(z)

    return y + z
    def __call__(self,
                 inputs,
                 encoder_mask=None,
                 encoder_relative_position=None):
        """Applies Transformer block.

    Args:
      inputs: input data `[batch_size, ..., length, dim]`
      encoder_mask: encoder self-attention mask
      encoder_relative_position: encoder relative positions tensor
          `[batch_sizes..., length, length]'

    Returns:
      Encoded input data `[batch_size, ..., length, mlp_dim]`
    """
        cfg = self.config

        # Attention block.
        x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        if cfg.use_relative_attention:
            x = relative_attention.RelativeSelfAttention(
                num_heads=cfg.num_heads,
                dtype=cfg.dtype,
                qkv_features=cfg.qkv_dim,
                kernel_init=cfg.kernel_init,
                bias_init=cfg.bias_init,
                use_bias=False,
                broadcast_dropout=False,
                dropout_rate=cfg.attention_dropout_rate,
                deterministic=cfg.deterministic,
                bidirectional=self.bidirectional_attention,
                num_relative_position_buckets=self.
                num_relative_position_buckets,
                max_distance=self.max_distance)(x, encoder_mask,
                                                encoder_relative_position)
        else:
            x = nn.SelfAttention(num_heads=cfg.num_heads,
                                 dtype=cfg.dtype,
                                 qkv_features=cfg.qkv_dim,
                                 kernel_init=cfg.kernel_init,
                                 bias_init=cfg.bias_init,
                                 use_bias=False,
                                 broadcast_dropout=False,
                                 dropout_rate=cfg.attention_dropout_rate,
                                 deterministic=cfg.deterministic)(x,
                                                                  encoder_mask)

        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = MLPBlock(config=cfg)(y)

        return x + y
Esempio n. 7
0
    def __call__(self,
                 targets,
                 encoded,
                 decoder_mask=None,
                 encoder_decoder_mask=None):
        """Applies Transformer block.

    Args:
      targets: input data for decoder `[batch_size, ..., length, dim]`
      encoded: input data from encoder `[batch_size, ..., length2, dim2]`
      decoder_mask: decoder self-attention mask
      encoder_decoder_mask: encoder-decoder attention mask

    Returns:
      Decoded data `[batch_size, ..., length, mlp_dim]`
    """
        cfg = self.config

        # Decoder block.
        x = nn.LayerNorm(dtype=cfg.dtype)(targets)
        x = nn.SelfAttention(num_heads=cfg.num_heads,
                             dtype=cfg.dtype,
                             qkv_features=cfg.qkv_dim,
                             kernel_init=cfg.kernel_init,
                             bias_init=cfg.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=cfg.attention_dropout_rate,
                             deterministic=cfg.deterministic,
                             decode=cfg.decode)(x, decoder_mask)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + targets

        # Encoder-Decoder block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = nn.MultiHeadDotProductAttention(
            num_heads=cfg.num_heads,
            dtype=cfg.dtype,
            qkv_features=cfg.qkv_dim,
            kernel_init=cfg.kernel_init,
            bias_init=cfg.bias_init,
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=cfg.attention_dropout_rate,
            deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask)
        y = nn.Dropout(rate=cfg.dropout_rate)(y,
                                              deterministic=cfg.deterministic)
        y = y + x

        # MLP block.
        z = nn.LayerNorm(dtype=cfg.dtype)(y)
        z = MLPBlock(config=cfg)(z)

        return y + z
    def __call__(self, pixel_embeddings, patch_embeddings):

        cfg = self.config

        v = cfg.image_size // cfg.patch_size

        #Inner T-Block
        x = nn.LayerNorm(dtype=cfg.dtype)(pixel_embeddings)
        x = nn.SelfAttention(num_heads=cfg.inner_heads,
                             qkv_features=cfg.inner_heads * cfg.inner_dim_head,
                             out_features=cfg.inner_dim,
                             use_bias=False,
                             kernel_init=cfg.kernel_init,
                             deterministic=True)(x)
        x = x + pixel_embeddings
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = MlpBlock(config=cfg, inner=True)(y)
        inner_output = x + y

        x = rearrange(pixel_embeddings, '... n d -> ... (n d)')
        x = nn.Dense(cfg.outer_dim,
                     dtype=cfg.dtype,
                     kernel_init=cfg.kernel_init,
                     bias_init=cfg.bias_init)(x)
        x = rearrange(x, '(b h w) d -> b (h w) d', h=v, w=v)
        x = jnp.pad(x, ((0, 0), (0, 1), (0, 0)))
        x = x + patch_embeddings

        #Outer T-Block
        x = nn.LayerNorm(dtype=cfg.dtype)(x)
        x = nn.SelfAttention(num_heads=cfg.outer_heads,
                             qkv_features=cfg.outer_heads * cfg.outer_dim_head,
                             out_features=cfg.outer_dim,
                             use_bias=False,
                             kernel_init=cfg.kernel_init,
                             deterministic=True)(x)
        x = x + patch_embeddings
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = MlpBlock(config=cfg, inner=False)(y)
        outer_output = x + y

        return inner_output, outer_output
Esempio n. 9
0
 def setup(self):
     self.attention_layer = nn.SelfAttention(
         num_heads=self.num_heads,
         dtype=self.dtype,
         qkv_features=self.model_dim,
         dropout_rate=self.dropout_rate,
         kernel_init=self.kernel_init,
         bias_init=self.bias_init,
     )
     self.dropout = nn.Dropout(self.dropout_rate)
     self.layer_norm = nn.LayerNorm(epsilon=self.layer_norm_epsilon)
Esempio n. 10
0
 def test_multihead_self_attention(self):
   rng = random.PRNGKey(0)
   x = jnp.ones((4, 6, 5))
   sa_module = nn.SelfAttention(
       num_heads=8,
       qkv_features=16,
       kernel_init=initializers.ones,
       bias_init=initializers.zeros,
   )
   y, _ = sa_module.init_with_output(rng, x)
   self.assertEqual(y.shape, x.shape)
Esempio n. 11
0
    def __call__(self, inputs, encoder_mask=None, train=True):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      encoder_mask: encoder self-attention mask.
      train: if it is training.

    Returns:
      output after transformer encoder block.
    """
        # Attention block.
        assert inputs.ndim == 3
        if self.normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            maybe_post_normalize = model_utils.get_normalizer('none', train)
        elif self.normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer('none', train)
            maybe_post_normalize = model_utils.get_normalizer(
                self.normalizer, train)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(
                self.normalizer))

        x = maybe_pre_normalize()(inputs)
        x = nn.SelfAttention(num_heads=self.num_heads,
                             dtype=self.dtype,
                             qkv_features=self.qkv_dim,
                             kernel_init=self.enc_self_attn_kernel_init_fn,
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=self.attention_dropout_rate,
                             name='EncoderSelfAttention')(
                                 x, mask=encoder_mask, deterministic=not train)

        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = x + inputs

        x = maybe_post_normalize()(x)
        # MLP block.
        y = maybe_pre_normalize()(x)
        y = MlpBlock(mlp_dim=self.mlp_dim,
                     dtype=self.dtype,
                     dropout_rate=self.dropout_rate,
                     name='MLPBlock')(y, train=train)

        res = x + y
        return maybe_post_normalize()(res)
Esempio n. 12
0
 def __call__(self, x, training: bool = True):
     x = x.astype(jnp.int32)
     x = nn.Embed(
         num_embeddings=self.num_embeddings,
         features=self.embedding_dim,
         name="embed",
     )(x)
     x = jnp.reshape(x, (x.shape[0], -1))
     x = nn.SelfAttention(
         num_heads=self.num_heads,
         qkv_features=self.qkv_features,
         out_features=self.out_features,
         use_bias=False,
         deterministic=not training,
     )(x)
     return x
Esempio n. 13
0
  def __call__(self, query, deterministic):
    out_params = query.shape[-1]

    # Attention from query to value
    attention_output = nn.SelfAttention(
        num_heads=self.attention_heads,
        qkv_features=self.qkv_params,
        out_features=out_params,
        dropout_rate=self.dropout_rate)(
            query, deterministic=deterministic)
    normalized_attention_output = nn.LayerNorm()(query + attention_output)

    mlp_output = Mlp(
        hidden_params=self.mlp_params,
        out_params=out_params,
        dropout_rate=self.dropout_rate)(
            normalized_attention_output, deterministic=deterministic)
    return nn.LayerNorm()(normalized_attention_output + mlp_output)
Esempio n. 14
0
    def __call__(
            self,
            inputs,
            inputs_segmentation=None,  # REFACTOR
            padding_mask=None):  # REFACTOR
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      inputs_segmentation: input segmentation info for packed examples.
      padding_mask: bool, mask padding tokens.

    Returns:
      output after transformer encoder block.
    """
        cfg = self.config

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        x = nn.SelfAttention(num_heads=cfg.num_heads,
                             dtype=cfg.dtype,
                             qkv_features=cfg.qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=False,
                             kernel_init=cfg.kernel_init,
                             bias_init=cfg.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=cfg.attention_dropout_rate,
                             deterministic=cfg.deterministic)(
                                 x,
                                 segmentation=inputs_segmentation,  # REFACTOR
                                 padding_mask=padding_mask)  # REFACTOR

        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = MlpBlock(config=cfg)(y)

        return x + y
Esempio n. 15
0
    def __call__(self,
                 inputs,
                 temb,
                 deterministic,
                 decoder_mask=None,
                 encoder_decoder_mask=None):
        """Applies EncoderDecoder1DBlock module.

    Args:
      inputs: Input data for decoder.
      temb: Time embedding representation.
      deterministic: Should be deterministic in dropout?
      decoder_mask: Decoder self-attention mask.
      encoder_decoder_mask: Encoder-decoder attention mask.

    Returns:
      output after transformer encoder-decoder block.
    """
        cfg = self.config

        # Decoder block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        x = nn.SelfAttention(num_heads=cfg.num_heads,
                             dtype=cfg.dtype,
                             qkv_features=cfg.qkv_dim,
                             kernel_init=cfg.kernel_init,
                             bias_init=cfg.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=cfg.attention_dropout_rate,
                             deterministic=deterministic,
                             decode=False)(x, decoder_mask)
        x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        z = nn.LayerNorm(dtype=cfg.dtype)(x)
        z = MlpBlock(config=cfg)(z, temb, deterministic)

        return x + z
Esempio n. 16
0
    def __call__(self, x, train=True):
        out = {}
        y = nn.LayerNorm(name='LayerNorm_0')(x)
        y = out['sa'] = nn.SelfAttention(
            num_heads=self.num_heads,
            kernel_init=nn.initializers.xavier_uniform(),
            deterministic=train,
            name='MultiHeadDotProductAttention_1',
        )(y)
        y = nn.Dropout(rate=self.dropout)(y, train)
        x = out['+sa'] = x + y

        y = nn.LayerNorm(name='LayerNorm_2')(x)
        y = out['mlp'] = MlpBlock(
            mlp_dim=self.mlp_dim,
            dropout=self.dropout,
            name='MlpBlock_3',
        )(y, train)
        y = nn.Dropout(rate=self.dropout)(y, train)
        x = out['+mlp'] = x + y
        return x, out
Esempio n. 17
0
  def __call__(self,
               targets,
               encoded,
               decoder_mask = None,
               encoder_decoder_mask = None,
               decoder_relative_position = None,
               encoder_decoder_relative_position = None):
    """Applies Transformer block.

    Args:
      targets: input data for decoder `[batch_size, ..., length, dim]`
      encoded: input data from encoder `[batch_size, ..., length2, dim2]`
      decoder_mask: decoder self-attention mask
      encoder_decoder_mask: encoder-decoder attention mask
      decoder_relative_position: decoder relative positions tensor
          `[batch_sizes..., length2, length2]'
      encoder_decoder_relative_position: encoder-decoder relative tensor
          `[batch_sizes..., length2, length]'

    Returns:
      Decoded data `[batch_size, ..., length2, mlp_dim]`
    """
    cfg = self.config

    # Decoder block.
    x = nn.LayerNorm(dtype=cfg.dtype)(targets)
    if cfg.use_relative_attention:
      x = relative_attention.RelativeSelfAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic,
          bidirectional=self.bidirectional_attention,
          num_relative_position_buckets=self.num_relative_position_buckets,
          max_distance=self.max_distance)(
              x, decoder_mask, decoder_relative_position)
    else:
      x = nn.SelfAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic)(x, decoder_mask)

    x = nn.Dropout(rate=cfg.dropout_rate)(
        x, deterministic=cfg.deterministic)
    x = x + targets

    # Encoder-Decoder block.
    y = nn.LayerNorm(dtype=cfg.dtype)(x)
    if self.relative_cross_attention:
      y = relative_attention.RelativeMultiHeadDotProductAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic,
          bidirectional=self.bidirectional_cross_attention,
          num_relative_position_buckets=(
              self.num_relative_position_buckets_cross_attention),
          max_distance=self.max_distance_cross_attention)(
              y, encoded, encoder_decoder_mask,
              encoder_decoder_relative_position)
    else:
      y = nn.MultiHeadDotProductAttention(
          num_heads=cfg.num_heads,
          dtype=cfg.dtype,
          qkv_features=cfg.qkv_dim,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          use_bias=False,
          broadcast_dropout=False,
          dropout_rate=cfg.attention_dropout_rate,
          deterministic=cfg.deterministic)(y, encoded, encoder_decoder_mask)

    y = nn.Dropout(rate=cfg.dropout_rate)(
        y, deterministic=cfg.deterministic)
    y = y + x

    # MLP block.
    z = nn.LayerNorm(dtype=cfg.dtype)(y)
    z = MLPBlock(config=cfg)(z)

    return y + z
Esempio n. 18
0
    def __call__(
            self,
            targets,
            encoded,
            inputs_segmentation=None,  # REFACTOR
            targets_segmentation=None,  # REFACTOR
            padding_mask=None,  # REFACTOR
            key_padding_mask=None):  # REFACTOR
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      padding_mask: bool, mask padding tokens
      key_padding_mask: bool, mask padding tokens

    Returns:
      output after transformer encoder-decoder block.
    """
        cfg = self.config

        # Decoder block.
        assert targets.ndim == 3
        x = nn.LayerNorm(dtype=cfg.dtype)(targets)
        x = nn.SelfAttention(num_heads=cfg.num_heads,
                             dtype=cfg.dtype,
                             qkv_features=cfg.qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=True,
                             kernel_init=cfg.kernel_init,
                             bias_init=cfg.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=cfg.attention_dropout_rate,
                             deterministic=cfg.deterministic,
                             decode=cfg.decode)(
                                 x,
                                 padding_mask=padding_mask,
                                 segmentation=targets_segmentation)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + targets

        # Encoder-Decoder block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = nn.MultiHeadDotProductAttention(
            num_heads=cfg.num_heads,
            dtype=cfg.dtype,
            qkv_features=cfg.qkv_dim,
            attention_axis=(1, ),
            causal_mask=False,
            kernel_init=cfg.kernel_init,
            bias_init=cfg.bias_init,
            use_bias=False,
            broadcast_dropout=False,
            dropout_rate=cfg.attention_dropout_rate,
            deterministic=cfg.deterministic)(
                y,
                encoded,
                padding_mask=padding_mask,
                key_padding_mask=key_padding_mask,
                segmentation=targets_segmentation,
                key_segmentation=inputs_segmentation)

        y = nn.Dropout(rate=cfg.dropout_rate)(y,
                                              deterministic=cfg.deterministic)
        y = y + x

        # MLP block.
        z = nn.LayerNorm(dtype=cfg.dtype)(y)
        z = MlpBlock(config=cfg)(z)

        return y + z
Esempio n. 19
0
    def __call__(self,
                 inputs,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 inputs_kv=None):
        """Applies EncoderDecoder1DBlock module.

    Args:
      inputs: input data for decoder
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output after transformer encoder-decoder block.
    """
        cfg = self.config

        # Decoder block.
        assert inputs.ndim == 3
        # assert decoder_mask.ndim == 4
        if cfg.use_layernorm:
            x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        else:
            x = inputs
        if self.is_self_att:
            x = nn.SelfAttention(num_heads=cfg.num_heads,
                                 out_features=self.out_features,
                                 dtype=cfg.dtype,
                                 qkv_features=cfg.qkv_dim,
                                 kernel_init=cfg.kernel_init,
                                 bias_init=cfg.bias_init,
                                 use_bias=False,
                                 broadcast_dropout=False,
                                 dropout_rate=cfg.attention_dropout_rate,
                                 deterministic=cfg.deterministic,
                                 attention_fn=self.attention_fn,
                                 decode=cfg.decode)(x, decoder_mask)
        else:
            if cfg.use_layernorm:
                x_kv = nn.LayerNorm(dtype=cfg.dtype)(inputs_kv)
            else:
                x_kv = inputs_kv
            x = nn.MultiHeadDotProductAttention(
                num_heads=cfg.num_heads,
                out_features=self.out_features,
                dtype=cfg.dtype,
                qkv_features=cfg.qkv_dim,
                kernel_init=cfg.kernel_init,
                bias_init=cfg.bias_init,
                use_bias=False,
                broadcast_dropout=False,
                dropout_rate=cfg.attention_dropout_rate,
                deterministic=cfg.deterministic,
                attention_fn=self.attention_fn,
                decode=cfg.decode)(x, x_kv, mask=decoder_mask)

        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        x = x + inputs

        # MLP block.
        if cfg.use_layernorm:
            z = nn.LayerNorm(dtype=cfg.dtype)(x)
        else:
            z = x
        z = MlpBlock(config=cfg)(z)

        return x + z
Esempio n. 20
0
    def __call__(self,
                 inputs,
                 train,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      train: bool: if model is training.
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        if self.normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(self.normalizer,
                                                             train,
                                                             dtype=self.dtype)
            maybe_post_normalize = model_utils.get_normalizer('none',
                                                              train,
                                                              dtype=self.dtype)
        elif self.normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer('none',
                                                             train,
                                                             dtype=self.dtype)
            maybe_post_normalize = model_utils.get_normalizer(self.normalizer,
                                                              train,
                                                              dtype=self.dtype)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(
                self.normalizer))

        x = maybe_pre_normalize()(inputs)

        if self.attention_fn is None:
            attention_fn = nn.dot_product_attention
        else:
            attention_fn = self.attention_fn
        x = nn.SelfAttention(num_heads=self.num_heads,
                             qkv_features=self.qkv_dim,
                             decode=self.decode,
                             dtype=self.dtype,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             use_bias=False,
                             broadcast_dropout=False,
                             attention_fn=attention_fn,
                             dropout_rate=self.attention_dropout_rate,
                             deterministic=not train)(x, decoder_mask)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        x = x + inputs
        x = maybe_post_normalize()(x)

        # MLP block.
        y = maybe_pre_normalize()(x)
        y = MlpBlock(mlp_dim=self.mlp_dim,
                     dropout_rate=self.dropout_rate)(y, train=train)
        res = x + y

        return maybe_post_normalize()(res)
Esempio n. 21
0
    def exec_op(self, op, input_values, deterministic, training, **_):
        """Executes an op according to the normal concrete semantics."""
        input_kwargs: Dict[str, Any] = op.input_kwargs
        op_kwargs: Dict[str, Any] = op.op_kwargs
        op_type = op.type
        if "name" not in op_kwargs:
            raise ValueError("Op kwargs must contain a name.")
        op_name = op_kwargs["name"]

        if op_type == OpType.NONE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert len(op_kwargs) == 1
            output_values = [lax.stop_gradient(input_value)]

        elif op_type == OpType.IDENTITY:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert len(op_kwargs) == 1
            output_values = [input_value]

        # nn.linear

        elif op_type == OpType.DENSE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.Dense(**op_kwargs)(input_value)]

        elif op_type == OpType.DENSE_GENERAL:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert 2 <= len(op_kwargs) <= 7
            output_values = [nn.DenseGeneral(**op_kwargs)(input_value)]

        elif op_type == OpType.CONV:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs

            ks = op_kwargs["kernel_size"]
            if isinstance(ks, int):
                op_kwargs["kernel_size"] = (ks, ) * (input_value.ndim - 2)

            output_values = [nn.Conv(**op_kwargs)(input_value)]

        # others

        elif op_type == OpType.MUL:
            assert len(input_values) == 2
            assert not input_kwargs
            assert len(op_kwargs) == 1  # name
            output_values = [input_values[0] * input_values[1]]

        elif op_type in [OpType.ADD, OpType.STOCH_DEPTH]:
            assert len(op_kwargs) == 1  # name

            input_value = input_values[0]
            if "layer_drop_rate" in input_kwargs:
                assert len(input_kwargs) == 1
                survival_rate = 1 - input_kwargs["layer_drop_rate"]
                if survival_rate == 1.0 or deterministic:
                    pass
                else:
                    # Reuse dropout's rng stream.
                    rng = self.make_rng("dropout")
                    mask_shape = [input_value.shape[0]
                                  ] + [1] * (input_value.ndim - 1)
                    mask = random.bernoulli(rng,
                                            p=survival_rate,
                                            shape=mask_shape)
                    mask = jnp.tile(mask, [1] + list(input_value.shape[1:]))
                    input_value = lax.select(mask, input_value / survival_rate,
                                             jnp.zeros_like(input_value))
            else:
                assert not input_kwargs
                assert op_type == OpType.ADD

            if op_type == OpType.ADD:
                assert len(input_values) == 2
                output_values = [input_value + input_values[1]]
            else:
                assert len(input_values) == 1
                output_values = [input_value]

        elif op_type == OpType.SCALAR_MUL:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            assert len(op_kwargs) == 1  # name
            if "const" in input_kwargs:
                c = input_kwargs["const"]
            else:
                c = 1 / jnp.sqrt(input_values[0].shape[-1])
            output_values = [input_values[0] * c]

        elif op_type == OpType.SCALAR_ADD:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            assert len(op_kwargs) == 1  # name
            assert "const" in input_kwargs
            c = input_kwargs["const"]
            output_values = [input_values[0] + c]

        elif op_type == OpType.DOT_GENERAL:
            assert len(input_values) == 2
            assert 0 < len(input_kwargs) <= 3
            assert len(op_kwargs) == 1  # name
            output_values = [
                lax.dot_general(input_values[0], input_values[1],
                                **input_kwargs)
            ]

        elif op_type == OpType.EINSUM:
            assert len(input_values) == 2
            assert len(input_kwargs) == 1
            assert "sum" in input_kwargs
            output_values = [
                jnp.einsum(input_kwargs["sum"], input_values[0],
                           input_values[1])
            ]

        # nn.attention

        elif op_type == OpType.SELF_ATTENTION:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [
                nn.SelfAttention(**op_kwargs,
                                 deterministic=deterministic)(input_value)
            ]

        # nn.activation

        elif op_type in [
                OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID
        ]:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            fn = {
                OpType.RELU: nn.relu,
                OpType.GELU: nn.gelu,
                OpType.SWISH: nn.swish,
                OpType.SIGMOID: nn.sigmoid
            }[op_type]
            output_values = [fn(input_value)]

        elif op_type == OpType.SOFTMAX:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            output_values = [nn.softmax(input_value, **input_kwargs)]

        # nn.normalization

        elif op_type == OpType.BATCH_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            add_kwargs = {}
            if "use_running_average" not in input_kwargs:
                add_kwargs = {"use_running_average": not training}
            else:
                add_kwargs = {}
            output_values = [
                nn.BatchNorm(**op_kwargs)(input_value, **input_kwargs,
                                          **add_kwargs)
            ]

        elif op_type == OpType.LAYER_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.LayerNorm(**op_kwargs)(input_value)]

        elif op_type == OpType.GROUP_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.GroupNorm(**op_kwargs)(input_value)]

        # reshape operators

        elif op_type == OpType.RESHAPE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert 0 < len(input_kwargs) < 3
            new_shape = input_kwargs.pop("new_shape")
            if new_shape[0] == "B":
                new_shape = (input_value.shape[0], ) + new_shape[1:]
            output_values = [
                jnp.reshape(input_value, new_shape, **input_kwargs)
            ]

        elif op_type == OpType.FLATTEN:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            new_shape = (input_value.shape[0], -1)
            output_values = [jnp.reshape(input_value, new_shape)]

        elif op_type == OpType.TRANSPOSE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) == 1
            assert len(op_kwargs) == 1  # name
            output_values = [jnp.transpose(input_value, **input_kwargs)]

        # nn.stochastic

        elif op_type == OpType.DROPOUT:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            output_values = [
                nn.Dropout(**op_kwargs)(input_value,
                                        deterministic=deterministic,
                                        **input_kwargs)
            ]

        # nn.pooling

        elif op_type == OpType.AVG_POOL or op_type == OpType.MAX_POOL:
            op_fn = nn.avg_pool if op_type == OpType.AVG_POOL else nn.max_pool
            assert len(input_values) == 1
            input_value = input_values[0]
            assert input_kwargs

            ws = input_kwargs["window_shape"]
            if isinstance(ws, int):
                ws = [ws] * (input_value.ndim - 2)
            new_ws = []
            for window_dim_shape, dim_shape in zip(ws, input_value.shape[1:]):
                if window_dim_shape == 0:
                    new_ws.append(dim_shape)
                else:
                    new_ws.append(window_dim_shape)
            input_kwargs["window_shape"] = tuple(new_ws)

            if "strides" in input_kwargs:
                s = input_kwargs["strides"]
                if isinstance(s, int):
                    input_kwargs["strides"] = (s, ) * (input_value.ndim - 2)

            output_values = [op_fn(input_value, **input_kwargs)]

        elif op_type == OpType.MEAN:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert input_kwargs
            output_values = [jnp.mean(input_value, **input_kwargs)]

        # new param

        elif op_type == OpType.PARAM:
            assert not input_values
            assert 0 < len(input_kwargs) <= 2
            init_fn = input_kwargs.pop("init_fn")

            init_fn_with_kwargs = functools.partial(init_fn, **input_kwargs)
            output_values = [self.param(op_name, init_fn_with_kwargs)]

        else:
            raise ValueError(f"op_type {op_type} not supported...")

        return output_values